From bb459defe9be588510d572abfe70d0b51dc0ccf4 Mon Sep 17 00:00:00 2001 From: Adam Weidman Date: Thu, 26 Mar 2026 11:04:51 -0400 Subject: [PATCH] fix(core): preserve first-turn display content --- packages/cli/src/nonInteractiveCli.test.ts | 48 +++-- packages/cli/src/nonInteractiveCli.ts | 30 +-- packages/cli/src/utils/errors.ts | 1 + packages/core/src/agent/agent-session.test.ts | 26 ++- .../core/src/agent/event-translator.test.ts | 1 + packages/core/src/agent/event-translator.ts | 2 +- .../src/agent/legacy-agent-session.test.ts | 178 ++++++++++-------- .../core/src/agent/legacy-agent-session.ts | 43 ++++- packages/core/src/agent/mock.test.ts | 2 +- packages/core/src/agent/mock.ts | 9 +- packages/core/src/agent/types.ts | 7 +- 11 files changed, 227 insertions(+), 120 deletions(-) diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index fe3128be63..70ede3ddfe 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -265,6 +265,9 @@ describe('runNonInteractive', () => { [{ text: 'Test input' }], expect.any(AbortSignal), 'prompt-id-1', + undefined, + false, + 'Test input', ); expect(getWrittenOutput()).toBe('Hello World\n'); // Note: Telemetry shutdown is now handled in runExitCleanup() in cleanup.ts @@ -429,6 +432,9 @@ describe('runNonInteractive', () => { [{ text: 'Tool response' }], expect.any(AbortSignal), 'prompt-id-2', + undefined, + false, + undefined, ); expect(getWrittenOutput()).toBe('Final answer\n'); }); @@ -586,6 +592,9 @@ describe('runNonInteractive', () => { ], expect.any(AbortSignal), 'prompt-id-3', + undefined, + false, + undefined, ); expect(getWrittenOutput()).toBe('Sorry, let me try again.\n'); }); @@ -725,6 +734,9 @@ describe('runNonInteractive', () => { processedParts, expect.any(AbortSignal), 'prompt-id-7', + undefined, + false, + rawInput, ); // 6. Assert the final output is correct @@ -758,6 +770,9 @@ describe('runNonInteractive', () => { [{ text: 'Test input' }], expect.any(AbortSignal), 'prompt-id-1', + undefined, + false, + 'Test input', ); expect(processStdoutSpy).toHaveBeenCalledWith( JSON.stringify( @@ -961,6 +976,9 @@ describe('runNonInteractive', () => { [{ text: 'Empty response test' }], expect.any(AbortSignal), 'prompt-id-empty', + undefined, + false, + 'Empty response test', ); // This should output JSON with empty response but include stats @@ -1095,6 +1113,9 @@ describe('runNonInteractive', () => { [{ text: 'Prompt from command' }], expect.any(AbortSignal), 'prompt-id-slash', + undefined, + false, + '/testcommand', ); expect(getWrittenOutput()).toBe('Response from command\n'); @@ -1138,6 +1159,9 @@ describe('runNonInteractive', () => { [{ text: 'Slash command output' }], expect.any(AbortSignal), 'prompt-id-slash', + undefined, + false, + '/help', ); expect(getWrittenOutput()).toBe('Response to slash command\n'); handleSlashCommandSpy.mockRestore(); @@ -1268,17 +1292,16 @@ describe('runNonInteractive', () => { (process.stdin as any).setRawMode = vi.fn(); } - const stdinOnSpy = vi.spyOn(process.stdin, 'on').mockImplementation( - ( - event: string | symbol, - listener: (...args: unknown[]) => void, - ) => { - if (event === 'keypress') { - listener('\u0003', { ctrl: true, name: 'c' }); - } - return process.stdin; - }, - ); + const stdinOnSpy = vi + .spyOn(process.stdin, 'on') + .mockImplementation( + (event: string | symbol, listener: (...args: unknown[]) => void) => { + if (event === 'keypress') { + listener('\u0003', { ctrl: true, name: 'c' }); + } + return process.stdin; + }, + ); // eslint-disable-next-line @typescript-eslint/no-explicit-any vi.spyOn(process.stdin as any, 'setRawMode').mockImplementation(() => true); vi.spyOn(process.stdin, 'resume').mockImplementation(() => process.stdin); @@ -1377,6 +1400,9 @@ describe('runNonInteractive', () => { [{ text: '/unknowncommand' }], expect.any(AbortSignal), 'prompt-id-unknown', + undefined, + false, + '/unknowncommand', ); expect(getWrittenOutput()).toBe('Response to unknown\n'); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 6311247911..bae45e1e32 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -301,7 +301,10 @@ export async function runNonInteractive({ // Start the agentic loop (runs in background) const { streamId } = await session.send({ - message: geminiPartsToContentParts(query), + message: { + content: geminiPartsToContentParts(query), + displayContent: input, + }, }); if (streamId === null) { throw new Error( @@ -465,18 +468,17 @@ export async function runNonInteractive({ } if (event.data?.['errorType'] === ToolErrorType.NO_SPACE_LEFT) { - runTerminalExitHandler(() => - handleToolError( - event.name, - new Error(errorMsg), - config, - typeof event.data?.['errorType'] === 'string' - ? event.data['errorType'] - : undefined, - displayText, - ), + terminalProcessExitHandled = true; + handleToolError( + event.name, + new Error(errorMsg), + config, + typeof event.data?.['errorType'] === 'string' + ? event.data['errorType'] + : undefined, + displayText, ); - break; + return; } handleToolError( event.name, @@ -528,7 +530,9 @@ export async function runNonInteractive({ typeof event.data?.['turnCount'] === 'number'; if (isConfiguredTurnLimit) { - runTerminalExitHandler(() => handleMaxTurnsExceededError(config)); + runTerminalExitHandler(() => + handleMaxTurnsExceededError(config), + ); } else if (streamFormatter) { streamFormatter.emitEvent({ type: JsonStreamEventType.ERROR, diff --git a/packages/cli/src/utils/errors.ts b/packages/cli/src/utils/errors.ts index 913fc0d562..5e48abed99 100644 --- a/packages/cli/src/utils/errors.ts +++ b/packages/cli/src/utils/errors.ts @@ -18,6 +18,7 @@ import { isFatalToolError, debugLogger, coreEvents, + getErrorType, getErrorMessage, getErrorType, } from '@google/gemini-cli-core'; diff --git a/packages/core/src/agent/agent-session.test.ts b/packages/core/src/agent/agent-session.test.ts index e3ff1c5dc0..2ee9e4b7f3 100644 --- a/packages/core/src/agent/agent-session.test.ts +++ b/packages/core/src/agent/agent-session.test.ts @@ -7,7 +7,19 @@ import { describe, expect, it } from 'vitest'; import { AgentSession } from './agent-session.js'; import { MockAgentProtocol } from './mock.js'; -import type { AgentEvent } from './types.js'; +import type { AgentEvent, AgentSend } from './types.js'; + +function makeMessageSend( + text: string, + displayContent?: string, +): Extract { + return { + message: { + content: [{ type: 'text', text }], + ...(displayContent ? { displayContent } : {}), + }, + }; +} describe('AgentSession', () => { it('should passthrough simple methods', async () => { @@ -51,7 +63,7 @@ describe('AgentSession', () => { const events: AgentEvent[] = []; for await (const event of session.sendStream({ - message: [{ type: 'text', text: 'hi' }], + ...makeMessageSend('hi'), })) { events.push(event); } @@ -139,7 +151,7 @@ describe('AgentSession', () => { const events: AgentEvent[] = []; for await (const event of session.sendStream({ - message: [{ type: 'text', text: 'hi' }], + ...makeMessageSend('hi'), })) { events.push(event); } @@ -178,7 +190,7 @@ describe('AgentSession', () => { protocol.pushResponse([{ type: 'message' }]); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'request' }], + ...makeMessageSend('request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -242,7 +254,7 @@ describe('AgentSession', () => { }, ]); await session.send({ - message: [{ type: 'text', text: 'request' }], + ...makeMessageSend('request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -303,7 +315,7 @@ describe('AgentSession', () => { }, ]); const { streamId: streamId1 } = await session.send({ - message: [{ type: 'text', text: 'first request' }], + ...makeMessageSend('first request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -315,7 +327,7 @@ describe('AgentSession', () => { }, ]); await session.send({ - message: [{ type: 'text', text: 'second request' }], + ...makeMessageSend('second request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); diff --git a/packages/core/src/agent/event-translator.test.ts b/packages/core/src/agent/event-translator.test.ts index f40c6c27ad..be9d8ea40e 100644 --- a/packages/core/src/agent/event-translator.test.ts +++ b/packages/core/src/agent/event-translator.test.ts @@ -679,6 +679,7 @@ describe('mapError', () => { expect(result.status).toBe('RESOURCE_EXHAUSTED'); expect(result.message).toBe('Rate limit'); expect(result.fatal).toBe(true); + expect(result._meta?.['status']).toBe(429); expect(result._meta?.['rawError']).toEqual({ message: 'Rate limit', status: 429, diff --git a/packages/core/src/agent/event-translator.ts b/packages/core/src/agent/event-translator.ts index 73f93f4a15..00b5d12b4f 100644 --- a/packages/core/src/agent/event-translator.ts +++ b/packages/core/src/agent/event-translator.ts @@ -403,7 +403,7 @@ export function mapError( } if (isStructuredError(error)) { - const structuredMeta = { ...meta, rawError: error }; + const structuredMeta = { ...meta, rawError: error, status: error.status }; return { status: mapHttpToGrpcStatus(error.status), message: error.message, diff --git a/packages/core/src/agent/legacy-agent-session.test.ts b/packages/core/src/agent/legacy-agent-session.test.ts index 438b1e5ef0..8cd92ca08d 100644 --- a/packages/core/src/agent/legacy-agent-session.test.ts +++ b/packages/core/src/agent/legacy-agent-session.test.ts @@ -10,7 +10,7 @@ import { LegacyAgentSession } from './legacy-agent-session.js'; import type { LegacyAgentSessionDeps } from './legacy-agent-session.js'; import { GeminiEventType } from '../core/turn.js'; import type { ServerGeminiStreamEvent } from '../core/turn.js'; -import type { AgentEvent } from './types.js'; +import type { AgentEvent, AgentSend } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; import type { CompletedToolCall, @@ -72,6 +72,18 @@ function makeToolRequest(callId: string, name: string): ToolCallRequestInfo { }; } +function makeMessageSend( + text: string, + displayContent?: string, +): Extract { + return { + message: { + content: [{ type: 'text', text }], + ...(displayContent ? { displayContent } : {}), + }, + }; +} + function makeCompletedToolCall( callId: string, name: string, @@ -140,9 +152,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const result = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const result = await session.send(makeMessageSend('hi')); expect(result.streamId).toBe('test-stream'); }); @@ -162,7 +172,10 @@ describe('LegacyAgentSession', () => { const session = new LegacyAgentSession(deps); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], + message: { + content: [{ type: 'text', text: 'hi' }], + displayContent: 'raw input', + }, _meta: { source: 'user-test' }, }); @@ -170,12 +183,57 @@ describe('LegacyAgentSession', () => { (e): e is AgentEvent<'message'> => e.type === 'message' && e.role === 'user' && e.streamId === streamId, ); - expect(userMessage?.content).toEqual([{ type: 'text', text: 'hi' }]); + expect(userMessage?.content).toEqual([ + { type: 'text', text: 'raw input' }, + ]); expect(userMessage?._meta).toEqual({ source: 'user-test' }); + await vi.advanceTimersByTimeAsync(0); + expect(sendMock).toHaveBeenCalledWith( + [{ text: 'hi' }], + expect.any(AbortSignal), + 'test-prompt', + undefined, + false, + 'raw input', + ); await collectEvents(session, { streamId: streamId ?? undefined }); }); + it('accepts legacy message-array sends without displayContent', async () => { + const sendMock = deps.client.sendMessageStream as ReturnType< + typeof vi.fn + >; + sendMock.mockReturnValue( + makeStream([ + { + type: GeminiEventType.Finished, + value: { reason: FinishReason.STOP, usageMetadata: undefined }, + }, + ]), + ); + + const session = new LegacyAgentSession(deps); + const { streamId } = await session.send({ + message: [{ type: 'text', text: 'hi' }], + }); + + const userMessage = session.events.find( + (e): e is AgentEvent<'message'> => + e.type === 'message' && e.role === 'user' && e.streamId === streamId, + ); + expect(userMessage?.content).toEqual([{ type: 'text', text: 'hi' }]); + await vi.advanceTimersByTimeAsync(0); + expect(sendMock).toHaveBeenCalledWith( + [{ text: 'hi' }], + expect.any(AbortSignal), + 'test-prompt', + undefined, + false, + undefined, + ); + }); + it('returns streamId before emitting agent_start', async () => { const sendMock = deps.client.sendMessageStream as ReturnType< typeof vi.fn @@ -195,9 +253,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); expect(streamId).toBe('test-stream'); expect(liveEvents.some((event) => event.type === 'agent_start')).toBe( @@ -235,14 +291,12 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'first' }], - }); + const { streamId } = await session.send(makeMessageSend('first')); await vi.advanceTimersByTimeAsync(0); - await expect( - session.send({ message: [{ type: 'text', text: 'second' }] }), - ).rejects.toThrow('cannot be called while a stream is active'); + await expect(session.send(makeMessageSend('second'))).rejects.toThrow( + 'cannot be called while a stream is active', + ); resolveHang?.(); await collectEvents(session, { streamId: streamId ?? undefined }); @@ -273,16 +327,12 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first' }], - }); + const first = await session.send(makeMessageSend('first')); const firstEvents = await collectEvents(session, { streamId: first.streamId ?? undefined, }); - const second = await session.send({ - message: [{ type: 'text', text: 'second' }], - }); + const second = await session.send(makeMessageSend('second')); const secondEvents = await collectEvents(session, { streamId: second.streamId ?? undefined, }); @@ -330,7 +380,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const types = events.map((e) => e.type); @@ -387,7 +437,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'read a file' }] }); + await session.send(makeMessageSend('read a file')); const events = await collectEvents(session); const types = events.map((e) => e.type); @@ -455,9 +505,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([errorToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'write file' }], - }); + await session.send(makeMessageSend('write file')); const events = await collectEvents(session); const toolResp = events.find( @@ -506,9 +554,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([stopToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'do something' }], - }); + await session.send(makeMessageSend('do something')); const events = await collectEvents(session); const streamEnd = events.find( @@ -552,9 +598,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([fatalToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'write file' }], - }); + await session.send(makeMessageSend('write file')); const events = await collectEvents(session); const toolResp = events.find( @@ -592,7 +636,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const streamEnd = events.find( @@ -621,7 +665,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const blocked = events.find( @@ -663,7 +707,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -690,7 +734,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const warning = events.find( @@ -738,7 +782,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const streamEnd = events.find( @@ -762,7 +806,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const errorEvents = events.filter( @@ -799,9 +843,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await vi.advanceTimersByTimeAsync(0); await session.abort(); @@ -847,7 +889,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); // Give the loop time to start processing await new Promise((r) => setTimeout(r, 50)); @@ -891,9 +933,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await new Promise((resolve) => setTimeout(resolve, 25)); await session.abort(); @@ -935,7 +975,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); await collectEvents(session); expect(session.events.length).toBeGreaterThan(0); @@ -964,9 +1004,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await collectEvents(session, { streamId: streamId ?? undefined }); unsubscribe(); @@ -1002,9 +1040,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); const liveEvents: AgentEvent[] = []; @@ -1012,9 +1048,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const second = await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + const second = await session.send(makeMessageSend('second request')); await collectEvents(session, { streamId: second.streamId ?? undefined }); unsubscribe(); @@ -1058,14 +1092,10 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); - const second = await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + const second = await session.send(makeMessageSend('second request')); await collectEvents(session, { streamId: second.streamId ?? undefined }); const firstStreamEvents = await collectEvents(session, { @@ -1120,14 +1150,10 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); - await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + await session.send(makeMessageSend('second request')); await collectEvents(session); const firstAgentMessage = session.events.find( @@ -1175,7 +1201,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); expect(events.length).toBeGreaterThan(0); @@ -1196,7 +1222,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); expect(events[events.length - 1]?.type).toBe('agent_end'); @@ -1244,7 +1270,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'do it' }] }); + await session.send(makeMessageSend('do it')); const events = await collectEvents(session); // Only one agent_end at the very end @@ -1291,7 +1317,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'go' }] }); + await session.send(makeMessageSend('go')); const events = await collectEvents(session); // Should have at least one usage event from the intermediate Finished @@ -1314,7 +1340,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1342,7 +1368,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1365,7 +1391,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1385,7 +1411,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1405,7 +1431,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( diff --git a/packages/core/src/agent/legacy-agent-session.ts b/packages/core/src/agent/legacy-agent-session.ts index d8044e77e3..fa1d652eb2 100644 --- a/packages/core/src/agent/legacy-agent-session.ts +++ b/packages/core/src/agent/legacy-agent-session.ts @@ -93,6 +93,9 @@ class LegacyAgentProtocol implements AgentProtocol { 'LegacyAgentSession.send() only supports message sends for the moment.', ); } + const normalizedMessage = Array.isArray(message) + ? { content: message, displayContent: undefined } + : message; if (this._activeStreamId) { // TODO: Interactive may eventually allow selected in-stream sends such as @@ -105,12 +108,16 @@ class LegacyAgentProtocol implements AgentProtocol { this._beginNewStream(); const streamId = this._translationState.streamId; - const parts = contentPartsToGeminiParts(message); - const userMessage = this._makeUserMessageEvent(message, payload._meta); + const parts = contentPartsToGeminiParts(normalizedMessage.content); + const userMessage = this._makeUserMessageEvent( + normalizedMessage.content, + normalizedMessage.displayContent, + payload._meta, + ); this._emit([userMessage]); - this._scheduleRunLoop(parts); + this._scheduleRunLoop(parts, normalizedMessage.displayContent); return { streamId }; } @@ -119,18 +126,24 @@ class LegacyAgentProtocol implements AgentProtocol { this._abortController.abort(); } - private _scheduleRunLoop(initialParts: Part[]): void { + private _scheduleRunLoop( + initialParts: Part[], + displayContent?: string, + ): void { // Use a macrotask so send() resolves with the streamId before agent_start // is emitted and consumers can attach to the stream without racing startup. setTimeout(() => { - void this._runLoopInBackground(initialParts); + void this._runLoopInBackground(initialParts, displayContent); }, 0); } - private async _runLoopInBackground(initialParts: Part[]): Promise { + private async _runLoopInBackground( + initialParts: Part[], + displayContent?: string, + ): Promise { this._ensureAgentStart(); try { - await this._runLoop(initialParts); + await this._runLoop(initialParts, displayContent); } catch (err: unknown) { if (this._abortController.signal.aborted || isAbortLikeError(err)) { this._ensureAgentEnd('aborted'); @@ -141,8 +154,12 @@ class LegacyAgentProtocol implements AgentProtocol { } } - private async _runLoop(initialParts: Part[]): Promise { + private async _runLoop( + initialParts: Part[], + initialDisplayContent?: string, + ): Promise { let currentParts: Part[] = initialParts; + let currentDisplayContent = initialDisplayContent; let turnCount = 0; const maxTurns = this._config.getMaxSessionTurns(); @@ -162,7 +179,11 @@ class LegacyAgentProtocol implements AgentProtocol { currentParts, this._abortController.signal, this._promptId, + undefined, + false, + currentDisplayContent, ); + currentDisplayContent = undefined; for await (const event of responseStream) { if (this._abortController.signal.aborted) { @@ -383,13 +404,17 @@ class LegacyAgentProtocol implements AgentProtocol { private _makeUserMessageEvent( content: ContentPart[], + displayContent?: string, meta?: Record, ): AgentEvent<'message'> { + const eventContent: ContentPart[] = displayContent + ? [{ type: 'text', text: displayContent }] + : content; const event = { ...this._nextEventFields(), type: 'message', role: 'user', - content, + content: eventContent, ...(meta ? { _meta: meta } : {}), } satisfies AgentEvent<'message'>; return event; diff --git a/packages/core/src/agent/mock.test.ts b/packages/core/src/agent/mock.test.ts index f5138e388a..64403008a6 100644 --- a/packages/core/src/agent/mock.test.ts +++ b/packages/core/src/agent/mock.test.ts @@ -34,7 +34,7 @@ describe('MockAgentProtocol', () => { const streamPromise = waitForStreamEnd(session); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], + message: { content: [{ type: 'text', text: 'hi' }] }, }); expect(streamId).toBeDefined(); diff --git a/packages/core/src/agent/mock.ts b/packages/core/src/agent/mock.ts index 80d8ebae2f..26ef965d6d 100644 --- a/packages/core/src/agent/mock.ts +++ b/packages/core/src/agent/mock.ts @@ -10,6 +10,7 @@ import type { AgentEventData, AgentProtocol, AgentSend, + ContentPart, Unsubscribe, } from './types.js'; @@ -133,11 +134,17 @@ export class MockAgentProtocol implements AgentProtocol { // 1. User/Update event (BEFORE agent_start) if ('message' in payload && payload.message) { + const message = Array.isArray(payload.message) + ? { content: payload.message, displayContent: undefined } + : payload.message; + const userContent: ContentPart[] = message.displayContent + ? [{ type: 'text', text: message.displayContent }] + : message.content; eventsToEmit.push( normalize({ type: 'message', role: 'user', - content: payload.message, + content: userContent, _meta: payload._meta, }), ); diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 4ec369d066..512a8c9507 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -46,7 +46,12 @@ type RequireExactlyOne = { }[keyof T]; interface AgentSendPayloads { - message: ContentPart[]; + message: + | ContentPart[] + | { + content: ContentPart[]; + displayContent?: string; + }; elicitations: ElicitationResponse[]; update: { title?: string; model?: string; config?: Record }; action: { type: string; data: unknown };