!feat(core): tighten legacy agent session stream lifecycle

This commit is contained in:
Adam Weidman
2026-03-20 14:33:36 -04:00
parent f038ba417f
commit 7abec8e3c7
2 changed files with 292 additions and 78 deletions
@@ -101,24 +101,12 @@ async function collectEvents(
options?: { streamId?: string; eventId?: string },
): Promise<AgentEvent[]> {
const events: AgentEvent[] = [];
const streamOptions: { streamId?: string; eventId?: string } =
options?.eventId
? {
eventId: options.eventId,
...(options.streamId ? { streamId: options.streamId } : {}),
}
: {
streamId:
options?.streamId ??
session.events.findLast((event) => event.type === 'agent_start')
?.streamId,
};
const streamOptions =
options?.eventId || options?.streamId ? options : undefined;
if (!streamOptions.eventId && !streamOptions.streamId) {
return events;
}
for await (const event of session.stream(streamOptions)) {
for await (const event of streamOptions
? session.stream(streamOptions)
: session.stream()) {
events.push(event);
}
return events;
@@ -188,6 +176,40 @@ describe('LegacyAgentSession', () => {
await collectEvents(session, { streamId: streamId ?? undefined });
});
it('returns streamId before emitting agent_start', async () => {
const sendMock = deps.client.sendMessageStream as ReturnType<
typeof vi.fn
>;
sendMock.mockReturnValue(
makeStream([
{
type: GeminiEventType.Finished,
value: { reason: FinishReason.STOP, usageMetadata: undefined },
},
]),
);
const session = new LegacyAgentSession(deps);
const liveEvents: AgentEvent[] = [];
session.subscribe((event) => {
liveEvents.push(event);
});
const { streamId } = await session.send({
message: [{ type: 'text', text: 'hi' }],
});
expect(streamId).toBe('test-stream');
expect(liveEvents.some((event) => event.type === 'agent_start')).toBe(
false,
);
await collectEvents(session, { streamId: streamId ?? undefined });
expect(liveEvents.some((event) => event.type === 'agent_start')).toBe(
true,
);
});
it('throws for non-message payloads', async () => {
const session = new LegacyAgentSession(deps);
await expect(session.send({ update: { title: 'test' } })).rejects.toThrow(
@@ -213,14 +235,17 @@ describe('LegacyAgentSession', () => {
);
const session = new LegacyAgentSession(deps);
await session.send({ message: [{ type: 'text', text: 'first' }] });
const { streamId } = await session.send({
message: [{ type: 'text', text: 'first' }],
});
await vi.advanceTimersByTimeAsync(0);
await expect(
session.send({ message: [{ type: 'text', text: 'second' }] }),
).rejects.toThrow('cannot be called while a stream is active');
resolveHang?.();
await collectEvents(session);
await collectEvents(session, { streamId: streamId ?? undefined });
});
it('creates a new streamId after the previous stream completes', async () => {
@@ -756,6 +781,48 @@ describe('LegacyAgentSession', () => {
});
describe('abort', () => {
it('treats abort before the first model event as aborted without fatal error', async () => {
let releaseAbort: (() => void) | undefined;
const sendMock = deps.client.sendMessageStream as ReturnType<
typeof vi.fn
>;
sendMock.mockReturnValue(
(async function* () {
await new Promise<void>((resolve) => {
releaseAbort = resolve;
});
yield* [];
const abortError = new Error('Aborted');
abortError.name = 'AbortError';
throw abortError;
})(),
);
const session = new LegacyAgentSession(deps);
const { streamId } = await session.send({
message: [{ type: 'text', text: 'hi' }],
});
await vi.advanceTimersByTimeAsync(0);
await session.abort();
releaseAbort?.();
const events = await collectEvents(session, {
streamId: streamId ?? undefined,
});
expect(
events.some(
(event): event is AgentEvent<'error'> =>
event.type === 'error' && event.fatal,
),
).toBe(false);
const streamEnd = events.findLast(
(event): event is AgentEvent<'agent_end'> => event.type === 'agent_end',
);
expect(streamEnd?.reason).toBe('aborted');
});
it('aborts the stream', async () => {
const sendMock = deps.client.sendMessageStream as ReturnType<
typeof vi.fn
@@ -797,6 +864,59 @@ describe('LegacyAgentSession', () => {
);
expect(streamEnd?.reason).toBe('aborted');
});
it('treats abort during pending scheduler work as aborted without fatal error', async () => {
let resolveSchedule: ((value: CompletedToolCall[]) => void) | undefined;
const sendMock = deps.client.sendMessageStream as ReturnType<
typeof vi.fn
>;
sendMock.mockReturnValue(
makeStream([
{
type: GeminiEventType.ToolCallRequest,
value: makeToolRequest('call-1', 'slow_tool'),
},
{
type: GeminiEventType.Finished,
value: { reason: FinishReason.STOP, usageMetadata: undefined },
},
]),
);
const scheduleMock = deps.scheduler.schedule as ReturnType<typeof vi.fn>;
scheduleMock.mockReturnValue(
new Promise<CompletedToolCall[]>((resolve) => {
resolveSchedule = resolve;
}),
);
const session = new LegacyAgentSession(deps);
const { streamId } = await session.send({
message: [{ type: 'text', text: 'hi' }],
});
await new Promise((resolve) => setTimeout(resolve, 25));
await session.abort();
resolveSchedule?.([makeCompletedToolCall('call-1', 'slow_tool', 'done')]);
const events = await collectEvents(session, {
streamId: streamId ?? undefined,
});
expect(
events.some(
(event): event is AgentEvent<'error'> =>
event.type === 'error' && event.fatal,
),
).toBe(false);
expect(events.some((event) => event.type === 'tool_response')).toBe(
false,
);
const streamEnd = events.findLast(
(event): event is AgentEvent<'agent_end'> => event.type === 'agent_end',
);
expect(streamEnd?.reason).toBe('aborted');
});
});
describe('events property', () => {
@@ -1273,5 +1393,25 @@ describe('LegacyAgentSession', () => {
);
expect(err?._meta?.['code']).toBe('ENOENT');
});
it('preserves status in _meta for errors with status property', async () => {
const sendMock = deps.client.sendMessageStream as ReturnType<
typeof vi.fn
>;
const statusError = new Error('rate limited');
(statusError as Error & { status: string }).status = 'RESOURCE_EXHAUSTED';
sendMock.mockImplementation(() => {
throw statusError;
});
const session = new LegacyAgentSession(deps);
await session.send({ message: [{ type: 'text', text: 'hi' }] });
const events = await collectEvents(session);
const err = events.find(
(e): e is AgentEvent<'error'> => e.type === 'error',
);
expect(err?._meta?.['status']).toBe('RESOURCE_EXHAUSTED');
});
});
});
+133 -59
View File
@@ -40,6 +40,10 @@ import type {
Unsubscribe,
} from './types.js';
function isAbortLikeError(err: unknown): boolean {
return err instanceof Error && err.name === 'AbortError';
}
export interface LegacySessionDeps {
client: GeminiClient;
scheduler: Scheduler;
@@ -71,7 +75,7 @@ class LegacyAgentProtocol implements AgentProtocol {
this._promptId = deps.promptId;
}
get events(): AgentEvent[] {
get events(): readonly AgentEvent[] {
return this._events;
}
@@ -102,23 +106,11 @@ class LegacyAgentProtocol implements AgentProtocol {
this._beginNewStream();
const streamId = this._translationState.streamId;
const parts = contentPartsToGeminiParts(message);
const userMessage = this._makeInternalEvent('message', {
role: 'user',
content: message,
...(payload._meta ? { _meta: payload._meta } : {}),
});
const userMessage = this._makeUserMessageEvent(message, payload._meta);
this._emit([userMessage]);
void Promise.resolve().then(async () => {
this._ensureAgentStart();
try {
await this._runLoop(parts);
} catch (err: unknown) {
this._emitErrorAndAgentEnd(err);
this._markStreamDone();
}
});
this._scheduleRunLoop(parts);
return { streamId };
}
@@ -127,6 +119,26 @@ class LegacyAgentProtocol implements AgentProtocol {
this._abortController.abort();
}
private _scheduleRunLoop(initialParts: Part[]): void {
setTimeout(() => {
void this._runLoopInBackground(initialParts);
}, 0);
}
private async _runLoopInBackground(initialParts: Part[]): Promise<void> {
this._ensureAgentStart();
try {
await this._runLoop(initialParts);
} catch (err: unknown) {
if (this._abortController.signal.aborted || isAbortLikeError(err)) {
this._ensureAgentEnd('aborted');
} else {
this._emitErrorAndAgentEnd(err);
}
this._markStreamDone();
}
}
private async _runLoop(initialParts: Part[]): Promise<void> {
let currentParts: Part[] = initialParts;
let turnCount = 0;
@@ -135,17 +147,11 @@ class LegacyAgentProtocol implements AgentProtocol {
while (true) {
turnCount++;
if (maxTurns >= 0 && turnCount > maxTurns) {
this._emit([
this._makeInternalEvent('agent_end', {
reason: 'max_turns',
data: {
code: 'MAX_TURNS_EXCEEDED',
maxTurns,
turnCount: turnCount - 1,
},
}),
]);
this._markStreamDone();
this._finishStream('max_turns', {
code: 'MAX_TURNS_EXCEEDED',
maxTurns,
turnCount: turnCount - 1,
});
return;
}
@@ -158,8 +164,7 @@ class LegacyAgentProtocol implements AgentProtocol {
for await (const event of responseStream) {
if (this._abortController.signal.aborted) {
this._ensureAgentEnd('aborted');
this._markStreamDone();
this._finishStream('aborted');
return;
}
@@ -170,8 +175,7 @@ class LegacyAgentProtocol implements AgentProtocol {
this._emit(translateEvent(event, this._translationState));
if (event.type === GeminiEventType.Error) {
this._ensureAgentEnd('failed');
this._markStreamDone();
this._finishStream('failed');
return;
}
@@ -179,15 +183,13 @@ class LegacyAgentProtocol implements AgentProtocol {
event.type === GeminiEventType.InvalidStream ||
event.type === GeminiEventType.ContextWindowWillOverflow
) {
this._ensureAgentEnd('failed');
this._markStreamDone();
this._finishStream('failed');
return;
}
if (event.type === GeminiEventType.Finished) {
if (toolCallRequests.length === 0) {
this._ensureAgentEnd(mapFinishReason(event.value.reason));
this._markStreamDone();
this._finishStream(mapFinishReason(event.value.reason));
return;
}
continue;
@@ -203,9 +205,13 @@ class LegacyAgentProtocol implements AgentProtocol {
}
}
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
if (toolCallRequests.length === 0) {
this._ensureAgentEnd('completed');
this._markStreamDone();
this._finishStream('completed');
return;
}
@@ -214,6 +220,11 @@ class LegacyAgentProtocol implements AgentProtocol {
this._abortController.signal,
);
if (this._abortController.signal.aborted) {
this._finishStream('aborted');
return;
}
const toolResponseParts: Part[] = [];
for (const tc of completedToolCalls) {
const response = tc.response;
@@ -227,7 +238,7 @@ class LegacyAgentProtocol implements AgentProtocol {
const data = buildToolResponseData(response);
this._emit([
this._makeInternalEvent('tool_response', {
this._makeToolResponseEvent({
requestId: request.callId,
name: request.name,
content,
@@ -261,8 +272,7 @@ class LegacyAgentProtocol implements AgentProtocol {
tc.response.error !== undefined,
);
if (stopTool) {
this._ensureAgentEnd('completed');
this._markStreamDone();
this._finishStream('completed');
return;
}
@@ -270,8 +280,7 @@ class LegacyAgentProtocol implements AgentProtocol {
isFatalToolError(tc.response.errorType),
);
if (fatalTool) {
this._ensureAgentEnd('failed');
this._markStreamDone();
this._finishStream('failed');
return;
}
@@ -313,21 +322,29 @@ class LegacyAgentProtocol implements AgentProtocol {
private _ensureAgentStart(): void {
if (!this._translationState.streamStartEmitted) {
this._translationState.streamStartEmitted = true;
this._emit([this._makeInternalEvent('agent_start', {})]);
this._emit([this._makeAgentStartEvent()]);
}
}
private _ensureAgentEnd(reason: StreamEndReason = 'completed'): void {
if (!this._agentEndEmitted && this._translationState.streamStartEmitted) {
this._agentEndEmitted = true;
this._emit([
this._makeInternalEvent('agent_end', {
reason,
}),
]);
this._emit([this._makeAgentEndEvent(reason)]);
}
}
private _finishStream(
reason: StreamEndReason,
data?: Record<string, unknown>,
): void {
if (data && !this._agentEndEmitted) {
this._emit([this._makeAgentEndEvent(reason, data)]);
} else {
this._ensureAgentEnd(reason);
}
this._markStreamDone();
}
/**
* Preserve error identity fields in _meta so downstream consumers can
* reconstruct fatal CLI errors.
@@ -346,10 +363,13 @@ class LegacyAgentProtocol implements AgentProtocol {
if ('code' in err) {
meta['code'] = err.code;
}
if ('status' in err) {
meta['status'] = err.status;
}
}
this._emit([
this._makeInternalEvent('error', {
this._makeErrorEvent({
status: 'INTERNAL',
message,
fatal: true,
@@ -360,22 +380,76 @@ class LegacyAgentProtocol implements AgentProtocol {
this._ensureAgentEnd('failed');
}
private _makeInternalEvent<T extends AgentEvent['type']>(
type: T,
payload: Omit<
Partial<AgentEvent<T>>,
'id' | 'timestamp' | 'streamId' | 'type'
>,
): AgentEvent<T> {
const id = `${this._translationState.streamId}-${this._translationState.eventCounter++}`;
private _nextEventFields() {
return {
...payload,
id,
id: `${this._translationState.streamId}-${this._translationState.eventCounter++}`,
timestamp: new Date().toISOString(),
streamId: this._translationState.streamId,
type,
};
}
private _makeUserMessageEvent(
content: ContentPart[],
meta?: Record<string, unknown>,
): AgentEvent<'message'> {
const event = {
...this._nextEventFields(),
type: 'message',
role: 'user',
content,
...(meta ? { _meta: meta } : {}),
} satisfies AgentEvent<'message'>;
return event;
}
private _makeToolResponseEvent(
payload: Omit<
AgentEvent<'tool_response'>,
'id' | 'timestamp' | 'streamId' | 'type'
>,
): AgentEvent<'tool_response'> {
const event = {
...this._nextEventFields(),
type: 'tool_response',
...payload,
} satisfies AgentEvent<'tool_response'>;
return event;
}
private _makeAgentStartEvent(): AgentEvent<'agent_start'> {
const event = {
...this._nextEventFields(),
type: 'agent_start',
} satisfies AgentEvent<'agent_start'>;
return event;
}
private _makeAgentEndEvent(
reason: StreamEndReason,
data?: Record<string, unknown>,
): AgentEvent<'agent_end'> {
const event = {
...this._nextEventFields(),
type: 'agent_end',
reason,
...(data ? { data } : {}),
} satisfies AgentEvent<'agent_end'>;
return event;
}
private _makeErrorEvent(
payload: Omit<
AgentEvent<'error'>,
'id' | 'timestamp' | 'streamId' | 'type'
>,
): AgentEvent<'error'> {
const event = {
...this._nextEventFields(),
type: 'error',
...payload,
} satisfies AgentEvent<'error'>;
return event;
}
}
export class LegacyAgentSession extends AgentSession {