refactor(acp): delegate prompt turn processing logic to GeminiClient (#26222)

This commit is contained in:
Sri Pasumarthi
2026-04-29 16:58:16 -07:00
committed by GitHub
parent 1834ad0298
commit 0ccc5ce58f
4 changed files with 354 additions and 169 deletions
+173 -69
View File
@@ -13,21 +13,22 @@ import {
afterEach,
type Mock,
type Mocked,
type MockInstance,
} from 'vitest';
import { Session } from './acpSession.js';
import type * as acp from '@agentclientprotocol/sdk';
import {
StreamEventType,
ReadManyFilesTool,
type GeminiChat,
type Config,
type MessageBus,
LlmRole,
type GitService,
type ModelRouterService,
InvalidStreamError,
GeminiEventType,
type ServerGeminiStreamEvent,
} from '@google/gemini-cli-core';
import type { LoadedSettings } from '../config/settings.js';
import { type Part, FinishReason } from '@google/genai';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import type { CommandHandler } from './acpCommandHandler.js';
@@ -57,11 +58,23 @@ vi.mock(
},
);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
async function* createMockStream(items: any[]) {
async function* createMockStream(
items: readonly ServerGeminiStreamEvent[],
): AsyncGenerator<ServerGeminiStreamEvent> {
for (const item of items) {
yield item;
}
yield {
type: GeminiEventType.Finished,
value: {
reason: FinishReason.STOP,
usageMetadata: {
promptTokenCount: 5,
candidatesTokenCount: 10,
},
},
};
}
describe('Session', () => {
@@ -72,6 +85,13 @@ describe('Session', () => {
let mockToolRegistry: { getTool: Mock };
let mockTool: { kind: string; build: Mock };
let mockMessageBus: Mocked<MessageBus>;
let mockSendMessageStream: MockInstance<
(
request: Part[],
signal: AbortSignal,
promptId: string,
) => AsyncGenerator<ServerGeminiStreamEvent>
>;
beforeEach(() => {
mockChat = {
@@ -97,6 +117,7 @@ describe('Session', () => {
subscribe: vi.fn(),
unsubscribe: vi.fn(),
} as unknown as Mocked<MessageBus>;
mockSendMessageStream = vi.fn();
mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'),
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
@@ -124,6 +145,11 @@ describe('Session', () => {
}),
waitForMcpInit: vi.fn(),
getDisableAlwaysAllow: vi.fn().mockReturnValue(false),
getMaxSessionTurns: vi.fn().mockReturnValue(-1),
geminiClient: {
sendMessageStream: mockSendMessageStream,
getChat: vi.fn().mockReturnValue(mockChat),
},
get config() {
return this;
},
@@ -176,11 +202,11 @@ describe('Session', () => {
it('should await MCP initialization before processing a prompt', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] },
type: GeminiEventType.Content,
value: 'Hi',
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
mockSendMessageStream.mockReturnValue(stream);
await session.prompt({
sessionId: 'session-1',
@@ -193,20 +219,18 @@ describe('Session', () => {
it('should handle prompt with text response', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
},
type: GeminiEventType.Content,
value: 'Hello',
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
mockSendMessageStream.mockReturnValue(stream);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Hi' }],
});
expect(mockChat.sendMessageStream).toHaveBeenCalled();
expect(mockSendMessageStream).toHaveBeenCalled();
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({
sessionId: 'session-1',
update: {
@@ -217,41 +241,40 @@ describe('Session', () => {
expect(result).toMatchObject({ stopReason: 'end_turn' });
});
it('should use model router to determine model', async () => {
const mockRouter = {
route: vi.fn().mockResolvedValue({ model: 'routed-model' }),
} as unknown as ModelRouterService;
mockConfig.getModelRouterService.mockReturnValue(mockRouter);
it('should pass current session information directly onto geminiClient.sendMessageStream', async () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
},
type: GeminiEventType.Content,
value: 'Hello',
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
mockSendMessageStream.mockReturnValue(stream);
await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Hi' }],
});
expect(mockRouter.route).toHaveBeenCalled();
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({ model: 'routed-model' }),
expect.any(Array),
expect.any(String),
expect.any(Object),
expect(mockSendMessageStream).toHaveBeenCalledWith(
expect.arrayContaining([{ text: 'Hi' }]),
expect.any(AbortSignal),
expect.any(String),
);
});
it('should handle prompt with empty response (InvalidStreamError)', async () => {
mockChat.sendMessageStream.mockRejectedValue(
new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'),
);
const error = new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT');
mockSendMessageStream.mockImplementation(() => {
async function* errorGen(): AsyncGenerator<
ServerGeminiStreamEvent,
void,
unknown
> {
yield* [];
throw error;
}
return errorGen();
});
const result = await session.prompt({
sessionId: 'session-1',
@@ -262,9 +285,21 @@ describe('Session', () => {
});
it('should handle prompt with no finish reason (InvalidStreamError)', async () => {
mockChat.sendMessageStream.mockRejectedValue(
new InvalidStreamError('No finish reason', 'NO_FINISH_REASON'),
const error = new InvalidStreamError(
'No finish reason',
'NO_FINISH_REASON',
);
mockSendMessageStream.mockImplementation(() => {
async function* errorGen(): AsyncGenerator<
ServerGeminiStreamEvent,
void,
unknown
> {
yield* [];
throw error;
}
return errorGen();
});
const result = await session.prompt({
sessionId: 'session-1',
@@ -298,24 +333,26 @@ describe('Session', () => {
it('should handle tool calls', async () => {
const stream1 = createMockStream([
{
type: StreamEventType.CHUNK,
type: GeminiEventType.ToolCallRequest,
value: {
functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }],
callId: 'call-1',
name: 'test_tool',
args: { foo: 'bar' },
isClientInitiated: false,
prompt_id: 'prompt-1',
},
},
]);
const stream2 = createMockStream([
{
type: StreamEventType.CHUNK,
value: {
candidates: [{ content: { parts: [{ text: 'Result' }] } }],
},
type: GeminiEventType.Content,
value: 'Result',
},
]);
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
mockSendMessageStream
.mockReturnValueOnce(stream1)
.mockReturnValueOnce(stream2);
const result = await session.prompt({
sessionId: 'session-1',
@@ -347,22 +384,26 @@ describe('Session', () => {
const stream1 = createMockStream([
{
type: StreamEventType.CHUNK,
type: GeminiEventType.ToolCallRequest,
value: {
functionCalls: [{ name: 'test_tool', args: {} }],
callId: 'call-1',
name: 'test_tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
},
]);
const stream2 = createMockStream([
{
type: StreamEventType.CHUNK,
value: { candidates: [] },
type: GeminiEventType.Content,
value: '',
},
]);
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
mockSendMessageStream
.mockReturnValueOnce(stream1)
.mockReturnValueOnce(stream2);
await session.prompt({
sessionId: 'session-1',
@@ -381,11 +422,11 @@ describe('Session', () => {
const stream = createMockStream([
{
type: StreamEventType.CHUNK,
value: { candidates: [] },
type: GeminiEventType.Content,
value: '',
},
]);
mockChat.sendMessageStream.mockResolvedValue(stream);
mockSendMessageStream.mockReturnValue(stream);
await session.prompt({
sessionId: 'session-1',
@@ -402,23 +443,33 @@ describe('Session', () => {
expect(path.resolve).toHaveBeenCalled();
expect(fs.stat).toHaveBeenCalled();
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
expect.anything(),
expect(mockSendMessageStream).toHaveBeenCalledWith(
expect.arrayContaining([
expect.objectContaining({
text: expect.stringContaining('Content from @file.txt'),
}),
]),
expect.anything(),
expect.any(AbortSignal),
LlmRole.MAIN,
expect.any(String),
);
});
it('should handle rate limit error', async () => {
const error = new Error('Rate limit');
(error as unknown as { status: number }).status = 429;
mockChat.sendMessageStream.mockRejectedValue(error);
const customError = error as { status?: number; message?: string };
customError.status = 429;
mockSendMessageStream.mockImplementation(() => {
async function* errorGen(): AsyncGenerator<
ServerGeminiStreamEvent,
void,
unknown
> {
yield* [];
throw customError;
}
return errorGen();
});
await expect(
session.prompt({
@@ -436,28 +487,81 @@ describe('Session', () => {
const stream1 = createMockStream([
{
type: StreamEventType.CHUNK,
type: GeminiEventType.ToolCallRequest,
value: {
functionCalls: [{ name: 'unknown_tool', args: {} }],
callId: 'call-1',
name: 'unknown_tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
},
]);
const stream2 = createMockStream([
{
type: StreamEventType.CHUNK,
value: { candidates: [] },
type: GeminiEventType.Content,
value: '',
},
]);
mockChat.sendMessageStream
.mockResolvedValueOnce(stream1)
.mockResolvedValueOnce(stream2);
mockSendMessageStream
.mockReturnValueOnce(stream1)
.mockReturnValueOnce(stream2);
await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Call tool' }],
});
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
});
it('should handle GeminiEventType.LoopDetected', async () => {
const stream = createMockStream([
{
type: GeminiEventType.LoopDetected,
},
]);
mockSendMessageStream.mockReturnValue(stream);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Trigger Loop Simulation' }],
});
expect(result.stopReason).toBe('max_turn_requests');
});
it('should handle GeminiEventType.ContextWindowWillOverflow', async () => {
const stream = createMockStream([
{
type: GeminiEventType.ContextWindowWillOverflow,
value: { estimatedRequestTokenCount: 1000, remainingTokenCount: 200 },
},
]);
mockSendMessageStream.mockReturnValue(stream);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Trigger Overflow Simulation' }],
});
expect(result.stopReason).toBe('max_tokens');
});
it('should handle GeminiEventType.MaxSessionTurns', async () => {
const stream = createMockStream([
{
type: GeminiEventType.MaxSessionTurns,
},
]);
mockSendMessageStream.mockReturnValue(stream);
const result = await session.prompt({
sessionId: 'session-1',
prompt: [{ type: 'text', text: 'Trigger Safety Limits' }],
});
expect(result.stopReason).toBe('max_turn_requests');
});
});
+171 -96
View File
@@ -6,35 +6,34 @@
import {
type ApprovalMode,
type GeminiChat,
type ToolResult,
type ConversationRecord,
CoreToolCallStatus,
logToolCall,
convertToFunctionResponse,
ToolConfirmationOutcome,
isWithinRoot,
getErrorStatus,
DiscoveredMCPTool,
StreamEventType,
ToolCallEvent,
debugLogger,
ReadManyFilesTool,
REFERENCE_CONTENT_START,
type RoutingContext,
partListUnionToString,
LlmRole,
processSingleFileContent,
InvalidStreamError,
type AgentLoopContext,
updatePolicy,
isNodeError,
getErrorMessage,
type FilterFilesOptions,
isTextPart,
GeminiEventType,
type ToolCallRequestInfo,
type GeminiChat,
type ToolResult,
isWithinRoot,
processSingleFileContent,
isNodeError,
REFERENCE_CONTENT_START,
InvalidStreamError,
} from '@google/gemini-cli-core';
import * as acp from '@agentclientprotocol/sdk';
import type { Content, Part, FunctionCall } from '@google/genai';
import type { Part, FunctionCall } from '@google/genai';
import type { LoadedSettings } from '../config/settings.js';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
@@ -50,6 +49,11 @@ import {
import { z } from 'zod';
import { getAcpErrorMessage } from './acpErrors.js';
const StructuredErrorSchema = z.object({
status: z.number().optional(),
message: z.string().optional(),
});
export class Session {
private pendingPrompt: AbortController | null = null;
private commandHandler = new CommandHandler();
@@ -188,7 +192,6 @@ export class Session {
await this.context.config.waitForMcpInit();
const promptId = Math.random().toString(16).slice(2);
const chat = this.chat;
const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal);
@@ -236,100 +239,125 @@ export class Session {
let totalOutputTokens = 0;
const modelUsageMap = new Map<string, { input: number; output: number }>();
let nextMessage: Content | null = { role: 'user', parts };
let currentParts: Part[] = parts;
let turnCount = 0;
const maxTurns = this.context.config.getMaxSessionTurns();
while (nextMessage !== null) {
if (pendingSend.signal.aborted) {
chat.addHistory(nextMessage);
return { stopReason: CoreToolCallStatus.Cancelled };
while (true) {
turnCount++;
if (maxTurns >= 0 && turnCount > maxTurns) {
return {
stopReason: 'max_turn_requests',
_meta: {
quota: {
token_count: {
input_tokens: totalInputTokens,
output_tokens: totalOutputTokens,
},
model_usage: Array.from(modelUsageMap.entries()).map(
([modelName, counts]) => ({
model: modelName,
token_count: {
input_tokens: counts.input,
output_tokens: counts.output,
},
}),
),
},
},
};
}
const functionCalls: FunctionCall[] = [];
if (pendingSend.signal.aborted) {
return { stopReason: 'cancelled' };
}
const toolCallRequests: ToolCallRequestInfo[] = [];
let stopReason: acp.StopReason = 'end_turn';
let turnModelId = this.context.config.getModel();
let turnInputTokens = 0;
let turnOutputTokens = 0;
try {
const routingContext: RoutingContext = {
history: chat.getHistory(/*curated=*/ true),
request: nextMessage?.parts ?? [],
signal: pendingSend.signal,
requestedModel: this.context.config.getModel(),
};
const router = this.context.config.getModelRouterService();
const { model } = await router.route(routingContext);
const responseStream = await chat.sendMessageStream(
{ model },
nextMessage?.parts ?? [],
promptId,
const responseStream = this.context.geminiClient.sendMessageStream(
currentParts,
pendingSend.signal,
LlmRole.MAIN,
promptId,
);
nextMessage = null;
let turnInputTokens = 0;
let turnOutputTokens = 0;
let turnModelId = model;
for await (const resp of responseStream) {
for await (const event of responseStream) {
if (pendingSend.signal.aborted) {
return { stopReason: CoreToolCallStatus.Cancelled };
return { stopReason: 'cancelled' };
}
if (resp.type === StreamEventType.CHUNK && resp.value.usageMetadata) {
turnInputTokens =
resp.value.usageMetadata.promptTokenCount ?? turnInputTokens;
turnOutputTokens =
resp.value.usageMetadata.candidatesTokenCount ?? turnOutputTokens;
if (resp.value.modelVersion) {
turnModelId = resp.value.modelVersion;
}
}
if (
resp.type === StreamEventType.CHUNK &&
resp.value.candidates &&
resp.value.candidates.length > 0
) {
const candidate = resp.value.candidates[0];
for (const part of candidate.content?.parts ?? []) {
if (!part.text) {
continue;
}
switch (event.type) {
case GeminiEventType.Content: {
const content: acp.ContentBlock = {
type: 'text',
text: part.text,
text: event.value,
};
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.sendUpdate({
sessionUpdate: part.thought
? 'agent_thought_chunk'
: 'agent_message_chunk',
await this.sendUpdate({
sessionUpdate: 'agent_message_chunk',
content,
});
break;
}
case GeminiEventType.Thought: {
const thoughtText = `**${event.value.subject}**\n${event.value.description}`;
await this.sendUpdate({
sessionUpdate: 'agent_thought_chunk',
content: { type: 'text', text: thoughtText },
});
break;
}
case GeminiEventType.ToolCallRequest:
toolCallRequests.push(event.value);
break;
case GeminiEventType.Finished: {
const usage = event.value.usageMetadata;
if (usage) {
turnInputTokens = usage.promptTokenCount ?? turnInputTokens;
turnOutputTokens =
usage.candidatesTokenCount ?? turnOutputTokens;
}
break;
}
case GeminiEventType.ModelInfo:
turnModelId = event.value;
break;
case GeminiEventType.MaxSessionTurns:
stopReason = 'max_turn_requests';
break;
case GeminiEventType.LoopDetected:
stopReason = 'max_turn_requests';
break;
case GeminiEventType.ContextWindowWillOverflow:
stopReason = 'max_tokens';
break;
case GeminiEventType.Error: {
const parseResult = StructuredErrorSchema.safeParse(
event.value.error,
);
const errData = parseResult.success ? parseResult.data : {};
throw new acp.RequestError(
errData.status ?? 500,
errData.message ?? 'Unknown stream execution error.',
);
}
default:
break;
}
if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
functionCalls.push(...resp.value.functionCalls);
}
}
totalInputTokens += turnInputTokens;
totalOutputTokens += turnOutputTokens;
if (turnInputTokens > 0 || turnOutputTokens > 0) {
const existing = modelUsageMap.get(turnModelId) ?? {
input: 0,
output: 0,
};
existing.input += turnInputTokens;
existing.output += turnOutputTokens;
modelUsageMap.set(turnModelId, existing);
}
if (pendingSend.signal.aborted) {
return { stopReason: CoreToolCallStatus.Cancelled };
}
} catch (error) {
if (getErrorStatus(error) === 429) {
@@ -343,7 +371,11 @@ export class Session {
pendingSend.signal.aborted ||
(error instanceof Error && error.name === 'AbortError')
) {
return { stopReason: CoreToolCallStatus.Cancelled };
return { stopReason: 'cancelled' };
}
if (error instanceof acp.RequestError) {
throw error;
}
if (
@@ -386,16 +418,59 @@ export class Session {
);
}
if (functionCalls.length > 0) {
const toolResponseParts: Part[] = [];
totalInputTokens += turnInputTokens;
totalOutputTokens += turnOutputTokens;
for (const fc of functionCalls) {
const response = await this.runTool(pendingSend.signal, promptId, fc);
toolResponseParts.push(...response);
}
nextMessage = { role: 'user', parts: toolResponseParts };
if (turnInputTokens > 0 || turnOutputTokens > 0) {
const existing = modelUsageMap.get(turnModelId) ?? {
input: 0,
output: 0,
};
existing.input += turnInputTokens;
existing.output += turnOutputTokens;
modelUsageMap.set(turnModelId, existing);
}
if (stopReason !== 'end_turn') {
return {
stopReason,
_meta: {
quota: {
token_count: {
input_tokens: totalInputTokens,
output_tokens: totalOutputTokens,
},
model_usage: Array.from(modelUsageMap.entries()).map(
([modelName, counts]) => ({
model: modelName,
token_count: {
input_tokens: counts.input,
output_tokens: counts.output,
},
}),
),
},
},
};
}
if (toolCallRequests.length === 0) {
break;
}
const toolResponseParts: Part[] = [];
for (const tReq of toolCallRequests) {
const fc: FunctionCall = {
id: tReq.callId,
name: tReq.name,
args: tReq.args,
};
const response = await this.runTool(pendingSend.signal, promptId, fc);
toolResponseParts.push(...response);
}
currentParts = toolResponseParts;
}
const modelUsageArray = Array.from(modelUsageMap.entries()).map(
+4 -4
View File
@@ -3988,7 +3988,7 @@ describe('loadCliConfig acpMode and clientName', () => {
expect(config.getClientName()).toBe('acp-vscode');
});
it('should set acpMode to true but leave clientName undefined for generic terminals', async () => {
it('should set acpMode to true and set clientName to acp for generic terminals', async () => {
process.argv = ['node', 'script.js', '--acp'];
vi.stubEnv('TERM_PROGRAM', 'iTerm.app'); // Generic terminal
vi.stubEnv('VSCODE_GIT_ASKPASS_MAIN', '');
@@ -4000,10 +4000,10 @@ describe('loadCliConfig acpMode and clientName', () => {
argv,
);
expect(config.getAcpMode()).toBe(true);
expect(config.getClientName()).toBeUndefined();
expect(config.getClientName()).toBe('acp');
});
it('should set acpMode to false and clientName to undefined by default', async () => {
it('should set acpMode to false and clientName to tui by default', async () => {
process.argv = ['node', 'script.js'];
const argv = await parseArguments(createTestMergedSettings());
const config = await loadCliConfig(
@@ -4012,6 +4012,6 @@ describe('loadCliConfig acpMode and clientName', () => {
argv,
);
expect(config.getAcpMode()).toBe(false);
expect(config.getClientName()).toBeUndefined();
expect(config.getClientName()).toBe('tui');
});
});
+6
View File
@@ -931,7 +931,13 @@ export async function loadCliConfig(
(ide.name !== 'vscode' || process.env['TERM_PROGRAM'] === 'vscode')
) {
clientName = `acp-${ide.name}`;
} else {
clientName = 'acp';
}
} else if (argv.isCommand) {
clientName = 'cli-command';
} else {
clientName = 'tui';
}
// TODO(joshualitt): Clean this up alongside removal of the legacy config.