From 0ccc5ce58f4d7c1ef1ca47d3ae8169399d40cd11 Mon Sep 17 00:00:00 2001 From: Sri Pasumarthi <111310667+sripasg@users.noreply.github.com> Date: Wed, 29 Apr 2026 16:58:16 -0700 Subject: [PATCH] refactor(acp): delegate prompt turn processing logic to GeminiClient (#26222) --- packages/cli/src/acp/acpSession.test.ts | 242 +++++++++++++++------ packages/cli/src/acp/acpSession.ts | 267 +++++++++++++++--------- packages/cli/src/config/config.test.ts | 8 +- packages/cli/src/config/config.ts | 6 + 4 files changed, 354 insertions(+), 169 deletions(-) diff --git a/packages/cli/src/acp/acpSession.test.ts b/packages/cli/src/acp/acpSession.test.ts index 07639108ce..c87c1cc4b4 100644 --- a/packages/cli/src/acp/acpSession.test.ts +++ b/packages/cli/src/acp/acpSession.test.ts @@ -13,21 +13,22 @@ import { afterEach, type Mock, type Mocked, + type MockInstance, } from 'vitest'; import { Session } from './acpSession.js'; import type * as acp from '@agentclientprotocol/sdk'; import { - StreamEventType, ReadManyFilesTool, type GeminiChat, type Config, type MessageBus, - LlmRole, type GitService, - type ModelRouterService, InvalidStreamError, + GeminiEventType, + type ServerGeminiStreamEvent, } from '@google/gemini-cli-core'; import type { LoadedSettings } from '../config/settings.js'; +import { type Part, FinishReason } from '@google/genai'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; import type { CommandHandler } from './acpCommandHandler.js'; @@ -57,11 +58,23 @@ vi.mock( }, ); -// eslint-disable-next-line @typescript-eslint/no-explicit-any -async function* createMockStream(items: any[]) { +async function* createMockStream( + items: readonly ServerGeminiStreamEvent[], +): AsyncGenerator { for (const item of items) { yield item; } + + yield { + type: GeminiEventType.Finished, + value: { + reason: FinishReason.STOP, + usageMetadata: { + promptTokenCount: 5, + candidatesTokenCount: 10, + }, + }, + }; } describe('Session', () => { @@ -72,6 +85,13 @@ describe('Session', () => { let mockToolRegistry: { getTool: Mock }; let mockTool: { kind: string; build: Mock }; let mockMessageBus: Mocked; + let mockSendMessageStream: MockInstance< + ( + request: Part[], + signal: AbortSignal, + promptId: string, + ) => AsyncGenerator + >; beforeEach(() => { mockChat = { @@ -97,6 +117,7 @@ describe('Session', () => { subscribe: vi.fn(), unsubscribe: vi.fn(), } as unknown as Mocked; + mockSendMessageStream = vi.fn(); mockConfig = { getModel: vi.fn().mockReturnValue('gemini-pro'), getActiveModel: vi.fn().mockReturnValue('gemini-pro'), @@ -124,6 +145,11 @@ describe('Session', () => { }), waitForMcpInit: vi.fn(), getDisableAlwaysAllow: vi.fn().mockReturnValue(false), + getMaxSessionTurns: vi.fn().mockReturnValue(-1), + geminiClient: { + sendMessageStream: mockSendMessageStream, + getChat: vi.fn().mockReturnValue(mockChat), + }, get config() { return this; }, @@ -176,11 +202,11 @@ describe('Session', () => { it('should await MCP initialization before processing a prompt', async () => { const stream = createMockStream([ { - type: StreamEventType.CHUNK, - value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] }, + type: GeminiEventType.Content, + value: 'Hi', }, ]); - mockChat.sendMessageStream.mockResolvedValue(stream); + mockSendMessageStream.mockReturnValue(stream); await session.prompt({ sessionId: 'session-1', @@ -193,20 +219,18 @@ describe('Session', () => { it('should handle prompt with text response', async () => { const stream = createMockStream([ { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - }, + type: GeminiEventType.Content, + value: 'Hello', }, ]); - mockChat.sendMessageStream.mockResolvedValue(stream); + mockSendMessageStream.mockReturnValue(stream); const result = await session.prompt({ sessionId: 'session-1', prompt: [{ type: 'text', text: 'Hi' }], }); - expect(mockChat.sendMessageStream).toHaveBeenCalled(); + expect(mockSendMessageStream).toHaveBeenCalled(); expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({ sessionId: 'session-1', update: { @@ -217,41 +241,40 @@ describe('Session', () => { expect(result).toMatchObject({ stopReason: 'end_turn' }); }); - it('should use model router to determine model', async () => { - const mockRouter = { - route: vi.fn().mockResolvedValue({ model: 'routed-model' }), - } as unknown as ModelRouterService; - mockConfig.getModelRouterService.mockReturnValue(mockRouter); - + it('should pass current session information directly onto geminiClient.sendMessageStream', async () => { const stream = createMockStream([ { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - }, + type: GeminiEventType.Content, + value: 'Hello', }, ]); - mockChat.sendMessageStream.mockResolvedValue(stream); + mockSendMessageStream.mockReturnValue(stream); await session.prompt({ sessionId: 'session-1', prompt: [{ type: 'text', text: 'Hi' }], }); - expect(mockRouter.route).toHaveBeenCalled(); - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.objectContaining({ model: 'routed-model' }), - expect.any(Array), - expect.any(String), - expect.any(Object), + expect(mockSendMessageStream).toHaveBeenCalledWith( + expect.arrayContaining([{ text: 'Hi' }]), + expect.any(AbortSignal), expect.any(String), ); }); it('should handle prompt with empty response (InvalidStreamError)', async () => { - mockChat.sendMessageStream.mockRejectedValue( - new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'), - ); + const error = new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'); + mockSendMessageStream.mockImplementation(() => { + async function* errorGen(): AsyncGenerator< + ServerGeminiStreamEvent, + void, + unknown + > { + yield* []; + throw error; + } + return errorGen(); + }); const result = await session.prompt({ sessionId: 'session-1', @@ -262,9 +285,21 @@ describe('Session', () => { }); it('should handle prompt with no finish reason (InvalidStreamError)', async () => { - mockChat.sendMessageStream.mockRejectedValue( - new InvalidStreamError('No finish reason', 'NO_FINISH_REASON'), + const error = new InvalidStreamError( + 'No finish reason', + 'NO_FINISH_REASON', ); + mockSendMessageStream.mockImplementation(() => { + async function* errorGen(): AsyncGenerator< + ServerGeminiStreamEvent, + void, + unknown + > { + yield* []; + throw error; + } + return errorGen(); + }); const result = await session.prompt({ sessionId: 'session-1', @@ -298,24 +333,26 @@ describe('Session', () => { it('should handle tool calls', async () => { const stream1 = createMockStream([ { - type: StreamEventType.CHUNK, + type: GeminiEventType.ToolCallRequest, value: { - functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }], + callId: 'call-1', + name: 'test_tool', + args: { foo: 'bar' }, + isClientInitiated: false, + prompt_id: 'prompt-1', }, }, ]); const stream2 = createMockStream([ { - type: StreamEventType.CHUNK, - value: { - candidates: [{ content: { parts: [{ text: 'Result' }] } }], - }, + type: GeminiEventType.Content, + value: 'Result', }, ]); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); + mockSendMessageStream + .mockReturnValueOnce(stream1) + .mockReturnValueOnce(stream2); const result = await session.prompt({ sessionId: 'session-1', @@ -347,22 +384,26 @@ describe('Session', () => { const stream1 = createMockStream([ { - type: StreamEventType.CHUNK, + type: GeminiEventType.ToolCallRequest, value: { - functionCalls: [{ name: 'test_tool', args: {} }], + callId: 'call-1', + name: 'test_tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', }, }, ]); const stream2 = createMockStream([ { - type: StreamEventType.CHUNK, - value: { candidates: [] }, + type: GeminiEventType.Content, + value: '', }, ]); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); + mockSendMessageStream + .mockReturnValueOnce(stream1) + .mockReturnValueOnce(stream2); await session.prompt({ sessionId: 'session-1', @@ -381,11 +422,11 @@ describe('Session', () => { const stream = createMockStream([ { - type: StreamEventType.CHUNK, - value: { candidates: [] }, + type: GeminiEventType.Content, + value: '', }, ]); - mockChat.sendMessageStream.mockResolvedValue(stream); + mockSendMessageStream.mockReturnValue(stream); await session.prompt({ sessionId: 'session-1', @@ -402,23 +443,33 @@ describe('Session', () => { expect(path.resolve).toHaveBeenCalled(); expect(fs.stat).toHaveBeenCalled(); - expect(mockChat.sendMessageStream).toHaveBeenCalledWith( - expect.anything(), + expect(mockSendMessageStream).toHaveBeenCalledWith( expect.arrayContaining([ expect.objectContaining({ text: expect.stringContaining('Content from @file.txt'), }), ]), - expect.anything(), expect.any(AbortSignal), - LlmRole.MAIN, + expect.any(String), ); }); it('should handle rate limit error', async () => { const error = new Error('Rate limit'); - (error as unknown as { status: number }).status = 429; - mockChat.sendMessageStream.mockRejectedValue(error); + const customError = error as { status?: number; message?: string }; + customError.status = 429; + + mockSendMessageStream.mockImplementation(() => { + async function* errorGen(): AsyncGenerator< + ServerGeminiStreamEvent, + void, + unknown + > { + yield* []; + throw customError; + } + return errorGen(); + }); await expect( session.prompt({ @@ -436,28 +487,81 @@ describe('Session', () => { const stream1 = createMockStream([ { - type: StreamEventType.CHUNK, + type: GeminiEventType.ToolCallRequest, value: { - functionCalls: [{ name: 'unknown_tool', args: {} }], + callId: 'call-1', + name: 'unknown_tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', }, }, ]); const stream2 = createMockStream([ { - type: StreamEventType.CHUNK, - value: { candidates: [] }, + type: GeminiEventType.Content, + value: '', }, ]); - mockChat.sendMessageStream - .mockResolvedValueOnce(stream1) - .mockResolvedValueOnce(stream2); + mockSendMessageStream + .mockReturnValueOnce(stream1) + .mockReturnValueOnce(stream2); await session.prompt({ sessionId: 'session-1', prompt: [{ type: 'text', text: 'Call tool' }], }); - expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2); + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + }); + + it('should handle GeminiEventType.LoopDetected', async () => { + const stream = createMockStream([ + { + type: GeminiEventType.LoopDetected, + }, + ]); + mockSendMessageStream.mockReturnValue(stream); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Trigger Loop Simulation' }], + }); + + expect(result.stopReason).toBe('max_turn_requests'); + }); + + it('should handle GeminiEventType.ContextWindowWillOverflow', async () => { + const stream = createMockStream([ + { + type: GeminiEventType.ContextWindowWillOverflow, + value: { estimatedRequestTokenCount: 1000, remainingTokenCount: 200 }, + }, + ]); + mockSendMessageStream.mockReturnValue(stream); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Trigger Overflow Simulation' }], + }); + + expect(result.stopReason).toBe('max_tokens'); + }); + + it('should handle GeminiEventType.MaxSessionTurns', async () => { + const stream = createMockStream([ + { + type: GeminiEventType.MaxSessionTurns, + }, + ]); + mockSendMessageStream.mockReturnValue(stream); + + const result = await session.prompt({ + sessionId: 'session-1', + prompt: [{ type: 'text', text: 'Trigger Safety Limits' }], + }); + + expect(result.stopReason).toBe('max_turn_requests'); }); }); diff --git a/packages/cli/src/acp/acpSession.ts b/packages/cli/src/acp/acpSession.ts index db0c185007..bcc8a86248 100644 --- a/packages/cli/src/acp/acpSession.ts +++ b/packages/cli/src/acp/acpSession.ts @@ -6,35 +6,34 @@ import { type ApprovalMode, - type GeminiChat, - type ToolResult, type ConversationRecord, CoreToolCallStatus, logToolCall, convertToFunctionResponse, ToolConfirmationOutcome, - isWithinRoot, getErrorStatus, DiscoveredMCPTool, - StreamEventType, ToolCallEvent, debugLogger, ReadManyFilesTool, - REFERENCE_CONTENT_START, - type RoutingContext, partListUnionToString, - LlmRole, - processSingleFileContent, - InvalidStreamError, type AgentLoopContext, updatePolicy, - isNodeError, getErrorMessage, type FilterFilesOptions, isTextPart, + GeminiEventType, + type ToolCallRequestInfo, + type GeminiChat, + type ToolResult, + isWithinRoot, + processSingleFileContent, + isNodeError, + REFERENCE_CONTENT_START, + InvalidStreamError, } from '@google/gemini-cli-core'; import * as acp from '@agentclientprotocol/sdk'; -import type { Content, Part, FunctionCall } from '@google/genai'; +import type { Part, FunctionCall } from '@google/genai'; import type { LoadedSettings } from '../config/settings.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; @@ -50,6 +49,11 @@ import { import { z } from 'zod'; import { getAcpErrorMessage } from './acpErrors.js'; +const StructuredErrorSchema = z.object({ + status: z.number().optional(), + message: z.string().optional(), +}); + export class Session { private pendingPrompt: AbortController | null = null; private commandHandler = new CommandHandler(); @@ -188,7 +192,6 @@ export class Session { await this.context.config.waitForMcpInit(); const promptId = Math.random().toString(16).slice(2); - const chat = this.chat; const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal); @@ -236,100 +239,125 @@ export class Session { let totalOutputTokens = 0; const modelUsageMap = new Map(); - let nextMessage: Content | null = { role: 'user', parts }; + let currentParts: Part[] = parts; + let turnCount = 0; + const maxTurns = this.context.config.getMaxSessionTurns(); - while (nextMessage !== null) { - if (pendingSend.signal.aborted) { - chat.addHistory(nextMessage); - return { stopReason: CoreToolCallStatus.Cancelled }; + while (true) { + turnCount++; + if (maxTurns >= 0 && turnCount > maxTurns) { + return { + stopReason: 'max_turn_requests', + _meta: { + quota: { + token_count: { + input_tokens: totalInputTokens, + output_tokens: totalOutputTokens, + }, + model_usage: Array.from(modelUsageMap.entries()).map( + ([modelName, counts]) => ({ + model: modelName, + token_count: { + input_tokens: counts.input, + output_tokens: counts.output, + }, + }), + ), + }, + }, + }; } - const functionCalls: FunctionCall[] = []; + if (pendingSend.signal.aborted) { + return { stopReason: 'cancelled' }; + } + + const toolCallRequests: ToolCallRequestInfo[] = []; + let stopReason: acp.StopReason = 'end_turn'; + let turnModelId = this.context.config.getModel(); + let turnInputTokens = 0; + let turnOutputTokens = 0; try { - const routingContext: RoutingContext = { - history: chat.getHistory(/*curated=*/ true), - request: nextMessage?.parts ?? [], - signal: pendingSend.signal, - requestedModel: this.context.config.getModel(), - }; - - const router = this.context.config.getModelRouterService(); - const { model } = await router.route(routingContext); - const responseStream = await chat.sendMessageStream( - { model }, - nextMessage?.parts ?? [], - promptId, + const responseStream = this.context.geminiClient.sendMessageStream( + currentParts, pendingSend.signal, - LlmRole.MAIN, + promptId, ); - nextMessage = null; - let turnInputTokens = 0; - let turnOutputTokens = 0; - let turnModelId = model; - - for await (const resp of responseStream) { + for await (const event of responseStream) { if (pendingSend.signal.aborted) { - return { stopReason: CoreToolCallStatus.Cancelled }; + return { stopReason: 'cancelled' }; } - if (resp.type === StreamEventType.CHUNK && resp.value.usageMetadata) { - turnInputTokens = - resp.value.usageMetadata.promptTokenCount ?? turnInputTokens; - turnOutputTokens = - resp.value.usageMetadata.candidatesTokenCount ?? turnOutputTokens; - if (resp.value.modelVersion) { - turnModelId = resp.value.modelVersion; - } - } - - if ( - resp.type === StreamEventType.CHUNK && - resp.value.candidates && - resp.value.candidates.length > 0 - ) { - const candidate = resp.value.candidates[0]; - for (const part of candidate.content?.parts ?? []) { - if (!part.text) { - continue; - } - + switch (event.type) { + case GeminiEventType.Content: { const content: acp.ContentBlock = { type: 'text', - text: part.text, + text: event.value, }; - // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.sendUpdate({ - sessionUpdate: part.thought - ? 'agent_thought_chunk' - : 'agent_message_chunk', + await this.sendUpdate({ + sessionUpdate: 'agent_message_chunk', content, }); + break; } + + case GeminiEventType.Thought: { + const thoughtText = `**${event.value.subject}**\n${event.value.description}`; + await this.sendUpdate({ + sessionUpdate: 'agent_thought_chunk', + content: { type: 'text', text: thoughtText }, + }); + break; + } + + case GeminiEventType.ToolCallRequest: + toolCallRequests.push(event.value); + break; + + case GeminiEventType.Finished: { + const usage = event.value.usageMetadata; + if (usage) { + turnInputTokens = usage.promptTokenCount ?? turnInputTokens; + turnOutputTokens = + usage.candidatesTokenCount ?? turnOutputTokens; + } + break; + } + + case GeminiEventType.ModelInfo: + turnModelId = event.value; + break; + + case GeminiEventType.MaxSessionTurns: + stopReason = 'max_turn_requests'; + break; + + case GeminiEventType.LoopDetected: + stopReason = 'max_turn_requests'; + break; + + case GeminiEventType.ContextWindowWillOverflow: + stopReason = 'max_tokens'; + break; + + case GeminiEventType.Error: { + const parseResult = StructuredErrorSchema.safeParse( + event.value.error, + ); + const errData = parseResult.success ? parseResult.data : {}; + + throw new acp.RequestError( + errData.status ?? 500, + errData.message ?? 'Unknown stream execution error.', + ); + } + + default: + break; } - - if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) { - functionCalls.push(...resp.value.functionCalls); - } - } - - totalInputTokens += turnInputTokens; - totalOutputTokens += turnOutputTokens; - - if (turnInputTokens > 0 || turnOutputTokens > 0) { - const existing = modelUsageMap.get(turnModelId) ?? { - input: 0, - output: 0, - }; - existing.input += turnInputTokens; - existing.output += turnOutputTokens; - modelUsageMap.set(turnModelId, existing); - } - - if (pendingSend.signal.aborted) { - return { stopReason: CoreToolCallStatus.Cancelled }; } } catch (error) { if (getErrorStatus(error) === 429) { @@ -343,7 +371,11 @@ export class Session { pendingSend.signal.aborted || (error instanceof Error && error.name === 'AbortError') ) { - return { stopReason: CoreToolCallStatus.Cancelled }; + return { stopReason: 'cancelled' }; + } + + if (error instanceof acp.RequestError) { + throw error; } if ( @@ -386,16 +418,59 @@ export class Session { ); } - if (functionCalls.length > 0) { - const toolResponseParts: Part[] = []; + totalInputTokens += turnInputTokens; + totalOutputTokens += turnOutputTokens; - for (const fc of functionCalls) { - const response = await this.runTool(pendingSend.signal, promptId, fc); - toolResponseParts.push(...response); - } - - nextMessage = { role: 'user', parts: toolResponseParts }; + if (turnInputTokens > 0 || turnOutputTokens > 0) { + const existing = modelUsageMap.get(turnModelId) ?? { + input: 0, + output: 0, + }; + existing.input += turnInputTokens; + existing.output += turnOutputTokens; + modelUsageMap.set(turnModelId, existing); } + + if (stopReason !== 'end_turn') { + return { + stopReason, + _meta: { + quota: { + token_count: { + input_tokens: totalInputTokens, + output_tokens: totalOutputTokens, + }, + model_usage: Array.from(modelUsageMap.entries()).map( + ([modelName, counts]) => ({ + model: modelName, + token_count: { + input_tokens: counts.input, + output_tokens: counts.output, + }, + }), + ), + }, + }, + }; + } + + if (toolCallRequests.length === 0) { + break; + } + + const toolResponseParts: Part[] = []; + for (const tReq of toolCallRequests) { + const fc: FunctionCall = { + id: tReq.callId, + name: tReq.name, + args: tReq.args, + }; + + const response = await this.runTool(pendingSend.signal, promptId, fc); + toolResponseParts.push(...response); + } + + currentParts = toolResponseParts; } const modelUsageArray = Array.from(modelUsageMap.entries()).map( diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 0ee7a42ec9..312517db56 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -3988,7 +3988,7 @@ describe('loadCliConfig acpMode and clientName', () => { expect(config.getClientName()).toBe('acp-vscode'); }); - it('should set acpMode to true but leave clientName undefined for generic terminals', async () => { + it('should set acpMode to true and set clientName to acp for generic terminals', async () => { process.argv = ['node', 'script.js', '--acp']; vi.stubEnv('TERM_PROGRAM', 'iTerm.app'); // Generic terminal vi.stubEnv('VSCODE_GIT_ASKPASS_MAIN', ''); @@ -4000,10 +4000,10 @@ describe('loadCliConfig acpMode and clientName', () => { argv, ); expect(config.getAcpMode()).toBe(true); - expect(config.getClientName()).toBeUndefined(); + expect(config.getClientName()).toBe('acp'); }); - it('should set acpMode to false and clientName to undefined by default', async () => { + it('should set acpMode to false and clientName to tui by default', async () => { process.argv = ['node', 'script.js']; const argv = await parseArguments(createTestMergedSettings()); const config = await loadCliConfig( @@ -4012,6 +4012,6 @@ describe('loadCliConfig acpMode and clientName', () => { argv, ); expect(config.getAcpMode()).toBe(false); - expect(config.getClientName()).toBeUndefined(); + expect(config.getClientName()).toBe('tui'); }); }); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index bd1f616c8c..7e9ba97bf5 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -931,7 +931,13 @@ export async function loadCliConfig( (ide.name !== 'vscode' || process.env['TERM_PROGRAM'] === 'vscode') ) { clientName = `acp-${ide.name}`; + } else { + clientName = 'acp'; } + } else if (argv.isCommand) { + clientName = 'cli-command'; + } else { + clientName = 'tui'; } // TODO(joshualitt): Clean this up alongside removal of the legacy config.