diff --git a/packages/core/src/agent/legacy-agent-session.test.ts b/packages/core/src/agent/legacy-agent-session.test.ts index 34f745a6b2..4ce59edd63 100644 --- a/packages/core/src/agent/legacy-agent-session.test.ts +++ b/packages/core/src/agent/legacy-agent-session.test.ts @@ -101,24 +101,12 @@ async function collectEvents( options?: { streamId?: string; eventId?: string }, ): Promise { const events: AgentEvent[] = []; - const streamOptions: { streamId?: string; eventId?: string } = - options?.eventId - ? { - eventId: options.eventId, - ...(options.streamId ? { streamId: options.streamId } : {}), - } - : { - streamId: - options?.streamId ?? - session.events.findLast((event) => event.type === 'agent_start') - ?.streamId, - }; + const streamOptions = + options?.eventId || options?.streamId ? options : undefined; - if (!streamOptions.eventId && !streamOptions.streamId) { - return events; - } - - for await (const event of session.stream(streamOptions)) { + for await (const event of streamOptions + ? session.stream(streamOptions) + : session.stream()) { events.push(event); } return events; @@ -188,6 +176,40 @@ describe('LegacyAgentSession', () => { await collectEvents(session, { streamId: streamId ?? undefined }); }); + it('returns streamId before emitting agent_start', async () => { + const sendMock = deps.client.sendMessageStream as ReturnType< + typeof vi.fn + >; + sendMock.mockReturnValue( + makeStream([ + { + type: GeminiEventType.Finished, + value: { reason: FinishReason.STOP, usageMetadata: undefined }, + }, + ]), + ); + + const session = new LegacyAgentSession(deps); + const liveEvents: AgentEvent[] = []; + session.subscribe((event) => { + liveEvents.push(event); + }); + + const { streamId } = await session.send({ + message: [{ type: 'text', text: 'hi' }], + }); + + expect(streamId).toBe('test-stream'); + expect(liveEvents.some((event) => event.type === 'agent_start')).toBe( + false, + ); + + await collectEvents(session, { streamId: streamId ?? undefined }); + expect(liveEvents.some((event) => event.type === 'agent_start')).toBe( + true, + ); + }); + it('throws for non-message payloads', async () => { const session = new LegacyAgentSession(deps); await expect(session.send({ update: { title: 'test' } })).rejects.toThrow( @@ -213,14 +235,17 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'first' }] }); + const { streamId } = await session.send({ + message: [{ type: 'text', text: 'first' }], + }); + await vi.advanceTimersByTimeAsync(0); await expect( session.send({ message: [{ type: 'text', text: 'second' }] }), ).rejects.toThrow('cannot be called while a stream is active'); resolveHang?.(); - await collectEvents(session); + await collectEvents(session, { streamId: streamId ?? undefined }); }); it('creates a new streamId after the previous stream completes', async () => { @@ -756,6 +781,48 @@ describe('LegacyAgentSession', () => { }); describe('abort', () => { + it('treats abort before the first model event as aborted without fatal error', async () => { + let releaseAbort: (() => void) | undefined; + const sendMock = deps.client.sendMessageStream as ReturnType< + typeof vi.fn + >; + sendMock.mockReturnValue( + (async function* () { + await new Promise((resolve) => { + releaseAbort = resolve; + }); + yield* []; + const abortError = new Error('Aborted'); + abortError.name = 'AbortError'; + throw abortError; + })(), + ); + + const session = new LegacyAgentSession(deps); + const { streamId } = await session.send({ + message: [{ type: 'text', text: 'hi' }], + }); + await vi.advanceTimersByTimeAsync(0); + + await session.abort(); + releaseAbort?.(); + + const events = await collectEvents(session, { + streamId: streamId ?? undefined, + }); + expect( + events.some( + (event): event is AgentEvent<'error'> => + event.type === 'error' && event.fatal, + ), + ).toBe(false); + + const streamEnd = events.findLast( + (event): event is AgentEvent<'agent_end'> => event.type === 'agent_end', + ); + expect(streamEnd?.reason).toBe('aborted'); + }); + it('aborts the stream', async () => { const sendMock = deps.client.sendMessageStream as ReturnType< typeof vi.fn @@ -797,6 +864,59 @@ describe('LegacyAgentSession', () => { ); expect(streamEnd?.reason).toBe('aborted'); }); + + it('treats abort during pending scheduler work as aborted without fatal error', async () => { + let resolveSchedule: ((value: CompletedToolCall[]) => void) | undefined; + const sendMock = deps.client.sendMessageStream as ReturnType< + typeof vi.fn + >; + sendMock.mockReturnValue( + makeStream([ + { + type: GeminiEventType.ToolCallRequest, + value: makeToolRequest('call-1', 'slow_tool'), + }, + { + type: GeminiEventType.Finished, + value: { reason: FinishReason.STOP, usageMetadata: undefined }, + }, + ]), + ); + + const scheduleMock = deps.scheduler.schedule as ReturnType; + scheduleMock.mockReturnValue( + new Promise((resolve) => { + resolveSchedule = resolve; + }), + ); + + const session = new LegacyAgentSession(deps); + const { streamId } = await session.send({ + message: [{ type: 'text', text: 'hi' }], + }); + + await new Promise((resolve) => setTimeout(resolve, 25)); + await session.abort(); + resolveSchedule?.([makeCompletedToolCall('call-1', 'slow_tool', 'done')]); + + const events = await collectEvents(session, { + streamId: streamId ?? undefined, + }); + expect( + events.some( + (event): event is AgentEvent<'error'> => + event.type === 'error' && event.fatal, + ), + ).toBe(false); + expect(events.some((event) => event.type === 'tool_response')).toBe( + false, + ); + + const streamEnd = events.findLast( + (event): event is AgentEvent<'agent_end'> => event.type === 'agent_end', + ); + expect(streamEnd?.reason).toBe('aborted'); + }); }); describe('events property', () => { @@ -1273,5 +1393,25 @@ describe('LegacyAgentSession', () => { ); expect(err?._meta?.['code']).toBe('ENOENT'); }); + + it('preserves status in _meta for errors with status property', async () => { + const sendMock = deps.client.sendMessageStream as ReturnType< + typeof vi.fn + >; + const statusError = new Error('rate limited'); + (statusError as Error & { status: string }).status = 'RESOURCE_EXHAUSTED'; + sendMock.mockImplementation(() => { + throw statusError; + }); + + const session = new LegacyAgentSession(deps); + await session.send({ message: [{ type: 'text', text: 'hi' }] }); + const events = await collectEvents(session); + + const err = events.find( + (e): e is AgentEvent<'error'> => e.type === 'error', + ); + expect(err?._meta?.['status']).toBe('RESOURCE_EXHAUSTED'); + }); }); }); diff --git a/packages/core/src/agent/legacy-agent-session.ts b/packages/core/src/agent/legacy-agent-session.ts index 789ecc40ec..3feeb42512 100644 --- a/packages/core/src/agent/legacy-agent-session.ts +++ b/packages/core/src/agent/legacy-agent-session.ts @@ -40,6 +40,10 @@ import type { Unsubscribe, } from './types.js'; +function isAbortLikeError(err: unknown): boolean { + return err instanceof Error && err.name === 'AbortError'; +} + export interface LegacySessionDeps { client: GeminiClient; scheduler: Scheduler; @@ -71,7 +75,7 @@ class LegacyAgentProtocol implements AgentProtocol { this._promptId = deps.promptId; } - get events(): AgentEvent[] { + get events(): readonly AgentEvent[] { return this._events; } @@ -102,23 +106,11 @@ class LegacyAgentProtocol implements AgentProtocol { this._beginNewStream(); const streamId = this._translationState.streamId; const parts = contentPartsToGeminiParts(message); - const userMessage = this._makeInternalEvent('message', { - role: 'user', - content: message, - ...(payload._meta ? { _meta: payload._meta } : {}), - }); + const userMessage = this._makeUserMessageEvent(message, payload._meta); this._emit([userMessage]); - void Promise.resolve().then(async () => { - this._ensureAgentStart(); - try { - await this._runLoop(parts); - } catch (err: unknown) { - this._emitErrorAndAgentEnd(err); - this._markStreamDone(); - } - }); + this._scheduleRunLoop(parts); return { streamId }; } @@ -127,6 +119,26 @@ class LegacyAgentProtocol implements AgentProtocol { this._abortController.abort(); } + private _scheduleRunLoop(initialParts: Part[]): void { + setTimeout(() => { + void this._runLoopInBackground(initialParts); + }, 0); + } + + private async _runLoopInBackground(initialParts: Part[]): Promise { + this._ensureAgentStart(); + try { + await this._runLoop(initialParts); + } catch (err: unknown) { + if (this._abortController.signal.aborted || isAbortLikeError(err)) { + this._ensureAgentEnd('aborted'); + } else { + this._emitErrorAndAgentEnd(err); + } + this._markStreamDone(); + } + } + private async _runLoop(initialParts: Part[]): Promise { let currentParts: Part[] = initialParts; let turnCount = 0; @@ -135,17 +147,11 @@ class LegacyAgentProtocol implements AgentProtocol { while (true) { turnCount++; if (maxTurns >= 0 && turnCount > maxTurns) { - this._emit([ - this._makeInternalEvent('agent_end', { - reason: 'max_turns', - data: { - code: 'MAX_TURNS_EXCEEDED', - maxTurns, - turnCount: turnCount - 1, - }, - }), - ]); - this._markStreamDone(); + this._finishStream('max_turns', { + code: 'MAX_TURNS_EXCEEDED', + maxTurns, + turnCount: turnCount - 1, + }); return; } @@ -158,8 +164,7 @@ class LegacyAgentProtocol implements AgentProtocol { for await (const event of responseStream) { if (this._abortController.signal.aborted) { - this._ensureAgentEnd('aborted'); - this._markStreamDone(); + this._finishStream('aborted'); return; } @@ -170,8 +175,7 @@ class LegacyAgentProtocol implements AgentProtocol { this._emit(translateEvent(event, this._translationState)); if (event.type === GeminiEventType.Error) { - this._ensureAgentEnd('failed'); - this._markStreamDone(); + this._finishStream('failed'); return; } @@ -179,15 +183,13 @@ class LegacyAgentProtocol implements AgentProtocol { event.type === GeminiEventType.InvalidStream || event.type === GeminiEventType.ContextWindowWillOverflow ) { - this._ensureAgentEnd('failed'); - this._markStreamDone(); + this._finishStream('failed'); return; } if (event.type === GeminiEventType.Finished) { if (toolCallRequests.length === 0) { - this._ensureAgentEnd(mapFinishReason(event.value.reason)); - this._markStreamDone(); + this._finishStream(mapFinishReason(event.value.reason)); return; } continue; @@ -203,9 +205,13 @@ class LegacyAgentProtocol implements AgentProtocol { } } + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; + } + if (toolCallRequests.length === 0) { - this._ensureAgentEnd('completed'); - this._markStreamDone(); + this._finishStream('completed'); return; } @@ -214,6 +220,11 @@ class LegacyAgentProtocol implements AgentProtocol { this._abortController.signal, ); + if (this._abortController.signal.aborted) { + this._finishStream('aborted'); + return; + } + const toolResponseParts: Part[] = []; for (const tc of completedToolCalls) { const response = tc.response; @@ -227,7 +238,7 @@ class LegacyAgentProtocol implements AgentProtocol { const data = buildToolResponseData(response); this._emit([ - this._makeInternalEvent('tool_response', { + this._makeToolResponseEvent({ requestId: request.callId, name: request.name, content, @@ -261,8 +272,7 @@ class LegacyAgentProtocol implements AgentProtocol { tc.response.error !== undefined, ); if (stopTool) { - this._ensureAgentEnd('completed'); - this._markStreamDone(); + this._finishStream('completed'); return; } @@ -270,8 +280,7 @@ class LegacyAgentProtocol implements AgentProtocol { isFatalToolError(tc.response.errorType), ); if (fatalTool) { - this._ensureAgentEnd('failed'); - this._markStreamDone(); + this._finishStream('failed'); return; } @@ -313,21 +322,29 @@ class LegacyAgentProtocol implements AgentProtocol { private _ensureAgentStart(): void { if (!this._translationState.streamStartEmitted) { this._translationState.streamStartEmitted = true; - this._emit([this._makeInternalEvent('agent_start', {})]); + this._emit([this._makeAgentStartEvent()]); } } private _ensureAgentEnd(reason: StreamEndReason = 'completed'): void { if (!this._agentEndEmitted && this._translationState.streamStartEmitted) { this._agentEndEmitted = true; - this._emit([ - this._makeInternalEvent('agent_end', { - reason, - }), - ]); + this._emit([this._makeAgentEndEvent(reason)]); } } + private _finishStream( + reason: StreamEndReason, + data?: Record, + ): void { + if (data && !this._agentEndEmitted) { + this._emit([this._makeAgentEndEvent(reason, data)]); + } else { + this._ensureAgentEnd(reason); + } + this._markStreamDone(); + } + /** * Preserve error identity fields in _meta so downstream consumers can * reconstruct fatal CLI errors. @@ -346,10 +363,13 @@ class LegacyAgentProtocol implements AgentProtocol { if ('code' in err) { meta['code'] = err.code; } + if ('status' in err) { + meta['status'] = err.status; + } } this._emit([ - this._makeInternalEvent('error', { + this._makeErrorEvent({ status: 'INTERNAL', message, fatal: true, @@ -360,22 +380,76 @@ class LegacyAgentProtocol implements AgentProtocol { this._ensureAgentEnd('failed'); } - private _makeInternalEvent( - type: T, - payload: Omit< - Partial>, - 'id' | 'timestamp' | 'streamId' | 'type' - >, - ): AgentEvent { - const id = `${this._translationState.streamId}-${this._translationState.eventCounter++}`; + private _nextEventFields() { return { - ...payload, - id, + id: `${this._translationState.streamId}-${this._translationState.eventCounter++}`, timestamp: new Date().toISOString(), streamId: this._translationState.streamId, - type, }; } + + private _makeUserMessageEvent( + content: ContentPart[], + meta?: Record, + ): AgentEvent<'message'> { + const event = { + ...this._nextEventFields(), + type: 'message', + role: 'user', + content, + ...(meta ? { _meta: meta } : {}), + } satisfies AgentEvent<'message'>; + return event; + } + + private _makeToolResponseEvent( + payload: Omit< + AgentEvent<'tool_response'>, + 'id' | 'timestamp' | 'streamId' | 'type' + >, + ): AgentEvent<'tool_response'> { + const event = { + ...this._nextEventFields(), + type: 'tool_response', + ...payload, + } satisfies AgentEvent<'tool_response'>; + return event; + } + + private _makeAgentStartEvent(): AgentEvent<'agent_start'> { + const event = { + ...this._nextEventFields(), + type: 'agent_start', + } satisfies AgentEvent<'agent_start'>; + return event; + } + + private _makeAgentEndEvent( + reason: StreamEndReason, + data?: Record, + ): AgentEvent<'agent_end'> { + const event = { + ...this._nextEventFields(), + type: 'agent_end', + reason, + ...(data ? { data } : {}), + } satisfies AgentEvent<'agent_end'>; + return event; + } + + private _makeErrorEvent( + payload: Omit< + AgentEvent<'error'>, + 'id' | 'timestamp' | 'streamId' | 'type' + >, + ): AgentEvent<'error'> { + const event = { + ...this._nextEventFields(), + type: 'error', + ...payload, + } satisfies AgentEvent<'error'>; + return event; + } } export class LegacyAgentSession extends AgentSession {