diff --git a/packages/cli/src/__snapshots__/nonInteractiveCliAgentSession.test.ts.snap b/packages/cli/src/__snapshots__/nonInteractiveCliAgentSession.test.ts.snap new file mode 100644 index 0000000000..92f396a59c --- /dev/null +++ b/packages/cli/src/__snapshots__/nonInteractiveCliAgentSession.test.ts.snap @@ -0,0 +1,35 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`runNonInteractive > should emit appropriate error event in streaming JSON mode: 'loop detected' 1`] = ` +"{"type":"init","timestamp":"","session_id":"test-session-id","model":"test-model"} +{"type":"message","timestamp":"","role":"user","content":"Loop test"} +{"type":"error","timestamp":"","severity":"warning","message":"Loop detected, stopping execution"} +{"type":"result","timestamp":"","status":"success","stats":{"total_tokens":0,"input_tokens":0,"output_tokens":0,"cached":0,"input":0,"duration_ms":,"tool_calls":0,"models":{}}} +" +`; + +exports[`runNonInteractive > should emit appropriate error event in streaming JSON mode: 'max session turns' 1`] = ` +"{"type":"init","timestamp":"","session_id":"test-session-id","model":"test-model"} +{"type":"message","timestamp":"","role":"user","content":"Max turns test"} +{"type":"error","timestamp":"","severity":"error","message":"Maximum session turns exceeded"} +{"type":"result","timestamp":"","status":"success","stats":{"total_tokens":0,"input_tokens":0,"output_tokens":0,"cached":0,"input":0,"duration_ms":,"tool_calls":0,"models":{}}} +" +`; + +exports[`runNonInteractive > should emit appropriate events for streaming JSON output 1`] = ` +"{"type":"init","timestamp":"","session_id":"test-session-id","model":"test-model"} +{"type":"message","timestamp":"","role":"user","content":"Stream test"} +{"type":"message","timestamp":"","role":"assistant","content":"Thinking...","delta":true} +{"type":"tool_use","timestamp":"","tool_name":"testTool","tool_id":"tool-1","parameters":{"arg1":"value1"}} +{"type":"tool_result","timestamp":"","tool_id":"tool-1","status":"success","output":"Tool executed successfully"} +{"type":"message","timestamp":"","role":"assistant","content":"Final answer","delta":true} +{"type":"result","timestamp":"","status":"success","stats":{"total_tokens":0,"input_tokens":0,"output_tokens":0,"cached":0,"input":0,"duration_ms":,"tool_calls":0,"models":{}}} +" +`; + +exports[`runNonInteractive > should write a single newline between sequential text outputs from the model 1`] = ` +"Use mock tool +Use mock tool again +Finished. +" +`; diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 6adf1e22ef..855707de9e 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -166,7 +166,7 @@ describe('runNonInteractive', () => { }; mockConfig = { - initialize: vi.fn().mockResolvedValue(undefined), + initialize: vi.fn().mockReturnValue(Promise.resolve(undefined)), getMessageBus: vi.fn().mockReturnValue({ subscribe: vi.fn(), unsubscribe: vi.fn(), @@ -190,6 +190,7 @@ describe('runNonInteractive', () => { isTrustedFolder: vi.fn().mockReturnValue(false), getRawOutput: vi.fn().mockReturnValue(false), getAcceptRawOutputRisk: vi.fn().mockReturnValue(false), + getAgentSessionNoninteractiveEnabled: vi.fn().mockReturnValue(false), } as unknown as Config; mockSettings = { diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 4f9d817204..26daaf66a1 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -46,6 +46,7 @@ import { handleMaxTurnsExceededError, } from './utils/errors.js'; import { TextOutput } from './ui/utils/textOutput.js'; +import { runNonInteractive as runNonInteractiveAgentSession } from './nonInteractiveCliAgentSession.js'; interface RunNonInteractiveParams { config: Config; @@ -55,13 +56,16 @@ interface RunNonInteractiveParams { resumedSessionData?: ResumedSessionData; } -export async function runNonInteractive({ - config, - settings, - input, - prompt_id, - resumedSessionData, -}: RunNonInteractiveParams): Promise { +export async function runNonInteractive( + params: RunNonInteractiveParams, +): Promise { + const useAgentSession = params.config.getAgentSessionNoninteractiveEnabled(); + if (useAgentSession) { + return runNonInteractiveAgentSession(params); + } + + const { config, settings, input, prompt_id, resumedSessionData } = params; + return promptIdContext.run(prompt_id, async () => { const consolePatcher = new ConsolePatcher({ stderr: true, diff --git a/packages/cli/src/nonInteractiveCliAgentSession.test.ts b/packages/cli/src/nonInteractiveCliAgentSession.test.ts new file mode 100644 index 0000000000..617f80aca6 --- /dev/null +++ b/packages/cli/src/nonInteractiveCliAgentSession.test.ts @@ -0,0 +1,2436 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + Config, + ToolRegistry, + ServerGeminiStreamEvent, + SessionMetrics, + AnyDeclarativeTool, + AnyToolInvocation, + UserFeedbackPayload, +} from '@google/gemini-cli-core'; +import { + ToolErrorType, + GeminiEventType, + OutputFormat, + uiTelemetryService, + FatalInputError, + CoreEvent, + CoreToolCallStatus, +} from '@google/gemini-cli-core'; +import type { Part } from '@google/genai'; +import { runNonInteractive } from './nonInteractiveCliAgentSession.js'; +import { + describe, + it, + expect, + beforeEach, + afterEach, + vi, + type Mock, + type MockInstance, +} from 'vitest'; +import type { LoadedSettings } from './config/settings.js'; + +// Mock core modules +vi.mock('./ui/hooks/atCommandProcessor.js'); + +const mockSetupInitialActivityLogger = vi.hoisted(() => vi.fn()); +vi.mock('./utils/devtoolsService.js', () => ({ + setupInitialActivityLogger: mockSetupInitialActivityLogger, +})); + +const mockCoreEvents = vi.hoisted(() => ({ + on: vi.fn(), + off: vi.fn(), + emit: vi.fn(), + emitConsoleLog: vi.fn(), + emitFeedback: vi.fn(), + drainBacklogs: vi.fn(), +})); + +const mockSchedulerSchedule = vi.hoisted(() => vi.fn()); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + + class MockChatRecordingService { + initialize = vi.fn(); + recordMessage = vi.fn(); + recordMessageTokens = vi.fn(); + recordToolCalls = vi.fn(); + } + + return { + ...original, + Scheduler: class { + schedule = mockSchedulerSchedule; + cancelAll = vi.fn(); + }, + isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), + ChatRecordingService: MockChatRecordingService, + uiTelemetryService: { + getMetrics: vi.fn(), + }, + LegacyAgentSession: original.LegacyAgentSession, + geminiPartsToContentParts: original.geminiPartsToContentParts, + coreEvents: mockCoreEvents, + createWorkingStdio: vi.fn(() => ({ + stdout: process.stdout, + stderr: process.stderr, + })), + }; +}); + +const mockGetCommands = vi.hoisted(() => vi.fn()); +const mockCommandServiceCreate = vi.hoisted(() => vi.fn()); +vi.mock('./services/CommandService.js', () => ({ + CommandService: { + create: mockCommandServiceCreate, + }, +})); + +vi.mock('./services/FileCommandLoader.js'); +vi.mock('./services/McpPromptLoader.js'); +vi.mock('./services/BuiltinCommandLoader.js'); + +describe('runNonInteractive', () => { + let mockConfig: Config; + let mockSettings: LoadedSettings; + let mockToolRegistry: ToolRegistry; + let consoleErrorSpy: MockInstance; + let processStdoutSpy: MockInstance; + let processStderrSpy: MockInstance; + let mockGeminiClient: { + sendMessageStream: Mock; + resumeChat: Mock; + getChatRecordingService: Mock; + getChat: Mock; + getCurrentSequenceModel: Mock; + }; + const MOCK_SESSION_METRICS: SessionMetrics = { + models: {}, + tools: { + totalCalls: 0, + totalSuccess: 0, + totalFail: 0, + totalDurationMs: 0, + totalDecisions: { + accept: 0, + reject: 0, + modify: 0, + auto_accept: 0, + }, + byName: {}, + }, + files: { + totalLinesAdded: 0, + totalLinesRemoved: 0, + }, + }; + + beforeEach(async () => { + mockSchedulerSchedule.mockReset(); + + mockCommandServiceCreate.mockResolvedValue({ + getCommands: mockGetCommands, + }); + + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + processStdoutSpy = vi + .spyOn(process.stdout, 'write') + .mockImplementation(() => true); + vi.spyOn(process.stdout, 'on').mockImplementation(() => process.stdout); + processStderrSpy = vi + .spyOn(process.stderr, 'write') + .mockImplementation(() => true); + vi.spyOn(process, 'exit').mockImplementation((code) => { + throw new Error(`process.exit(${code}) called`); + }); + + mockToolRegistry = { + getTool: vi.fn(), + getFunctionDeclarations: vi.fn().mockReturnValue([]), + } as unknown as ToolRegistry; + + mockGeminiClient = { + sendMessageStream: vi.fn(), + resumeChat: vi.fn().mockResolvedValue(undefined), + getChatRecordingService: vi.fn(() => ({ + initialize: vi.fn(), + recordMessage: vi.fn(), + recordMessageTokens: vi.fn(), + recordToolCalls: vi.fn(), + })), + getChat: vi.fn(() => ({ recordCompletedToolCalls: vi.fn() })), + getCurrentSequenceModel: vi.fn().mockReturnValue(null), + }; + + mockConfig = { + initialize: vi.fn().mockReturnValue(Promise.resolve(undefined)), + getMessageBus: vi.fn().mockReturnValue({ + subscribe: vi.fn(), + unsubscribe: vi.fn(), + publish: vi.fn(), + }), + getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + getMaxSessionTurns: vi.fn().mockReturnValue(10), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getProjectRoot: vi.fn().mockReturnValue('/test/project'), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/project/.gemini/tmp'), + }, + getIdeMode: vi.fn().mockReturnValue(false), + + getContentGeneratorConfig: vi.fn().mockReturnValue({}), + getDebugMode: vi.fn().mockReturnValue(false), + getOutputFormat: vi.fn().mockReturnValue('text'), + getModel: vi.fn().mockReturnValue('test-model'), + getFolderTrust: vi.fn().mockReturnValue(false), + isTrustedFolder: vi.fn().mockReturnValue(false), + getRawOutput: vi.fn().mockReturnValue(false), + getAcceptRawOutputRisk: vi.fn().mockReturnValue(false), + getAgentSessionNoninteractiveEnabled: vi.fn().mockReturnValue(false), + } as unknown as Config; + + mockSettings = { + system: { path: '', settings: {} }, + systemDefaults: { path: '', settings: {} }, + user: { path: '', settings: {} }, + workspace: { path: '', settings: {} }, + errors: [], + setValue: vi.fn(), + merged: { + security: { + auth: { + enforcedType: undefined, + }, + }, + }, + isTrusted: true, + migratedInMemoryScopes: new Set(), + forScope: vi.fn(), + computeMergedSettings: vi.fn(), + } as unknown as LoadedSettings; + + const { handleAtCommand } = await import( + './ui/hooks/atCommandProcessor.js' + ); + vi.mocked(handleAtCommand).mockImplementation(async ({ query }) => ({ + processedQuery: [{ text: query }], + })); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + async function* createStreamFromEvents( + events: ServerGeminiStreamEvent[], + ): AsyncGenerator { + for (const event of events) { + yield event; + } + } + + const getWrittenOutput = () => + processStdoutSpy.mock.calls.map((c) => c[0]).join(''); + + it('should process input and write text output', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' World' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-1', + }); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + 'prompt-id-1', + undefined, + false, + 'Test input', + ); + expect(getWrittenOutput()).toBe('Hello World\n'); + // Note: Telemetry shutdown is now handled in runExitCleanup() in cleanup.ts + // so we no longer expect shutdownTelemetry to be called directly here + }); + + it('should stream the specific stream started by send', async () => { + const { LegacyAgentSession } = await import('@google/gemini-cli-core'); + const streamSpy = vi.spyOn(LegacyAgentSession.prototype, 'stream'); + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello again' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-stream', + }); + + expect(streamSpy).toHaveBeenCalledWith({ streamId: expect.any(String) }); + }); + + it('fails fast if the session acknowledges a message send without a stream', async () => { + const { LegacyAgentSession } = await import('@google/gemini-cli-core'); + const sendSpy = vi + .spyOn(LegacyAgentSession.prototype, 'send') + .mockResolvedValue({ streamId: null }); + const streamSpy = vi.spyOn(LegacyAgentSession.prototype, 'stream'); + + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-null-stream', + }), + ).rejects.toThrow( + 'LegacyAgentSession.send() unexpectedly returned no stream for a message send.', + ); + + expect(streamSpy).not.toHaveBeenCalled(); + + sendSpy.mockRestore(); + streamSpy.mockRestore(); + }); + + it('should register activity logger when GEMINI_CLI_ACTIVITY_LOG_TARGET is set', async () => { + vi.stubEnv('GEMINI_CLI_ACTIVITY_LOG_TARGET', '/tmp/test.jsonl'); + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-activity-logger', + }); + + expect(mockSetupInitialActivityLogger).toHaveBeenCalledWith(mockConfig); + vi.unstubAllEnvs(); + }); + + it('should not register activity logger when GEMINI_CLI_ACTIVITY_LOG_TARGET is not set', async () => { + vi.stubEnv('GEMINI_CLI_ACTIVITY_LOG_TARGET', ''); + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-activity-logger-off', + }); + + expect(mockSetupInitialActivityLogger).not.toHaveBeenCalled(); + vi.unstubAllEnvs(); + }); + + it('should handle a single tool call and respond', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', + }, + }; + const toolResponse: Part[] = [{ text: 'Tool response' }]; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Use a tool', + prompt_id: 'prompt-id-2', + }); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'testTool' })], + expect.any(AbortSignal), + ); + expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( + 2, + [{ text: 'Tool response' }], + expect.any(AbortSignal), + 'prompt-id-2', + undefined, + false, + undefined, + ); + expect(getWrittenOutput()).toBe('Final answer\n'); + }); + + it('should write a single newline between sequential text outputs from the model', async () => { + // This test simulates a multi-turn conversation to ensure that a single newline + // is printed between each block of text output from the model. + + // 1. Define the tool requests that the model will ask the CLI to run. + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'mock-tool', + name: 'mockTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-multi', + }, + }; + + // 2. Mock the execution of the tools. We just need them to succeed. + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: toolCallEvent.value, // This is generic enough for both calls + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [], + callId: 'mock-tool', + }, + }, + ]); + + // 3. Define the sequence of events streamed from the mock model. + // Turn 1: Model outputs text, then requests a tool call. + const modelTurn1: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Use mock tool' }, + toolCallEvent, + ]; + // Turn 2: Model outputs more text, then requests another tool call. + const modelTurn2: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Use mock tool again' }, + toolCallEvent, + ]; + // Turn 3: Model outputs a final answer. + const modelTurn3: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Finished.' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(modelTurn1)) + .mockReturnValueOnce(createStreamFromEvents(modelTurn2)) + .mockReturnValueOnce(createStreamFromEvents(modelTurn3)); + + // 4. Run the command. + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Use mock tool multiple times', + prompt_id: 'prompt-id-multi', + }); + + // 5. Verify the output. + // The rendered output should contain the text from each turn, separated by a + // single newline, with a final newline at the end. + expect(getWrittenOutput()).toMatchSnapshot(); + + // Also verify the tools were called as expected. + expect(mockSchedulerSchedule).toHaveBeenCalledTimes(2); + }); + + it('should handle error during tool execution and should send error back to the model', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'errorTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', + }, + }; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Error, + request: { + callId: 'tool-1', + name: 'errorTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', + }, + tool: {} as AnyDeclarativeTool, + response: { + callId: 'tool-1', + error: new Error('Execution failed'), + errorType: ToolErrorType.EXECUTION_FAILED, + responseParts: [ + { + functionResponse: { + name: 'errorTool', + response: { + output: 'Error: Execution failed', + }, + }, + }, + ], + resultDisplay: 'Execution failed', + contentLength: undefined, + }, + }, + ]); + const finalResponse: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Content, + value: 'Sorry, let me try again.', + }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) + .mockReturnValueOnce(createStreamFromEvents(finalResponse)); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Trigger tool error', + prompt_id: 'prompt-id-3', + }); + + expect(mockSchedulerSchedule).toHaveBeenCalled(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool errorTool: Execution failed', + ); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockGeminiClient.sendMessageStream).toHaveBeenNthCalledWith( + 2, + [ + { + functionResponse: { + name: 'errorTool', + response: { + output: 'Error: Execution failed', + }, + }, + }, + ], + expect.any(AbortSignal), + 'prompt-id-3', + undefined, + false, + undefined, + ); + expect(getWrittenOutput()).toBe('Sorry, let me try again.\n'); + }); + + it('should exit with error if sendMessageStream throws initially', async () => { + const apiError = new Error('API connection failed'); + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw apiError; + }); + + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Initial fail', + prompt_id: 'prompt-id-4', + }), + ).rejects.toThrow('API connection failed'); + }); + + it('should not exit if a tool is not found, and should send error back to model', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'nonexistentTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-5', + }, + }; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Error, + request: { + callId: 'tool-1', + name: 'nonexistentTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-5', + }, + response: { + callId: 'tool-1', + error: new Error('Tool "nonexistentTool" not found in registry.'), + resultDisplay: 'Tool "nonexistentTool" not found in registry.', + responseParts: [], + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + const finalResponse: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Content, + value: "Sorry, I can't find that tool.", + }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) + .mockReturnValueOnce(createStreamFromEvents(finalResponse)); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Trigger tool not found', + prompt_id: 'prompt-id-5', + }); + + expect(mockSchedulerSchedule).toHaveBeenCalled(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Error executing tool nonexistentTool: Tool "nonexistentTool" not found in registry.', + ); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(getWrittenOutput()).toBe("Sorry, I can't find that tool.\n"); + }); + + it('should exit when max session turns are exceeded', async () => { + vi.mocked(mockConfig.getMaxSessionTurns).mockReturnValue(0); + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Trigger loop', + prompt_id: 'prompt-id-6', + }), + ).rejects.toThrow('Reached max session turns for this session'); + }); + + it('should preprocess @include commands before sending to the model', async () => { + // 1. Mock the imported atCommandProcessor + const { handleAtCommand } = await import( + './ui/hooks/atCommandProcessor.js' + ); + const mockHandleAtCommand = vi.mocked(handleAtCommand); + + // 2. Define the raw input and the expected processed output + const rawInput = 'Summarize @file.txt'; + const processedParts: Part[] = [ + { text: 'Summarize @file.txt' }, + { text: '\n--- Content from referenced files ---\n' }, + { text: 'This is the content of the file.' }, + { text: '\n--- End of content ---' }, + ]; + + // 3. Setup the mock to return the processed parts + mockHandleAtCommand.mockResolvedValue({ + processedQuery: processedParts, + }); + + // Mock a simple stream response from the Gemini client + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Summary complete.' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + // 4. Run the non-interactive mode with the raw input + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: rawInput, + prompt_id: 'prompt-id-7', + }); + + // 5. Assert that sendMessageStream was called with the PROCESSED parts, not the raw input + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + processedParts, + expect.any(AbortSignal), + 'prompt-id-7', + undefined, + false, + rawInput, + ); + + // 6. Assert the final output is correct + expect(getWrittenOutput()).toBe('Summary complete.\n'); + }); + + it('should process input and write JSON output with stats', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello World' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-1', + }); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Test input' }], + expect.any(AbortSignal), + 'prompt-id-1', + undefined, + false, + 'Test input', + ); + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: 'Hello World', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + + it('should write JSON output with stats for tool-only commands (no text response)', async () => { + // Test the scenario where a command completes successfully with only tool calls + // but no text response - this would have caught the original bug + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-tool-only', + }, + }; + const toolResponse: Part[] = [{ text: 'Tool executed successfully' }]; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-tool-only', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + + // First call returns only tool call, no content + const firstCallEvents: ServerGeminiStreamEvent[] = [ + toolCallEvent, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + + // Second call returns no content (tool-only completion) + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 3 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); + + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Execute tool only', + prompt_id: 'prompt-id-tool-only', + }); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'testTool' })], + expect.any(AbortSignal), + ); + + // This should output JSON with empty response but include stats + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: '', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + + it('should keep only the final post-tool assistant text in JSON output', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-json-tool-text', + }, + }; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [{ text: 'Tool executed successfully' }], + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce( + createStreamFromEvents([ + { type: GeminiEventType.Content, value: 'Let me check that...' }, + toolCallEvent, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]), + ) + .mockReturnValueOnce( + createStreamFromEvents([ + { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 3 } }, + }, + ]), + ); + + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Use a tool', + prompt_id: 'prompt-id-json-tool-text', + }); + + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: 'Final answer', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + + it('should write JSON output with stats for empty response commands', async () => { + // Test the scenario where a command completes but produces no content at all + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 1 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Empty response test', + prompt_id: 'prompt-id-empty', + }); + + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Empty response test' }], + expect.any(AbortSignal), + 'prompt-id-empty', + undefined, + false, + 'Empty response test', + ); + + // This should output JSON with empty response but include stats + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: '', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + + it('should handle errors in JSON format', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const testError = new Error('Invalid input provided'); + + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw testError; + }); + + let thrownError: Error | null = null; + try { + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-error', + }); + // Should not reach here + expect.fail('Expected process.exit to be called'); + } catch (error) { + thrownError = error as Error; + } + + // Should throw because of mocked process.exit + expect(thrownError?.message).toBe('process.exit(1) called'); + + expect(mockCoreEvents.emitFeedback).toHaveBeenCalledWith( + 'error', + JSON.stringify( + { + session_id: 'test-session-id', + error: { + type: 'Error', + message: 'Invalid input provided', + code: 1, + }, + }, + null, + 2, + ), + ); + }); + + it('should handle FatalInputError with custom exit code in JSON format', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + const fatalError = new FatalInputError('Invalid command syntax provided'); + + mockGeminiClient.sendMessageStream.mockImplementation(() => { + throw fatalError; + }); + + let thrownError: Error | null = null; + try { + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Invalid syntax', + prompt_id: 'prompt-id-fatal', + }); + // Should not reach here + expect.fail('Expected process.exit to be called'); + } catch (error) { + thrownError = error as Error; + } + + // Should throw because of mocked process.exit with custom exit code + expect(thrownError?.message).toBe('process.exit(42) called'); + + expect(mockCoreEvents.emitFeedback).toHaveBeenCalledWith( + 'error', + JSON.stringify( + { + session_id: 'test-session-id', + error: { + type: 'FatalInputError', + message: 'Invalid command syntax provided', + code: 42, + }, + }, + null, + 2, + ), + ); + }); + + it('should execute a slash command that returns a prompt', async () => { + const mockCommand = { + name: 'testcommand', + description: 'a test command', + action: vi.fn().mockResolvedValue({ + type: 'submit_prompt', + content: [{ text: 'Prompt from command' }], + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Response from command' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/testcommand', + prompt_id: 'prompt-id-slash', + }); + + // Ensure the prompt sent to the model is from the command, not the raw input + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Prompt from command' }], + expect.any(AbortSignal), + 'prompt-id-slash', + undefined, + false, + '/testcommand', + ); + + expect(getWrittenOutput()).toBe('Response from command\n'); + }); + + it('should handle slash commands', async () => { + const nonInteractiveCliCommands = await import( + './nonInteractiveCliCommands.js' + ); + const handleSlashCommandSpy = vi.spyOn( + nonInteractiveCliCommands, + 'handleSlashCommand', + ); + handleSlashCommandSpy.mockResolvedValue([{ text: 'Slash command output' }]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Response to slash command' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/help', + prompt_id: 'prompt-id-slash', + }); + + expect(handleSlashCommandSpy).toHaveBeenCalledWith( + '/help', + expect.any(AbortController), + mockConfig, + mockSettings, + ); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: 'Slash command output' }], + expect.any(AbortSignal), + 'prompt-id-slash', + undefined, + false, + '/help', + ); + expect(getWrittenOutput()).toBe('Response to slash command\n'); + handleSlashCommandSpy.mockRestore(); + }); + + it('should handle cancellation (Ctrl+C)', async () => { + // Mock isTTY and setRawMode safely + const originalIsTTY = process.stdin.isTTY; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const originalSetRawMode = (process.stdin as any).setRawMode; + + Object.defineProperty(process.stdin, 'isTTY', { + value: true, + configurable: true, + }); + if (!originalSetRawMode) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (process.stdin as any).setRawMode = vi.fn(); + } + + const stdinOnSpy = vi + .spyOn(process.stdin, 'on') + .mockImplementation(() => process.stdin); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.spyOn(process.stdin as any, 'setRawMode').mockImplementation(() => true); + vi.spyOn(process.stdin, 'resume').mockImplementation(() => process.stdin); + vi.spyOn(process.stdin, 'pause').mockImplementation(() => process.stdin); + vi.spyOn(process.stdin, 'removeAllListeners').mockImplementation( + () => process.stdin, + ); + + // Cancellation will throw FatalCancellationError directly + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Thinking...' }, + ]; + // Create a stream that responds to abortion + mockGeminiClient.sendMessageStream.mockImplementation( + (_messages, signal: AbortSignal) => + (async function* () { + yield events[0]; + await new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, 1000); + signal.addEventListener('abort', () => { + clearTimeout(timeout); + setTimeout(() => { + reject(new Error('Aborted')); + }, 300); + }); + }); + })(), + ); + + const runPromise = runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Long running query', + prompt_id: 'prompt-id-cancel', + }); + + // Wait a bit for setup to complete and listeners to be registered + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Find the keypress handler registered by runNonInteractive + const keypressCall = stdinOnSpy.mock.calls.find( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (call) => (call[0] as any) === 'keypress', + ); + expect(keypressCall).toBeDefined(); + const keypressHandler = keypressCall?.[1] as ( + str: string, + key: { name?: string; ctrl?: boolean }, + ) => void; + + if (keypressHandler) { + // Simulate Ctrl+C + keypressHandler('\u0003', { ctrl: true, name: 'c' }); + } + + await expect(runPromise).rejects.toThrow('Operation cancelled.'); + + expect( + processStderrSpy.mock.calls.some( + // eslint-disable-next-line no-restricted-syntax + (call) => typeof call[0] === 'string' && call[0].includes('Cancelling'), + ), + ).toBe(true); + + // Restore original values + Object.defineProperty(process.stdin, 'isTTY', { + value: originalIsTTY, + configurable: true, + }); + if (originalSetRawMode) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (process.stdin as any).setRawMode = originalSetRawMode; + } else { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delete (process.stdin as any).setRawMode; + } + // Spies are automatically restored by vi.restoreAllMocks() in afterEach, + // but we can also do it manually if needed. + }); + + it('should honor cancellation that happens before session.send()', async () => { + const originalIsTTY = process.stdin.isTTY; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const originalSetRawMode = (process.stdin as any).setRawMode; + + Object.defineProperty(process.stdin, 'isTTY', { + value: true, + configurable: true, + }); + if (!originalSetRawMode) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (process.stdin as any).setRawMode = vi.fn(); + } + + const stdinOnSpy = vi + .spyOn(process.stdin, 'on') + .mockImplementation( + (event: string | symbol, listener: (...args: unknown[]) => void) => { + if (event === 'keypress') { + listener('\u0003', { ctrl: true, name: 'c' }); + } + return process.stdin; + }, + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.spyOn(process.stdin as any, 'setRawMode').mockImplementation(() => true); + vi.spyOn(process.stdin, 'resume').mockImplementation(() => process.stdin); + vi.spyOn(process.stdin, 'pause').mockImplementation(() => process.stdin); + vi.spyOn(process.stdin, 'removeAllListeners').mockImplementation( + () => process.stdin, + ); + + // Cancellation will throw FatalCancellationError directly + + const { LegacyAgentSession } = await import('@google/gemini-cli-core'); + const sendSpy = vi.spyOn(LegacyAgentSession.prototype, 'send'); + + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Cancelled query', + prompt_id: 'prompt-id-pre-send-cancel', + }), + ).rejects.toThrow('Operation cancelled.'); + + expect(sendSpy).not.toHaveBeenCalled(); + expect(stdinOnSpy).toHaveBeenCalled(); + sendSpy.mockRestore(); + + Object.defineProperty(process.stdin, 'isTTY', { + value: originalIsTTY, + configurable: true, + }); + if (originalSetRawMode) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (process.stdin as any).setRawMode = originalSetRawMode; + } else { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + delete (process.stdin as any).setRawMode; + } + }); + + it('should throw FatalInputError if a command requires confirmation', async () => { + const mockCommand = { + name: 'confirm', + description: 'a command that needs confirmation', + action: vi.fn().mockResolvedValue({ + type: 'confirm_shell_commands', + commands: ['rm -rf /'], + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/confirm', + prompt_id: 'prompt-id-confirm', + }), + ).rejects.toThrow( + 'Exiting due to a confirmation prompt requested by the command.', + ); + }); + + it('should treat an unknown slash command as a regular prompt', async () => { + // No commands are mocked, so any slash command is "unknown" + mockGetCommands.mockReturnValue([]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Response to unknown' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/unknowncommand', + prompt_id: 'prompt-id-unknown', + }); + + // Ensure the raw input is sent to the model + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledWith( + [{ text: '/unknowncommand' }], + expect.any(AbortSignal), + 'prompt-id-unknown', + undefined, + false, + '/unknowncommand', + ); + + expect(getWrittenOutput()).toBe('Response to unknown\n'); + }); + + it('should throw for unhandled command result types', async () => { + const mockCommand = { + name: 'noaction', + description: 'unhandled type', + action: vi.fn().mockResolvedValue({ + type: 'unhandled', + }), + }; + mockGetCommands.mockReturnValue([mockCommand]); + + await expect( + runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/noaction', + prompt_id: 'prompt-id-unhandled', + }), + ).rejects.toThrow( + 'Exiting due to command result that is not supported in non-interactive mode.', + ); + }); + + it('should pass arguments to the slash command action', async () => { + const mockAction = vi.fn().mockResolvedValue({ + type: 'submit_prompt', + content: [{ text: 'Prompt from command' }], + }); + const mockCommand = { + name: 'testargs', + description: 'a test command', + action: mockAction, + }; + mockGetCommands.mockReturnValue([mockCommand]); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Acknowledged' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 1 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/testargs arg1 arg2', + prompt_id: 'prompt-id-args', + }); + + expect(mockAction).toHaveBeenCalledWith(expect.any(Object), 'arg1 arg2'); + + expect(getWrittenOutput()).toBe('Acknowledged\n'); + }); + + it('should instantiate CommandService with correct loaders for slash commands', async () => { + // This test indirectly checks that handleSlashCommand is using the right loaders. + const { FileCommandLoader } = await import( + './services/FileCommandLoader.js' + ); + const { McpPromptLoader } = await import('./services/McpPromptLoader.js'); + const { BuiltinCommandLoader } = await import( + './services/BuiltinCommandLoader.js' + ); + mockGetCommands.mockReturnValue([]); // No commands found, so it will fall through + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Acknowledged' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 1 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: '/mycommand', + prompt_id: 'prompt-id-loaders', + }); + + // Check that loaders were instantiated with the config + expect(FileCommandLoader).toHaveBeenCalledTimes(1); + expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig); + expect(McpPromptLoader).toHaveBeenCalledTimes(1); + expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig); + expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig); + + // Check that instances were passed to CommandService.create + expect(mockCommandServiceCreate).toHaveBeenCalledTimes(1); + const loadersArg = mockCommandServiceCreate.mock.calls[0][0]; + expect(loadersArg).toHaveLength(3); + expect(loadersArg[0]).toBe( + vi.mocked(BuiltinCommandLoader).mock.instances[0], + ); + expect(loadersArg[1]).toBe(vi.mocked(McpPromptLoader).mock.instances[0]); + expect(loadersArg[2]).toBe(vi.mocked(FileCommandLoader).mock.instances[0]); + }); + + it('should allow a normally-excluded tool when --allowed-tools is set', async () => { + // By default, ShellTool is excluded in non-interactive mode. + // This test ensures that --allowed-tools overrides this exclusion. + vi.mocked(mockConfig.getToolRegistry).mockReturnValue({ + getTool: vi.fn().mockReturnValue({ + name: 'ShellTool', + description: 'A shell tool', + run: vi.fn(), + }), + getFunctionDeclarations: vi.fn().mockReturnValue([{ name: 'ShellTool' }]), + } as unknown as ToolRegistry); + + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-shell-1', + name: 'ShellTool', + args: { command: 'ls' }, + isClientInitiated: false, + prompt_id: 'prompt-id-allowed', + }, + }; + const toolResponse: Part[] = [{ text: 'file.txt' }]; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: { + callId: 'tool-shell-1', + name: 'ShellTool', + args: { command: 'ls' }, + isClientInitiated: false, + prompt_id: 'prompt-id-allowed', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: toolResponse, + callId: 'tool-shell-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'file.txt' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'List the files', + prompt_id: 'prompt-id-allowed', + }); + + expect(mockSchedulerSchedule).toHaveBeenCalledWith( + [expect.objectContaining({ name: 'ShellTool' })], + expect.any(AbortSignal), + ); + expect(getWrittenOutput()).toBe('file.txt\n'); + }); + + describe('CoreEvents Integration', () => { + it('subscribes to UserFeedback and drains backlog on start', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-events', + }); + + expect(mockCoreEvents.on).toHaveBeenCalledWith( + CoreEvent.UserFeedback, + expect.any(Function), + ); + expect(mockCoreEvents.drainBacklogs).toHaveBeenCalledTimes(1); + }); + + it('unsubscribes from UserFeedback on finish', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-events', + }); + + expect(mockCoreEvents.off).toHaveBeenCalledWith( + CoreEvent.UserFeedback, + expect.any(Function), + ); + }); + + it('logs to process.stderr when UserFeedback event is received', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-events', + }); + + // Get the registered handler + const handler = mockCoreEvents.on.mock.calls.find( + (call: unknown[]) => call[0] === CoreEvent.UserFeedback, + )?.[1]; + expect(handler).toBeDefined(); + + // Simulate an event + const payload: UserFeedbackPayload = { + severity: 'error', + message: 'Test error message', + }; + handler(payload); + + expect(processStderrSpy).toHaveBeenCalledWith( + '[ERROR] Test error message\n', + ); + }); + + it('logs optional error object to process.stderr in debug mode', async () => { + vi.mocked(mockConfig.getDebugMode).mockReturnValue(true); + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test', + prompt_id: 'prompt-id-events', + }); + + // Get the registered handler + const handler = mockCoreEvents.on.mock.calls.find( + (call: unknown[]) => call[0] === CoreEvent.UserFeedback, + )?.[1]; + expect(handler).toBeDefined(); + + // Simulate an event with error object + const errorObj = new Error('Original error'); + // Mock stack for deterministic testing + errorObj.stack = 'Error: Original error\n at test'; + const payload: UserFeedbackPayload = { + severity: 'warning', + message: 'Test warning message', + error: errorObj, + }; + handler(payload); + + expect(processStderrSpy).toHaveBeenCalledWith( + '[WARNING] Test warning message\n', + ); + expect(processStderrSpy).toHaveBeenCalledWith( + 'Error: Original error\n at test\n', + ); + }); + }); + + it('should emit appropriate events for streaming JSON output', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-stream', + }, + }; + + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [{ text: 'Tool response' }], + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + resultDisplay: 'Tool executed successfully', + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Thinking...' }, + toolCallEvent, + ]; + const secondCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents(secondCallEvents)); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Stream test', + prompt_id: 'prompt-id-stream', + }); + + const output = getWrittenOutput(); + const sanitizedOutput = output + .replace(/"timestamp":"[^"]+"/g, '"timestamp":""') + .replace(/"duration_ms":\d+/g, '"duration_ms":'); + expect(sanitizedOutput).toMatchSnapshot(); + }); + + it('should handle EPIPE error gracefully', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello' }, + { type: GeminiEventType.Content, value: ' World' }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + // Mock process.exit to track calls without throwing + vi.spyOn(process, 'exit').mockImplementation((_code) => undefined as never); + + // Simulate EPIPE error on stdout + const stdoutErrorCallback = (process.stdout.on as Mock).mock.calls.find( + (call) => call[0] === 'error', + )?.[1]; + + if (stdoutErrorCallback) { + stdoutErrorCallback({ code: 'EPIPE' }); + } + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'EPIPE test', + prompt_id: 'prompt-id-epipe', + }); + + // Since EPIPE is simulated, it might exit early or continue depending on timing, + // but our main goal is to verify the handler is registered and handles EPIPE. + expect(process.stdout.on).toHaveBeenCalledWith( + 'error', + expect.any(Function), + ); + }); + + it('should resume chat when resumedSessionData is provided', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Resumed' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + const resumedSessionData = { + conversation: { + sessionId: 'resumed-session-id', + messages: [ + { role: 'user', parts: [{ text: 'Previous message' }] }, + ] as any, // eslint-disable-line @typescript-eslint/no-explicit-any + startTime: new Date().toISOString(), + lastUpdated: new Date().toISOString(), + firstUserMessage: 'Previous message', + projectHash: 'test-hash', + }, + filePath: '/path/to/session.json', + }; + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Continue', + prompt_id: 'prompt-id-resume', + resumedSessionData, + }); + + expect(mockGeminiClient.resumeChat).toHaveBeenCalledWith( + expect.any(Array), + resumedSessionData, + ); + expect(getWrittenOutput()).toBe('Resumed\n'); + }); + + it.each([ + { + name: 'loop detected', + events: [ + { type: GeminiEventType.LoopDetected }, + ] as ServerGeminiStreamEvent[], + input: 'Loop test', + promptId: 'prompt-id-loop', + }, + { + name: 'max session turns', + events: [ + { type: GeminiEventType.MaxSessionTurns }, + ] as ServerGeminiStreamEvent[], + input: 'Max turns test', + promptId: 'prompt-id-max-turns', + }, + ])( + 'should emit appropriate error event in streaming JSON mode: $name', + async ({ events, input, promptId }) => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + const streamEvents: ServerGeminiStreamEvent[] = [ + ...events, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(streamEvents), + ); + + try { + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input, + prompt_id: promptId, + }); + } catch { + // Expected exit + } + + const output = getWrittenOutput(); + const sanitizedOutput = output + .replace(/"timestamp":"[^"]+"/g, '"timestamp":""') + .replace(/"duration_ms":\d+/g, '"duration_ms":'); + expect(sanitizedOutput).toMatchSnapshot(); + }, + ); + + it('should log error when tool recording fails', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-tool-error', + }, + }; + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Success, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + responseParts: [], + callId: 'tool-1', + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]); + + const events: ServerGeminiStreamEvent[] = [ + toolCallEvent, + { type: GeminiEventType.Content, value: 'Done' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(events)) + .mockReturnValueOnce( + createStreamFromEvents([ + { type: GeminiEventType.Content, value: 'Done' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]), + ); + + // Mock getChat to throw when recording tool calls + const mockChat = { + recordCompletedToolCalls: vi.fn().mockImplementation(() => { + throw new Error('Recording failed'); + }), + }; + mockGeminiClient.getChat = vi.fn().mockReturnValue(mockChat); + mockGeminiClient.getCurrentSequenceModel = vi + .fn() + .mockReturnValue('model-1'); + + // Mock debugLogger.error + const { debugLogger } = await import('@google/gemini-cli-core'); + const debugLoggerErrorSpy = vi + .spyOn(debugLogger, 'error') + .mockImplementation(() => {}); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Tool recording error test', + prompt_id: 'prompt-id-tool-error', + }); + + expect(debugLoggerErrorSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'Error recording completed tool call information: Error: Recording failed', + ), + ); + expect(getWrittenOutput()).toContain('Done'); + }); + + it('should stop agent execution immediately when a tool call returns STOP_EXECUTION error', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'stop-call', + name: 'stopTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-stop', + }, + }; + + // Mock tool execution returning STOP_EXECUTION + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Error, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason from hook'), + resultDisplay: undefined, + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Executing tool...' }, + toolCallEvent, + ]; + + // Setup the mock to return events for the first call. + // We expect the loop to terminate after the tool execution. + // If it doesn't, it might call sendMessageStream again, which we'll assert against. + mockGeminiClient.sendMessageStream + .mockReturnValueOnce(createStreamFromEvents(firstCallEvents)) + .mockReturnValueOnce(createStreamFromEvents([])); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Run stop tool', + prompt_id: 'prompt-id-stop', + }); + + expect(mockSchedulerSchedule).toHaveBeenCalled(); + + // The key assertion: sendMessageStream should have been called ONLY ONCE (initial user input). + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); + + expect(processStderrSpy).toHaveBeenCalledWith( + 'Agent execution stopped: Stop reason from hook\n', + ); + }); + + it('should write JSON output when a tool call returns STOP_EXECUTION error', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue(OutputFormat.JSON); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'stop-call', + name: 'stopTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-stop-json', + }, + }; + + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Error, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason'), + resultDisplay: undefined, + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Partial content' }, + toolCallEvent, + ]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(firstCallEvents), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Run stop tool', + prompt_id: 'prompt-id-stop-json', + }); + + expect(processStdoutSpy).toHaveBeenCalledWith( + JSON.stringify( + { + session_id: 'test-session-id', + response: 'Partial content', + stats: MOCK_SESSION_METRICS, + }, + null, + 2, + ), + ); + }); + + it('should emit result event when a tool call returns STOP_EXECUTION error in streaming JSON mode', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'stop-call', + name: 'stopTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-stop-stream', + }, + }; + + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Error, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'stop-call', + responseParts: [{ text: 'error occurred' }], + errorType: ToolErrorType.STOP_EXECUTION, + error: new Error('Stop reason'), + resultDisplay: undefined, + }, + }, + ]); + + const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(firstCallEvents), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Run stop tool', + prompt_id: 'prompt-id-stop-stream', + }); + + const output = getWrittenOutput(); + expect(output).toContain('"type":"result"'); + expect(output).toContain('"status":"success"'); + }); + + describe('Agent Execution Events', () => { + it('should handle AgentExecutionStopped event', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.AgentExecutionStopped, + value: { reason: 'Stopped by hook' }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test stop', + prompt_id: 'prompt-id-stop', + }); + + expect(processStderrSpy).toHaveBeenCalledWith( + 'Agent execution stopped: Stopped by hook\n', + ); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('should handle AgentExecutionBlocked event', async () => { + const allEvents: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.AgentExecutionBlocked, + value: { reason: 'Blocked by hook' }, + }, + { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(allEvents), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test block', + prompt_id: 'prompt-id-block', + }); + + expect(processStderrSpy).toHaveBeenCalledWith( + '[WARNING] Agent execution blocked: Blocked by hook\n', + ); + // Stream continues after blocked event — content should be output + expect(getWrittenOutput()).toBe('Final answer\n'); + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); + }); + }); + + describe('Output Sanitization', () => { + const ANSI_SEQUENCE = '\u001B[31mRed Text\u001B[0m'; + const OSC_HYPERLINK = + '\u001B]8;;http://example.com\u001B\\Link\u001B]8;;\u001B\\'; + const PLAIN_TEXT_RED = 'Red Text'; + const PLAIN_TEXT_LINK = 'Link'; + + it('should sanitize ANSI output by default', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: ANSI_SEQUENCE }, + { type: GeminiEventType.Content, value: ' ' }, + { type: GeminiEventType.Content, value: OSC_HYPERLINK }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getRawOutput).mockReturnValue(false); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-sanitization', + }); + + expect(getWrittenOutput()).toBe(`${PLAIN_TEXT_RED} ${PLAIN_TEXT_LINK}\n`); + }); + + it('should allow ANSI output when rawOutput is true', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: ANSI_SEQUENCE }, + { type: GeminiEventType.Content, value: ' ' }, + { type: GeminiEventType.Content, value: OSC_HYPERLINK }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getRawOutput).mockReturnValue(true); + vi.mocked(mockConfig.getAcceptRawOutputRisk).mockReturnValue(true); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-raw', + }); + + expect(getWrittenOutput()).toBe(`${ANSI_SEQUENCE} ${OSC_HYPERLINK}\n`); + }); + + it('should allow ANSI output when only acceptRawOutputRisk is true', async () => { + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: ANSI_SEQUENCE }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getRawOutput).mockReturnValue(false); + vi.mocked(mockConfig.getAcceptRawOutputRisk).mockReturnValue(true); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-accept-only', + }); + + expect(getWrittenOutput()).toBe(`${ANSI_SEQUENCE}\n`); + }); + + it('should warn when rawOutput is true and acceptRisk is false', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getRawOutput).mockReturnValue(true); + vi.mocked(mockConfig.getAcceptRawOutputRisk).mockReturnValue(false); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-warn', + }); + + expect(processStderrSpy).toHaveBeenCalledWith( + expect.stringContaining('[WARNING] --raw-output is enabled'), + ); + }); + + it('should not warn when rawOutput is true and acceptRisk is true', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 0 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getRawOutput).mockReturnValue(true); + vi.mocked(mockConfig.getAcceptRawOutputRisk).mockReturnValue(true); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-no-warn', + }); + + expect(processStderrSpy).not.toHaveBeenCalledWith( + expect.stringContaining('[WARNING] --raw-output is enabled'), + ); + }); + + it('should emit warning event for loop_detected in streaming JSON mode', async () => { + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + const streamEvents: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.LoopDetected } as ServerGeminiStreamEvent, + { type: GeminiEventType.Content, value: 'Continuing after loop' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 5 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(streamEvents), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Loop test explicit', + prompt_id: 'prompt-id-loop-explicit', + }); + + const output = getWrittenOutput(); + // The STREAM_JSON output should contain an error event with warning severity + expect(output).toContain('"type":"error"'); + expect(output).toContain('"severity":"warning"'); + expect(output).toContain('Loop detected'); + }); + + it('should report cancelled tool calls as success in stream-json mode (legacy parity)', async () => { + const toolCallEvent: ServerGeminiStreamEvent = { + type: GeminiEventType.ToolCallRequest, + value: { + callId: 'tool-1', + name: 'testTool', + args: { arg1: 'value1' }, + isClientInitiated: false, + prompt_id: 'prompt-id-cancel', + }, + }; + + // Mock the scheduler to return a cancelled status + mockSchedulerSchedule.mockResolvedValue([ + { + status: CoreToolCallStatus.Cancelled, + request: toolCallEvent.value, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'tool-1', + responseParts: [{ text: 'Operation cancelled' }], + resultDisplay: 'Cancelled', + }, + }, + ]); + + const events: ServerGeminiStreamEvent[] = [ + toolCallEvent, + { + type: GeminiEventType.Content, + value: 'Model continues...', + }, + ]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + vi.mocked(mockConfig.getOutputFormat).mockReturnValue( + OutputFormat.STREAM_JSON, + ); + vi.mocked(uiTelemetryService.getMetrics).mockReturnValue( + MOCK_SESSION_METRICS, + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'Test input', + prompt_id: 'prompt-id-cancel', + }); + + const output = getWrittenOutput(); + expect(output).toContain('"type":"tool_result"'); + expect(output).toContain('"status":"success"'); + }); + }); +}); diff --git a/packages/cli/src/nonInteractiveCliAgentSession.ts b/packages/cli/src/nonInteractiveCliAgentSession.ts new file mode 100644 index 0000000000..78fc18be4e --- /dev/null +++ b/packages/cli/src/nonInteractiveCliAgentSession.ts @@ -0,0 +1,621 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + Config, + ResumedSessionData, + UserFeedbackPayload, + AgentEvent, + ContentPart, +} from '@google/gemini-cli-core'; +import { isSlashCommand } from './ui/utils/commandUtils.js'; +import type { LoadedSettings } from './config/settings.js'; +import { + convertSessionToClientHistory, + FatalError, + FatalAuthenticationError, + FatalInputError, + FatalSandboxError, + FatalConfigError, + FatalTurnLimitedError, + FatalToolExecutionError, + FatalCancellationError, + promptIdContext, + OutputFormat, + JsonFormatter, + StreamJsonFormatter, + JsonStreamEventType, + uiTelemetryService, + coreEvents, + CoreEvent, + createWorkingStdio, + Scheduler, + ROOT_SCHEDULER_ID, + LegacyAgentSession, + ToolErrorType, + geminiPartsToContentParts, +} from '@google/gemini-cli-core'; + +import type { Part } from '@google/genai'; +import readline from 'node:readline'; +import stripAnsi from 'strip-ansi'; + +import { handleSlashCommand } from './nonInteractiveCliCommands.js'; +import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; +import { handleAtCommand } from './ui/hooks/atCommandProcessor.js'; +import { handleError, handleToolError } from './utils/errors.js'; +import { TextOutput } from './ui/utils/textOutput.js'; + +interface RunNonInteractiveParams { + config: Config; + settings: LoadedSettings; + input: string; + prompt_id: string; + resumedSessionData?: ResumedSessionData; +} + +export async function runNonInteractive({ + config, + settings, + input, + prompt_id, + resumedSessionData, +}: RunNonInteractiveParams): Promise { + return promptIdContext.run(prompt_id, async () => { + const consolePatcher = new ConsolePatcher({ + stderr: true, + interactive: false, + debugMode: config.getDebugMode(), + onNewMessage: (msg) => { + coreEvents.emitConsoleLog(msg.type, msg.content); + }, + }); + + if (process.env['GEMINI_CLI_ACTIVITY_LOG_TARGET']) { + const { setupInitialActivityLogger } = await import( + './utils/devtoolsService.js' + ); + await setupInitialActivityLogger(config); + } + + const { stdout: workingStdout } = createWorkingStdio(); + const textOutput = new TextOutput(workingStdout); + + const handleUserFeedback = (payload: UserFeedbackPayload) => { + const prefix = payload.severity.toUpperCase(); + process.stderr.write(`[${prefix}] ${payload.message}\n`); + if (payload.error && config.getDebugMode()) { + const errorToLog = + payload.error instanceof Error + ? payload.error.stack || payload.error.message + : String(payload.error); + process.stderr.write(`${errorToLog}\n`); + } + }; + + const startTime = Date.now(); + const streamFormatter = + config.getOutputFormat() === OutputFormat.STREAM_JSON + ? new StreamJsonFormatter() + : null; + + const abortController = new AbortController(); + + // Track cancellation state + let isAborting = false; + let cancelMessageTimer: NodeJS.Timeout | null = null; + + // Setup stdin listener for Ctrl+C detection + let stdinWasRaw = false; + let rl: readline.Interface | null = null; + + const setupStdinCancellation = () => { + // Only setup if stdin is a TTY (user can interact) + if (!process.stdin.isTTY) { + return; + } + + // Save original raw mode state + stdinWasRaw = process.stdin.isRaw || false; + + // Enable raw mode to capture individual keypresses + process.stdin.setRawMode(true); + process.stdin.resume(); + + // Setup readline to emit keypress events + rl = readline.createInterface({ + input: process.stdin, + escapeCodeTimeout: 0, + }); + readline.emitKeypressEvents(process.stdin, rl); + + // Listen for Ctrl+C + const keypressHandler = ( + str: string, + key: { name?: string; ctrl?: boolean }, + ) => { + // Detect Ctrl+C: either ctrl+c key combo or raw character code 3 + if ((key && key.ctrl && key.name === 'c') || str === '\u0003') { + // Only handle once + if (isAborting) { + return; + } + + isAborting = true; + + // Only show message if cancellation takes longer than 200ms + // This reduces verbosity for fast cancellations + cancelMessageTimer = setTimeout(() => { + process.stderr.write('\nCancelling...\n'); + }, 200); + + abortController.abort(); + } + }; + + process.stdin.on('keypress', keypressHandler); + }; + + const cleanupStdinCancellation = () => { + // Clear any pending cancel message timer + if (cancelMessageTimer) { + clearTimeout(cancelMessageTimer); + cancelMessageTimer = null; + } + + // Cleanup readline and stdin listeners + if (rl) { + rl.close(); + rl = null; + } + + // Remove keypress listener + process.stdin.removeAllListeners('keypress'); + + // Restore stdin to original state + if (process.stdin.isTTY) { + process.stdin.setRawMode(stdinWasRaw); + process.stdin.pause(); + } + }; + + let errorToHandle: unknown | undefined; + let abortSession = () => {}; + try { + consolePatcher.patch(); + + if ( + config.getRawOutput() && + !config.getAcceptRawOutputRisk() && + config.getOutputFormat() === OutputFormat.TEXT + ) { + process.stderr.write( + '[WARNING] --raw-output is enabled. Model output is not sanitized and may contain harmful ANSI sequences (e.g. for phishing or command injection). Use --accept-raw-output-risk to suppress this warning.\n', + ); + } + + // Setup stdin cancellation listener + setupStdinCancellation(); + + coreEvents.on(CoreEvent.UserFeedback, handleUserFeedback); + coreEvents.drainBacklogs(); + + // Handle EPIPE errors when the output is piped to a command that closes early. + process.stdout.on('error', (err: NodeJS.ErrnoException) => { + if (err.code === 'EPIPE') { + // Exit gracefully if the pipe is closed. + cleanupStdinCancellation(); + consolePatcher.cleanup(); + process.exit(0); + } + }); + + const geminiClient = config.getGeminiClient(); + const scheduler = new Scheduler({ + context: config, + messageBus: config.getMessageBus(), + getPreferredEditor: () => undefined, + schedulerId: ROOT_SCHEDULER_ID, + }); + + // Initialize chat. Resume if resume data is passed. + if (resumedSessionData) { + await geminiClient.resumeChat( + convertSessionToClientHistory( + resumedSessionData.conversation.messages, + ), + resumedSessionData, + ); + } + + // Emit init event for streaming JSON + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.INIT, + timestamp: new Date().toISOString(), + session_id: config.getSessionId(), + model: config.getModel(), + }); + } + + let query: Part[] | undefined; + + if (isSlashCommand(input)) { + const slashCommandResult = await handleSlashCommand( + input, + abortController, + config, + settings, + ); + if (slashCommandResult) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + query = slashCommandResult as Part[]; + } + } + + if (!query) { + const { processedQuery, error } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + escapePastedAtSymbols: false, + }); + if (error || !processedQuery) { + throw new FatalInputError( + error || 'Exiting due to an error processing the @ command.', + ); + } + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + query = processedQuery as Part[]; + } + + // Emit user message event for streaming JSON + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'user', + content: input, + }); + } + + // Create LegacyAgentSession — owns the agentic loop + const session = new LegacyAgentSession({ + client: geminiClient, + scheduler, + config, + promptId: prompt_id, + }); + + // Wire Ctrl+C to session abort + abortSession = () => { + void session.abort(); + }; + abortController.signal.addEventListener('abort', abortSession); + if (abortController.signal.aborted) { + throw new FatalCancellationError('Operation cancelled.'); + } + + // Start the agentic loop (runs in background) + const { streamId } = await session.send({ + message: { + content: geminiPartsToContentParts(query), + displayContent: input, + }, + }); + if (streamId === null) { + throw new Error( + 'LegacyAgentSession.send() unexpectedly returned no stream for a message send.', + ); + } + + const getTextContent = (parts?: ContentPart[]): string | undefined => { + const text = parts + ?.map((part) => (part.type === 'text' ? part.text : '')) + .join(''); + return text ? text : undefined; + }; + + const emitFinalSuccessResult = (): void => { + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: 'success', + stats: streamFormatter.convertToStreamStats(metrics, durationMs), + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + const formatter = new JsonFormatter(); + const stats = uiTelemetryService.getMetrics(); + textOutput.write( + formatter.format(config.getSessionId(), responseText, stats), + ); + } else { + textOutput.ensureTrailingNewline(); + } + }; + + const reconstructFatalError = (event: AgentEvent<'error'>): Error => { + const errorMeta = event._meta; + const name = + typeof errorMeta?.['errorName'] === 'string' + ? errorMeta['errorName'] + : undefined; + + let errToThrow: Error; + switch (name) { + case 'FatalAuthenticationError': + errToThrow = new FatalAuthenticationError(event.message); + break; + case 'FatalInputError': + errToThrow = new FatalInputError(event.message); + break; + case 'FatalSandboxError': + errToThrow = new FatalSandboxError(event.message); + break; + case 'FatalConfigError': + errToThrow = new FatalConfigError(event.message); + break; + case 'FatalTurnLimitedError': + errToThrow = new FatalTurnLimitedError(event.message); + break; + case 'FatalToolExecutionError': + errToThrow = new FatalToolExecutionError(event.message); + break; + case 'FatalCancellationError': + errToThrow = new FatalCancellationError(event.message); + break; + case 'FatalError': + errToThrow = new FatalError( + event.message, + typeof errorMeta?.['exitCode'] === 'number' + ? errorMeta['exitCode'] + : 1, + ); + break; + default: + errToThrow = new Error(event.message); + if (name) { + Object.defineProperty(errToThrow, 'name', { + value: name, + enumerable: true, + }); + } + break; + } + + if (errorMeta?.['exitCode'] !== undefined) { + Object.defineProperty(errToThrow, 'exitCode', { + value: errorMeta['exitCode'], + enumerable: true, + }); + } + if (errorMeta?.['code'] !== undefined) { + Object.defineProperty(errToThrow, 'code', { + value: errorMeta['code'], + enumerable: true, + }); + } + if (errorMeta?.['status'] !== undefined) { + Object.defineProperty(errToThrow, 'status', { + value: errorMeta['status'], + enumerable: true, + }); + } + return errToThrow; + }; + + // Consume AgentEvents for output formatting + let responseText = ''; + let preToolResponseText: string | undefined; + let streamEnded = false; + for await (const event of session.stream({ streamId })) { + if (streamEnded) break; + switch (event.type) { + case 'message': { + if (event.role === 'agent') { + for (const part of event.content) { + if (part.type === 'text') { + const isRaw = + config.getRawOutput() || config.getAcceptRawOutputRisk(); + const output = isRaw ? part.text : stripAnsi(part.text); + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.MESSAGE, + timestamp: new Date().toISOString(), + role: 'assistant', + content: output, + delta: true, + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { + responseText += output; + } else { + if (part.text) { + textOutput.write(output); + } + } + } + } + } + break; + } + case 'tool_request': { + if (config.getOutputFormat() === OutputFormat.JSON) { + // Final JSON output should reflect the last assistant answer after + // any tool orchestration, not intermediate pre-tool text. + preToolResponseText = responseText || preToolResponseText; + responseText = ''; + } + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_USE, + timestamp: new Date().toISOString(), + tool_name: event.name, + tool_id: event.requestId, + parameters: event.args, + }); + } + break; + } + case 'tool_response': { + textOutput.ensureTrailingNewline(); + if (streamFormatter) { + const displayText = getTextContent(event.displayContent); + const errorMsg = getTextContent(event.content) ?? 'Tool error'; + streamFormatter.emitEvent({ + type: JsonStreamEventType.TOOL_RESULT, + timestamp: new Date().toISOString(), + tool_id: event.requestId, + status: event.isError ? 'error' : 'success', + output: displayText, + error: event.isError + ? { + type: + typeof event.data?.['errorType'] === 'string' + ? event.data['errorType'] + : 'TOOL_EXECUTION_ERROR', + message: errorMsg, + } + : undefined, + }); + } + if (event.isError) { + const displayText = getTextContent(event.displayContent); + const errorMsg = getTextContent(event.content) ?? 'Tool error'; + + if (event.data?.['errorType'] === ToolErrorType.STOP_EXECUTION) { + if ( + config.getOutputFormat() === OutputFormat.JSON && + !responseText && + preToolResponseText + ) { + responseText = preToolResponseText; + } + const stopMessage = `Agent execution stopped: ${errorMsg}`; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`${stopMessage}\n`); + } + } + + if (event.data?.['errorType'] === ToolErrorType.NO_SPACE_LEFT) { + throw new FatalToolExecutionError( + 'Error executing tool ' + + event.name + + ': ' + + (displayText || errorMsg), + ); + } + handleToolError( + event.name, + new Error(errorMsg), + config, + typeof event.data?.['errorType'] === 'string' + ? event.data['errorType'] + : undefined, + displayText, + ); + } + break; + } + case 'error': { + if (event.fatal) { + throw reconstructFatalError(event); + } + + const errorCode = event._meta?.['code']; + + if (errorCode === 'AGENT_EXECUTION_BLOCKED') { + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`[WARNING] ${event.message}\n`); + } + break; + } + + const severity = + event.status === 'RESOURCE_EXHAUSTED' ? 'error' : 'warning'; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`[WARNING] ${event.message}\n`); + } + if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.ERROR, + timestamp: new Date().toISOString(), + severity, + message: event.message, + }); + } + break; + } + case 'agent_end': { + if (event.reason === 'aborted') { + throw new FatalCancellationError('Operation cancelled.'); + } else if (event.reason === 'max_turns') { + const isConfiguredTurnLimit = + typeof event.data?.['maxTurns'] === 'number' || + typeof event.data?.['turnCount'] === 'number'; + + if (isConfiguredTurnLimit) { + throw new FatalTurnLimitedError( + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); + } else if (streamFormatter) { + streamFormatter.emitEvent({ + type: JsonStreamEventType.ERROR, + timestamp: new Date().toISOString(), + severity: 'error', + message: 'Maximum session turns exceeded', + }); + } + } + + const stopMessage = + typeof event.data?.['message'] === 'string' + ? event.data['message'] + : ''; + if (stopMessage && config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`Agent execution stopped: ${stopMessage}\n`); + } + + emitFinalSuccessResult(); + streamEnded = true; + break; + } + case 'initialize': + case 'session_update': + case 'agent_start': + case 'tool_update': + case 'elicitation_request': + case 'elicitation_response': + case 'usage': + case 'custom': + // Explicitly ignore these non-interactive events + break; + default: + event satisfies never; + break; + } + } + } catch (error) { + errorToHandle = error; + } finally { + // Cleanup stdin cancellation before other cleanup + cleanupStdinCancellation(); + abortController.signal.removeEventListener('abort', abortSession); + + consolePatcher.cleanup(); + coreEvents.off(CoreEvent.UserFeedback, handleUserFeedback); + } + + if (errorToHandle) { + handleError(errorToHandle, config); + } + }); +} diff --git a/packages/cli/src/utils/errors.ts b/packages/cli/src/utils/errors.ts index 913fc0d562..774d9e994c 100644 --- a/packages/cli/src/utils/errors.ts +++ b/packages/cli/src/utils/errors.ts @@ -18,8 +18,8 @@ import { isFatalToolError, debugLogger, coreEvents, - getErrorMessage, getErrorType, + getErrorMessage, } from '@google/gemini-cli-core'; import { runSyncCleanup } from './cleanup.js'; diff --git a/packages/core/src/agent/agent-session.test.ts b/packages/core/src/agent/agent-session.test.ts index e3ff1c5dc0..2ee9e4b7f3 100644 --- a/packages/core/src/agent/agent-session.test.ts +++ b/packages/core/src/agent/agent-session.test.ts @@ -7,7 +7,19 @@ import { describe, expect, it } from 'vitest'; import { AgentSession } from './agent-session.js'; import { MockAgentProtocol } from './mock.js'; -import type { AgentEvent } from './types.js'; +import type { AgentEvent, AgentSend } from './types.js'; + +function makeMessageSend( + text: string, + displayContent?: string, +): Extract { + return { + message: { + content: [{ type: 'text', text }], + ...(displayContent ? { displayContent } : {}), + }, + }; +} describe('AgentSession', () => { it('should passthrough simple methods', async () => { @@ -51,7 +63,7 @@ describe('AgentSession', () => { const events: AgentEvent[] = []; for await (const event of session.sendStream({ - message: [{ type: 'text', text: 'hi' }], + ...makeMessageSend('hi'), })) { events.push(event); } @@ -139,7 +151,7 @@ describe('AgentSession', () => { const events: AgentEvent[] = []; for await (const event of session.sendStream({ - message: [{ type: 'text', text: 'hi' }], + ...makeMessageSend('hi'), })) { events.push(event); } @@ -178,7 +190,7 @@ describe('AgentSession', () => { protocol.pushResponse([{ type: 'message' }]); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'request' }], + ...makeMessageSend('request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -242,7 +254,7 @@ describe('AgentSession', () => { }, ]); await session.send({ - message: [{ type: 'text', text: 'request' }], + ...makeMessageSend('request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -303,7 +315,7 @@ describe('AgentSession', () => { }, ]); const { streamId: streamId1 } = await session.send({ - message: [{ type: 'text', text: 'first request' }], + ...makeMessageSend('first request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); @@ -315,7 +327,7 @@ describe('AgentSession', () => { }, ]); await session.send({ - message: [{ type: 'text', text: 'second request' }], + ...makeMessageSend('second request'), }); await new Promise((resolve) => setTimeout(resolve, 10)); diff --git a/packages/core/src/agent/event-translator.test.ts b/packages/core/src/agent/event-translator.test.ts index f40c6c27ad..be9d8ea40e 100644 --- a/packages/core/src/agent/event-translator.test.ts +++ b/packages/core/src/agent/event-translator.test.ts @@ -679,6 +679,7 @@ describe('mapError', () => { expect(result.status).toBe('RESOURCE_EXHAUSTED'); expect(result.message).toBe('Rate limit'); expect(result.fatal).toBe(true); + expect(result._meta?.['status']).toBe(429); expect(result._meta?.['rawError']).toEqual({ message: 'Rate limit', status: 429, diff --git a/packages/core/src/agent/event-translator.ts b/packages/core/src/agent/event-translator.ts index 73f93f4a15..00b5d12b4f 100644 --- a/packages/core/src/agent/event-translator.ts +++ b/packages/core/src/agent/event-translator.ts @@ -403,7 +403,7 @@ export function mapError( } if (isStructuredError(error)) { - const structuredMeta = { ...meta, rawError: error }; + const structuredMeta = { ...meta, rawError: error, status: error.status }; return { status: mapHttpToGrpcStatus(error.status), message: error.message, diff --git a/packages/core/src/agent/legacy-agent-session.test.ts b/packages/core/src/agent/legacy-agent-session.test.ts index 438b1e5ef0..38bea34910 100644 --- a/packages/core/src/agent/legacy-agent-session.test.ts +++ b/packages/core/src/agent/legacy-agent-session.test.ts @@ -10,7 +10,7 @@ import { LegacyAgentSession } from './legacy-agent-session.js'; import type { LegacyAgentSessionDeps } from './legacy-agent-session.js'; import { GeminiEventType } from '../core/turn.js'; import type { ServerGeminiStreamEvent } from '../core/turn.js'; -import type { AgentEvent } from './types.js'; +import type { AgentEvent, AgentSend } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; import type { CompletedToolCall, @@ -72,6 +72,18 @@ function makeToolRequest(callId: string, name: string): ToolCallRequestInfo { }; } +function makeMessageSend( + text: string, + displayContent?: string, +): Extract { + return { + message: { + content: [{ type: 'text', text }], + ...(displayContent ? { displayContent } : {}), + }, + }; +} + function makeCompletedToolCall( callId: string, name: string, @@ -140,9 +152,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const result = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const result = await session.send(makeMessageSend('hi')); expect(result.streamId).toBe('test-stream'); }); @@ -162,7 +172,10 @@ describe('LegacyAgentSession', () => { const session = new LegacyAgentSession(deps); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], + message: { + content: [{ type: 'text', text: 'hi' }], + displayContent: 'raw input', + }, _meta: { source: 'user-test' }, }); @@ -170,8 +183,19 @@ describe('LegacyAgentSession', () => { (e): e is AgentEvent<'message'> => e.type === 'message' && e.role === 'user' && e.streamId === streamId, ); - expect(userMessage?.content).toEqual([{ type: 'text', text: 'hi' }]); + expect(userMessage?.content).toEqual([ + { type: 'text', text: 'raw input' }, + ]); expect(userMessage?._meta).toEqual({ source: 'user-test' }); + await vi.advanceTimersByTimeAsync(0); + expect(sendMock).toHaveBeenCalledWith( + [{ text: 'hi' }], + expect.any(AbortSignal), + 'test-prompt', + undefined, + false, + 'raw input', + ); await collectEvents(session, { streamId: streamId ?? undefined }); }); @@ -195,9 +219,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); expect(streamId).toBe('test-stream'); expect(liveEvents.some((event) => event.type === 'agent_start')).toBe( @@ -235,14 +257,12 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'first' }], - }); + const { streamId } = await session.send(makeMessageSend('first')); await vi.advanceTimersByTimeAsync(0); - await expect( - session.send({ message: [{ type: 'text', text: 'second' }] }), - ).rejects.toThrow('cannot be called while a stream is active'); + await expect(session.send(makeMessageSend('second'))).rejects.toThrow( + 'cannot be called while a stream is active', + ); resolveHang?.(); await collectEvents(session, { streamId: streamId ?? undefined }); @@ -273,16 +293,12 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first' }], - }); + const first = await session.send(makeMessageSend('first')); const firstEvents = await collectEvents(session, { streamId: first.streamId ?? undefined, }); - const second = await session.send({ - message: [{ type: 'text', text: 'second' }], - }); + const second = await session.send(makeMessageSend('second')); const secondEvents = await collectEvents(session, { streamId: second.streamId ?? undefined, }); @@ -330,7 +346,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const types = events.map((e) => e.type); @@ -387,7 +403,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'read a file' }] }); + await session.send(makeMessageSend('read a file')); const events = await collectEvents(session); const types = events.map((e) => e.type); @@ -455,9 +471,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([errorToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'write file' }], - }); + await session.send(makeMessageSend('write file')); const events = await collectEvents(session); const toolResp = events.find( @@ -506,9 +520,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([stopToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'do something' }], - }); + await session.send(makeMessageSend('do something')); const events = await collectEvents(session); const streamEnd = events.find( @@ -552,9 +564,7 @@ describe('LegacyAgentSession', () => { scheduleMock.mockResolvedValueOnce([fatalToolCall]); const session = new LegacyAgentSession(deps); - await session.send({ - message: [{ type: 'text', text: 'write file' }], - }); + await session.send(makeMessageSend('write file')); const events = await collectEvents(session); const toolResp = events.find( @@ -592,7 +602,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const streamEnd = events.find( @@ -621,7 +631,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const blocked = events.find( @@ -663,7 +673,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -690,7 +700,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const warning = events.find( @@ -738,7 +748,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const streamEnd = events.find( @@ -762,7 +772,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const errorEvents = events.filter( @@ -799,9 +809,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await vi.advanceTimersByTimeAsync(0); await session.abort(); @@ -847,7 +855,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); // Give the loop time to start processing await new Promise((r) => setTimeout(r, 50)); @@ -891,9 +899,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await new Promise((resolve) => setTimeout(resolve, 25)); await session.abort(); @@ -935,7 +941,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); await collectEvents(session); expect(session.events.length).toBeGreaterThan(0); @@ -964,9 +970,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], - }); + const { streamId } = await session.send(makeMessageSend('hi')); await collectEvents(session, { streamId: streamId ?? undefined }); unsubscribe(); @@ -1002,9 +1006,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); const liveEvents: AgentEvent[] = []; @@ -1012,9 +1014,7 @@ describe('LegacyAgentSession', () => { liveEvents.push(event); }); - const second = await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + const second = await session.send(makeMessageSend('second request')); await collectEvents(session, { streamId: second.streamId ?? undefined }); unsubscribe(); @@ -1058,14 +1058,10 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); - const second = await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + const second = await session.send(makeMessageSend('second request')); await collectEvents(session, { streamId: second.streamId ?? undefined }); const firstStreamEvents = await collectEvents(session, { @@ -1120,14 +1116,10 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - const first = await session.send({ - message: [{ type: 'text', text: 'first request' }], - }); + const first = await session.send(makeMessageSend('first request')); await collectEvents(session, { streamId: first.streamId ?? undefined }); - await session.send({ - message: [{ type: 'text', text: 'second request' }], - }); + await session.send(makeMessageSend('second request')); await collectEvents(session); const firstAgentMessage = session.events.find( @@ -1175,7 +1167,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); expect(events.length).toBeGreaterThan(0); @@ -1196,7 +1188,7 @@ describe('LegacyAgentSession', () => { ); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); expect(events[events.length - 1]?.type).toBe('agent_end'); @@ -1244,7 +1236,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'do it' }] }); + await session.send(makeMessageSend('do it')); const events = await collectEvents(session); // Only one agent_end at the very end @@ -1291,7 +1283,7 @@ describe('LegacyAgentSession', () => { ]); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'go' }] }); + await session.send(makeMessageSend('go')); const events = await collectEvents(session); // Should have at least one usage event from the intermediate Finished @@ -1314,7 +1306,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1342,7 +1334,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1365,7 +1357,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1385,7 +1377,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( @@ -1405,7 +1397,7 @@ describe('LegacyAgentSession', () => { }); const session = new LegacyAgentSession(deps); - await session.send({ message: [{ type: 'text', text: 'hi' }] }); + await session.send(makeMessageSend('hi')); const events = await collectEvents(session); const err = events.find( diff --git a/packages/core/src/agent/legacy-agent-session.ts b/packages/core/src/agent/legacy-agent-session.ts index d8044e77e3..667c85f5ed 100644 --- a/packages/core/src/agent/legacy-agent-session.ts +++ b/packages/core/src/agent/legacy-agent-session.ts @@ -105,12 +105,16 @@ class LegacyAgentProtocol implements AgentProtocol { this._beginNewStream(); const streamId = this._translationState.streamId; - const parts = contentPartsToGeminiParts(message); - const userMessage = this._makeUserMessageEvent(message, payload._meta); + const parts = contentPartsToGeminiParts(message.content); + const userMessage = this._makeUserMessageEvent( + message.content, + message.displayContent, + payload._meta, + ); this._emit([userMessage]); - this._scheduleRunLoop(parts); + this._scheduleRunLoop(parts, message.displayContent); return { streamId }; } @@ -119,18 +123,24 @@ class LegacyAgentProtocol implements AgentProtocol { this._abortController.abort(); } - private _scheduleRunLoop(initialParts: Part[]): void { + private _scheduleRunLoop( + initialParts: Part[], + displayContent?: string, + ): void { // Use a macrotask so send() resolves with the streamId before agent_start // is emitted and consumers can attach to the stream without racing startup. setTimeout(() => { - void this._runLoopInBackground(initialParts); + void this._runLoopInBackground(initialParts, displayContent); }, 0); } - private async _runLoopInBackground(initialParts: Part[]): Promise { + private async _runLoopInBackground( + initialParts: Part[], + displayContent?: string, + ): Promise { this._ensureAgentStart(); try { - await this._runLoop(initialParts); + await this._runLoop(initialParts, displayContent); } catch (err: unknown) { if (this._abortController.signal.aborted || isAbortLikeError(err)) { this._ensureAgentEnd('aborted'); @@ -141,8 +151,12 @@ class LegacyAgentProtocol implements AgentProtocol { } } - private async _runLoop(initialParts: Part[]): Promise { + private async _runLoop( + initialParts: Part[], + initialDisplayContent?: string, + ): Promise { let currentParts: Part[] = initialParts; + let currentDisplayContent = initialDisplayContent; let turnCount = 0; const maxTurns = this._config.getMaxSessionTurns(); @@ -162,7 +176,11 @@ class LegacyAgentProtocol implements AgentProtocol { currentParts, this._abortController.signal, this._promptId, + undefined, + false, + currentDisplayContent, ); + currentDisplayContent = undefined; for await (const event of responseStream) { if (this._abortController.signal.aborted) { @@ -383,13 +401,17 @@ class LegacyAgentProtocol implements AgentProtocol { private _makeUserMessageEvent( content: ContentPart[], + displayContent?: string, meta?: Record, ): AgentEvent<'message'> { + const eventContent: ContentPart[] = displayContent + ? [{ type: 'text', text: displayContent }] + : content; const event = { ...this._nextEventFields(), type: 'message', role: 'user', - content, + content: eventContent, ...(meta ? { _meta: meta } : {}), } satisfies AgentEvent<'message'>; return event; diff --git a/packages/core/src/agent/mock.test.ts b/packages/core/src/agent/mock.test.ts index f5138e388a..64403008a6 100644 --- a/packages/core/src/agent/mock.test.ts +++ b/packages/core/src/agent/mock.test.ts @@ -34,7 +34,7 @@ describe('MockAgentProtocol', () => { const streamPromise = waitForStreamEnd(session); const { streamId } = await session.send({ - message: [{ type: 'text', text: 'hi' }], + message: { content: [{ type: 'text', text: 'hi' }] }, }); expect(streamId).toBeDefined(); diff --git a/packages/core/src/agent/mock.ts b/packages/core/src/agent/mock.ts index 80d8ebae2f..26ef965d6d 100644 --- a/packages/core/src/agent/mock.ts +++ b/packages/core/src/agent/mock.ts @@ -10,6 +10,7 @@ import type { AgentEventData, AgentProtocol, AgentSend, + ContentPart, Unsubscribe, } from './types.js'; @@ -133,11 +134,17 @@ export class MockAgentProtocol implements AgentProtocol { // 1. User/Update event (BEFORE agent_start) if ('message' in payload && payload.message) { + const message = Array.isArray(payload.message) + ? { content: payload.message, displayContent: undefined } + : payload.message; + const userContent: ContentPart[] = message.displayContent + ? [{ type: 'text', text: message.displayContent }] + : message.content; eventsToEmit.push( normalize({ type: 'message', role: 'user', - content: payload.message, + content: userContent, _meta: payload._meta, }), ); diff --git a/packages/core/src/agent/types.ts b/packages/core/src/agent/types.ts index 4ec369d066..9bc3e81e0f 100644 --- a/packages/core/src/agent/types.ts +++ b/packages/core/src/agent/types.ts @@ -46,7 +46,10 @@ type RequireExactlyOne = { }[keyof T]; interface AgentSendPayloads { - message: ContentPart[]; + message: { + content: ContentPart[]; + displayContent?: string; + }; elicitations: ElicitationResponse[]; update: { title?: string; model?: string; config?: Record }; action: { type: string; data: unknown }; diff --git a/packages/core/src/output/json-formatter.ts b/packages/core/src/output/json-formatter.ts index dd3e558a6f..bce5055a6b 100644 --- a/packages/core/src/output/json-formatter.ts +++ b/packages/core/src/output/json-formatter.ts @@ -6,6 +6,7 @@ import stripAnsi from 'strip-ansi'; import type { SessionMetrics } from '../telemetry/uiTelemetry.js'; +import { getErrorType } from '../utils/errors.js'; import type { JsonError, JsonOutput } from './types.js'; export class JsonFormatter { @@ -42,7 +43,7 @@ export class JsonFormatter { sessionId?: string, ): string { const jsonError: JsonError = { - type: error.constructor.name, + type: getErrorType(error), message: stripAnsi(error.message), ...(code && { code }), };