mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 23:14:32 -07:00
refactor(acp): delegate prompt turn processing logic to GeminiClient (#26222)
This commit is contained in:
@@ -13,21 +13,22 @@ import {
|
||||
afterEach,
|
||||
type Mock,
|
||||
type Mocked,
|
||||
type MockInstance,
|
||||
} from 'vitest';
|
||||
import { Session } from './acpSession.js';
|
||||
import type * as acp from '@agentclientprotocol/sdk';
|
||||
import {
|
||||
StreamEventType,
|
||||
ReadManyFilesTool,
|
||||
type GeminiChat,
|
||||
type Config,
|
||||
type MessageBus,
|
||||
LlmRole,
|
||||
type GitService,
|
||||
type ModelRouterService,
|
||||
InvalidStreamError,
|
||||
GeminiEventType,
|
||||
type ServerGeminiStreamEvent,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { LoadedSettings } from '../config/settings.js';
|
||||
import { type Part, FinishReason } from '@google/genai';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type { CommandHandler } from './acpCommandHandler.js';
|
||||
@@ -57,11 +58,23 @@ vi.mock(
|
||||
},
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
async function* createMockStream(items: any[]) {
|
||||
async function* createMockStream(
|
||||
items: readonly ServerGeminiStreamEvent[],
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
for (const item of items) {
|
||||
yield item;
|
||||
}
|
||||
|
||||
yield {
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: FinishReason.STOP,
|
||||
usageMetadata: {
|
||||
promptTokenCount: 5,
|
||||
candidatesTokenCount: 10,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe('Session', () => {
|
||||
@@ -72,6 +85,13 @@ describe('Session', () => {
|
||||
let mockToolRegistry: { getTool: Mock };
|
||||
let mockTool: { kind: string; build: Mock };
|
||||
let mockMessageBus: Mocked<MessageBus>;
|
||||
let mockSendMessageStream: MockInstance<
|
||||
(
|
||||
request: Part[],
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
) => AsyncGenerator<ServerGeminiStreamEvent>
|
||||
>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockChat = {
|
||||
@@ -97,6 +117,7 @@ describe('Session', () => {
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
} as unknown as Mocked<MessageBus>;
|
||||
mockSendMessageStream = vi.fn();
|
||||
mockConfig = {
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
@@ -124,6 +145,11 @@ describe('Session', () => {
|
||||
}),
|
||||
waitForMcpInit: vi.fn(),
|
||||
getDisableAlwaysAllow: vi.fn().mockReturnValue(false),
|
||||
getMaxSessionTurns: vi.fn().mockReturnValue(-1),
|
||||
geminiClient: {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getChat: vi.fn().mockReturnValue(mockChat),
|
||||
},
|
||||
get config() {
|
||||
return this;
|
||||
},
|
||||
@@ -176,11 +202,11 @@ describe('Session', () => {
|
||||
it('should await MCP initialization before processing a prompt', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: { candidates: [{ content: { parts: [{ text: 'Hi' }] } }] },
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Hi',
|
||||
},
|
||||
]);
|
||||
mockChat.sendMessageStream.mockResolvedValue(stream);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -193,20 +219,18 @@ describe('Session', () => {
|
||||
it('should handle prompt with text response', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||
},
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Hello',
|
||||
},
|
||||
]);
|
||||
mockChat.sendMessageStream.mockResolvedValue(stream);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Hi' }],
|
||||
});
|
||||
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalled();
|
||||
expect(mockConnection.sessionUpdate).toHaveBeenCalledWith({
|
||||
sessionId: 'session-1',
|
||||
update: {
|
||||
@@ -217,41 +241,40 @@ describe('Session', () => {
|
||||
expect(result).toMatchObject({ stopReason: 'end_turn' });
|
||||
});
|
||||
|
||||
it('should use model router to determine model', async () => {
|
||||
const mockRouter = {
|
||||
route: vi.fn().mockResolvedValue({ model: 'routed-model' }),
|
||||
} as unknown as ModelRouterService;
|
||||
mockConfig.getModelRouterService.mockReturnValue(mockRouter);
|
||||
|
||||
it('should pass current session information directly onto geminiClient.sendMessageStream', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||
},
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Hello',
|
||||
},
|
||||
]);
|
||||
mockChat.sendMessageStream.mockResolvedValue(stream);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Hi' }],
|
||||
});
|
||||
|
||||
expect(mockRouter.route).toHaveBeenCalled();
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'routed-model' }),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(Object),
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([{ text: 'Hi' }]),
|
||||
expect.any(AbortSignal),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle prompt with empty response (InvalidStreamError)', async () => {
|
||||
mockChat.sendMessageStream.mockRejectedValue(
|
||||
new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT'),
|
||||
);
|
||||
const error = new InvalidStreamError('Empty response', 'NO_RESPONSE_TEXT');
|
||||
mockSendMessageStream.mockImplementation(() => {
|
||||
async function* errorGen(): AsyncGenerator<
|
||||
ServerGeminiStreamEvent,
|
||||
void,
|
||||
unknown
|
||||
> {
|
||||
yield* [];
|
||||
throw error;
|
||||
}
|
||||
return errorGen();
|
||||
});
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -262,9 +285,21 @@ describe('Session', () => {
|
||||
});
|
||||
|
||||
it('should handle prompt with no finish reason (InvalidStreamError)', async () => {
|
||||
mockChat.sendMessageStream.mockRejectedValue(
|
||||
new InvalidStreamError('No finish reason', 'NO_FINISH_REASON'),
|
||||
const error = new InvalidStreamError(
|
||||
'No finish reason',
|
||||
'NO_FINISH_REASON',
|
||||
);
|
||||
mockSendMessageStream.mockImplementation(() => {
|
||||
async function* errorGen(): AsyncGenerator<
|
||||
ServerGeminiStreamEvent,
|
||||
void,
|
||||
unknown
|
||||
> {
|
||||
yield* [];
|
||||
throw error;
|
||||
}
|
||||
return errorGen();
|
||||
});
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -298,24 +333,26 @@ describe('Session', () => {
|
||||
it('should handle tool calls', async () => {
|
||||
const stream1 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
functionCalls: [{ name: 'test_tool', args: { foo: 'bar' } }],
|
||||
callId: 'call-1',
|
||||
name: 'test_tool',
|
||||
args: { foo: 'bar' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
},
|
||||
},
|
||||
]);
|
||||
const stream2 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Result' }] } }],
|
||||
},
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Result',
|
||||
},
|
||||
]);
|
||||
|
||||
mockChat.sendMessageStream
|
||||
.mockResolvedValueOnce(stream1)
|
||||
.mockResolvedValueOnce(stream2);
|
||||
mockSendMessageStream
|
||||
.mockReturnValueOnce(stream1)
|
||||
.mockReturnValueOnce(stream2);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -347,22 +384,26 @@ describe('Session', () => {
|
||||
|
||||
const stream1 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
functionCalls: [{ name: 'test_tool', args: {} }],
|
||||
callId: 'call-1',
|
||||
name: 'test_tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
},
|
||||
},
|
||||
]);
|
||||
const stream2 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: { candidates: [] },
|
||||
type: GeminiEventType.Content,
|
||||
value: '',
|
||||
},
|
||||
]);
|
||||
|
||||
mockChat.sendMessageStream
|
||||
.mockResolvedValueOnce(stream1)
|
||||
.mockResolvedValueOnce(stream2);
|
||||
mockSendMessageStream
|
||||
.mockReturnValueOnce(stream1)
|
||||
.mockReturnValueOnce(stream2);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -381,11 +422,11 @@ describe('Session', () => {
|
||||
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: { candidates: [] },
|
||||
type: GeminiEventType.Content,
|
||||
value: '',
|
||||
},
|
||||
]);
|
||||
mockChat.sendMessageStream.mockResolvedValue(stream);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
@@ -402,23 +443,33 @@ describe('Session', () => {
|
||||
|
||||
expect(path.resolve).toHaveBeenCalled();
|
||||
expect(fs.stat).toHaveBeenCalled();
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
text: expect.stringContaining('Content from @file.txt'),
|
||||
}),
|
||||
]),
|
||||
expect.anything(),
|
||||
expect.any(AbortSignal),
|
||||
LlmRole.MAIN,
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle rate limit error', async () => {
|
||||
const error = new Error('Rate limit');
|
||||
(error as unknown as { status: number }).status = 429;
|
||||
mockChat.sendMessageStream.mockRejectedValue(error);
|
||||
const customError = error as { status?: number; message?: string };
|
||||
customError.status = 429;
|
||||
|
||||
mockSendMessageStream.mockImplementation(() => {
|
||||
async function* errorGen(): AsyncGenerator<
|
||||
ServerGeminiStreamEvent,
|
||||
void,
|
||||
unknown
|
||||
> {
|
||||
yield* [];
|
||||
throw customError;
|
||||
}
|
||||
return errorGen();
|
||||
});
|
||||
|
||||
await expect(
|
||||
session.prompt({
|
||||
@@ -436,28 +487,81 @@ describe('Session', () => {
|
||||
|
||||
const stream1 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
functionCalls: [{ name: 'unknown_tool', args: {} }],
|
||||
callId: 'call-1',
|
||||
name: 'unknown_tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
},
|
||||
},
|
||||
]);
|
||||
const stream2 = createMockStream([
|
||||
{
|
||||
type: StreamEventType.CHUNK,
|
||||
value: { candidates: [] },
|
||||
type: GeminiEventType.Content,
|
||||
value: '',
|
||||
},
|
||||
]);
|
||||
|
||||
mockChat.sendMessageStream
|
||||
.mockResolvedValueOnce(stream1)
|
||||
.mockResolvedValueOnce(stream2);
|
||||
mockSendMessageStream
|
||||
.mockReturnValueOnce(stream1)
|
||||
.mockReturnValueOnce(stream2);
|
||||
|
||||
await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Call tool' }],
|
||||
});
|
||||
|
||||
expect(mockChat.sendMessageStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle GeminiEventType.LoopDetected', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: GeminiEventType.LoopDetected,
|
||||
},
|
||||
]);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Trigger Loop Simulation' }],
|
||||
});
|
||||
|
||||
expect(result.stopReason).toBe('max_turn_requests');
|
||||
});
|
||||
|
||||
it('should handle GeminiEventType.ContextWindowWillOverflow', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: { estimatedRequestTokenCount: 1000, remainingTokenCount: 200 },
|
||||
},
|
||||
]);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Trigger Overflow Simulation' }],
|
||||
});
|
||||
|
||||
expect(result.stopReason).toBe('max_tokens');
|
||||
});
|
||||
|
||||
it('should handle GeminiEventType.MaxSessionTurns', async () => {
|
||||
const stream = createMockStream([
|
||||
{
|
||||
type: GeminiEventType.MaxSessionTurns,
|
||||
},
|
||||
]);
|
||||
mockSendMessageStream.mockReturnValue(stream);
|
||||
|
||||
const result = await session.prompt({
|
||||
sessionId: 'session-1',
|
||||
prompt: [{ type: 'text', text: 'Trigger Safety Limits' }],
|
||||
});
|
||||
|
||||
expect(result.stopReason).toBe('max_turn_requests');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,35 +6,34 @@
|
||||
|
||||
import {
|
||||
type ApprovalMode,
|
||||
type GeminiChat,
|
||||
type ToolResult,
|
||||
type ConversationRecord,
|
||||
CoreToolCallStatus,
|
||||
logToolCall,
|
||||
convertToFunctionResponse,
|
||||
ToolConfirmationOutcome,
|
||||
isWithinRoot,
|
||||
getErrorStatus,
|
||||
DiscoveredMCPTool,
|
||||
StreamEventType,
|
||||
ToolCallEvent,
|
||||
debugLogger,
|
||||
ReadManyFilesTool,
|
||||
REFERENCE_CONTENT_START,
|
||||
type RoutingContext,
|
||||
partListUnionToString,
|
||||
LlmRole,
|
||||
processSingleFileContent,
|
||||
InvalidStreamError,
|
||||
type AgentLoopContext,
|
||||
updatePolicy,
|
||||
isNodeError,
|
||||
getErrorMessage,
|
||||
type FilterFilesOptions,
|
||||
isTextPart,
|
||||
GeminiEventType,
|
||||
type ToolCallRequestInfo,
|
||||
type GeminiChat,
|
||||
type ToolResult,
|
||||
isWithinRoot,
|
||||
processSingleFileContent,
|
||||
isNodeError,
|
||||
REFERENCE_CONTENT_START,
|
||||
InvalidStreamError,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as acp from '@agentclientprotocol/sdk';
|
||||
import type { Content, Part, FunctionCall } from '@google/genai';
|
||||
import type { Part, FunctionCall } from '@google/genai';
|
||||
import type { LoadedSettings } from '../config/settings.js';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
@@ -50,6 +49,11 @@ import {
|
||||
import { z } from 'zod';
|
||||
import { getAcpErrorMessage } from './acpErrors.js';
|
||||
|
||||
const StructuredErrorSchema = z.object({
|
||||
status: z.number().optional(),
|
||||
message: z.string().optional(),
|
||||
});
|
||||
|
||||
export class Session {
|
||||
private pendingPrompt: AbortController | null = null;
|
||||
private commandHandler = new CommandHandler();
|
||||
@@ -188,7 +192,6 @@ export class Session {
|
||||
await this.context.config.waitForMcpInit();
|
||||
|
||||
const promptId = Math.random().toString(16).slice(2);
|
||||
const chat = this.chat;
|
||||
|
||||
const parts = await this.#resolvePrompt(params.prompt, pendingSend.signal);
|
||||
|
||||
@@ -236,100 +239,125 @@ export class Session {
|
||||
let totalOutputTokens = 0;
|
||||
const modelUsageMap = new Map<string, { input: number; output: number }>();
|
||||
|
||||
let nextMessage: Content | null = { role: 'user', parts };
|
||||
let currentParts: Part[] = parts;
|
||||
let turnCount = 0;
|
||||
const maxTurns = this.context.config.getMaxSessionTurns();
|
||||
|
||||
while (nextMessage !== null) {
|
||||
if (pendingSend.signal.aborted) {
|
||||
chat.addHistory(nextMessage);
|
||||
return { stopReason: CoreToolCallStatus.Cancelled };
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (maxTurns >= 0 && turnCount > maxTurns) {
|
||||
return {
|
||||
stopReason: 'max_turn_requests',
|
||||
_meta: {
|
||||
quota: {
|
||||
token_count: {
|
||||
input_tokens: totalInputTokens,
|
||||
output_tokens: totalOutputTokens,
|
||||
},
|
||||
model_usage: Array.from(modelUsageMap.entries()).map(
|
||||
([modelName, counts]) => ({
|
||||
model: modelName,
|
||||
token_count: {
|
||||
input_tokens: counts.input,
|
||||
output_tokens: counts.output,
|
||||
},
|
||||
}),
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const functionCalls: FunctionCall[] = [];
|
||||
if (pendingSend.signal.aborted) {
|
||||
return { stopReason: 'cancelled' };
|
||||
}
|
||||
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
let stopReason: acp.StopReason = 'end_turn';
|
||||
let turnModelId = this.context.config.getModel();
|
||||
let turnInputTokens = 0;
|
||||
let turnOutputTokens = 0;
|
||||
|
||||
try {
|
||||
const routingContext: RoutingContext = {
|
||||
history: chat.getHistory(/*curated=*/ true),
|
||||
request: nextMessage?.parts ?? [],
|
||||
signal: pendingSend.signal,
|
||||
requestedModel: this.context.config.getModel(),
|
||||
};
|
||||
|
||||
const router = this.context.config.getModelRouterService();
|
||||
const { model } = await router.route(routingContext);
|
||||
const responseStream = await chat.sendMessageStream(
|
||||
{ model },
|
||||
nextMessage?.parts ?? [],
|
||||
promptId,
|
||||
const responseStream = this.context.geminiClient.sendMessageStream(
|
||||
currentParts,
|
||||
pendingSend.signal,
|
||||
LlmRole.MAIN,
|
||||
promptId,
|
||||
);
|
||||
nextMessage = null;
|
||||
|
||||
let turnInputTokens = 0;
|
||||
let turnOutputTokens = 0;
|
||||
let turnModelId = model;
|
||||
|
||||
for await (const resp of responseStream) {
|
||||
for await (const event of responseStream) {
|
||||
if (pendingSend.signal.aborted) {
|
||||
return { stopReason: CoreToolCallStatus.Cancelled };
|
||||
return { stopReason: 'cancelled' };
|
||||
}
|
||||
|
||||
if (resp.type === StreamEventType.CHUNK && resp.value.usageMetadata) {
|
||||
turnInputTokens =
|
||||
resp.value.usageMetadata.promptTokenCount ?? turnInputTokens;
|
||||
turnOutputTokens =
|
||||
resp.value.usageMetadata.candidatesTokenCount ?? turnOutputTokens;
|
||||
if (resp.value.modelVersion) {
|
||||
turnModelId = resp.value.modelVersion;
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
resp.type === StreamEventType.CHUNK &&
|
||||
resp.value.candidates &&
|
||||
resp.value.candidates.length > 0
|
||||
) {
|
||||
const candidate = resp.value.candidates[0];
|
||||
for (const part of candidate.content?.parts ?? []) {
|
||||
if (!part.text) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (event.type) {
|
||||
case GeminiEventType.Content: {
|
||||
const content: acp.ContentBlock = {
|
||||
type: 'text',
|
||||
text: part.text,
|
||||
text: event.value,
|
||||
};
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-floating-promises
|
||||
this.sendUpdate({
|
||||
sessionUpdate: part.thought
|
||||
? 'agent_thought_chunk'
|
||||
: 'agent_message_chunk',
|
||||
await this.sendUpdate({
|
||||
sessionUpdate: 'agent_message_chunk',
|
||||
content,
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case GeminiEventType.Thought: {
|
||||
const thoughtText = `**${event.value.subject}**\n${event.value.description}`;
|
||||
await this.sendUpdate({
|
||||
sessionUpdate: 'agent_thought_chunk',
|
||||
content: { type: 'text', text: thoughtText },
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case GeminiEventType.ToolCallRequest:
|
||||
toolCallRequests.push(event.value);
|
||||
break;
|
||||
|
||||
case GeminiEventType.Finished: {
|
||||
const usage = event.value.usageMetadata;
|
||||
if (usage) {
|
||||
turnInputTokens = usage.promptTokenCount ?? turnInputTokens;
|
||||
turnOutputTokens =
|
||||
usage.candidatesTokenCount ?? turnOutputTokens;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case GeminiEventType.ModelInfo:
|
||||
turnModelId = event.value;
|
||||
break;
|
||||
|
||||
case GeminiEventType.MaxSessionTurns:
|
||||
stopReason = 'max_turn_requests';
|
||||
break;
|
||||
|
||||
case GeminiEventType.LoopDetected:
|
||||
stopReason = 'max_turn_requests';
|
||||
break;
|
||||
|
||||
case GeminiEventType.ContextWindowWillOverflow:
|
||||
stopReason = 'max_tokens';
|
||||
break;
|
||||
|
||||
case GeminiEventType.Error: {
|
||||
const parseResult = StructuredErrorSchema.safeParse(
|
||||
event.value.error,
|
||||
);
|
||||
const errData = parseResult.success ? parseResult.data : {};
|
||||
|
||||
throw new acp.RequestError(
|
||||
errData.status ?? 500,
|
||||
errData.message ?? 'Unknown stream execution error.',
|
||||
);
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
|
||||
functionCalls.push(...resp.value.functionCalls);
|
||||
}
|
||||
}
|
||||
|
||||
totalInputTokens += turnInputTokens;
|
||||
totalOutputTokens += turnOutputTokens;
|
||||
|
||||
if (turnInputTokens > 0 || turnOutputTokens > 0) {
|
||||
const existing = modelUsageMap.get(turnModelId) ?? {
|
||||
input: 0,
|
||||
output: 0,
|
||||
};
|
||||
existing.input += turnInputTokens;
|
||||
existing.output += turnOutputTokens;
|
||||
modelUsageMap.set(turnModelId, existing);
|
||||
}
|
||||
|
||||
if (pendingSend.signal.aborted) {
|
||||
return { stopReason: CoreToolCallStatus.Cancelled };
|
||||
}
|
||||
} catch (error) {
|
||||
if (getErrorStatus(error) === 429) {
|
||||
@@ -343,7 +371,11 @@ export class Session {
|
||||
pendingSend.signal.aborted ||
|
||||
(error instanceof Error && error.name === 'AbortError')
|
||||
) {
|
||||
return { stopReason: CoreToolCallStatus.Cancelled };
|
||||
return { stopReason: 'cancelled' };
|
||||
}
|
||||
|
||||
if (error instanceof acp.RequestError) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -386,16 +418,59 @@ export class Session {
|
||||
);
|
||||
}
|
||||
|
||||
if (functionCalls.length > 0) {
|
||||
const toolResponseParts: Part[] = [];
|
||||
totalInputTokens += turnInputTokens;
|
||||
totalOutputTokens += turnOutputTokens;
|
||||
|
||||
for (const fc of functionCalls) {
|
||||
const response = await this.runTool(pendingSend.signal, promptId, fc);
|
||||
toolResponseParts.push(...response);
|
||||
}
|
||||
|
||||
nextMessage = { role: 'user', parts: toolResponseParts };
|
||||
if (turnInputTokens > 0 || turnOutputTokens > 0) {
|
||||
const existing = modelUsageMap.get(turnModelId) ?? {
|
||||
input: 0,
|
||||
output: 0,
|
||||
};
|
||||
existing.input += turnInputTokens;
|
||||
existing.output += turnOutputTokens;
|
||||
modelUsageMap.set(turnModelId, existing);
|
||||
}
|
||||
|
||||
if (stopReason !== 'end_turn') {
|
||||
return {
|
||||
stopReason,
|
||||
_meta: {
|
||||
quota: {
|
||||
token_count: {
|
||||
input_tokens: totalInputTokens,
|
||||
output_tokens: totalOutputTokens,
|
||||
},
|
||||
model_usage: Array.from(modelUsageMap.entries()).map(
|
||||
([modelName, counts]) => ({
|
||||
model: modelName,
|
||||
token_count: {
|
||||
input_tokens: counts.input,
|
||||
output_tokens: counts.output,
|
||||
},
|
||||
}),
|
||||
),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (toolCallRequests.length === 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const tReq of toolCallRequests) {
|
||||
const fc: FunctionCall = {
|
||||
id: tReq.callId,
|
||||
name: tReq.name,
|
||||
args: tReq.args,
|
||||
};
|
||||
|
||||
const response = await this.runTool(pendingSend.signal, promptId, fc);
|
||||
toolResponseParts.push(...response);
|
||||
}
|
||||
|
||||
currentParts = toolResponseParts;
|
||||
}
|
||||
|
||||
const modelUsageArray = Array.from(modelUsageMap.entries()).map(
|
||||
|
||||
@@ -3988,7 +3988,7 @@ describe('loadCliConfig acpMode and clientName', () => {
|
||||
expect(config.getClientName()).toBe('acp-vscode');
|
||||
});
|
||||
|
||||
it('should set acpMode to true but leave clientName undefined for generic terminals', async () => {
|
||||
it('should set acpMode to true and set clientName to acp for generic terminals', async () => {
|
||||
process.argv = ['node', 'script.js', '--acp'];
|
||||
vi.stubEnv('TERM_PROGRAM', 'iTerm.app'); // Generic terminal
|
||||
vi.stubEnv('VSCODE_GIT_ASKPASS_MAIN', '');
|
||||
@@ -4000,10 +4000,10 @@ describe('loadCliConfig acpMode and clientName', () => {
|
||||
argv,
|
||||
);
|
||||
expect(config.getAcpMode()).toBe(true);
|
||||
expect(config.getClientName()).toBeUndefined();
|
||||
expect(config.getClientName()).toBe('acp');
|
||||
});
|
||||
|
||||
it('should set acpMode to false and clientName to undefined by default', async () => {
|
||||
it('should set acpMode to false and clientName to tui by default', async () => {
|
||||
process.argv = ['node', 'script.js'];
|
||||
const argv = await parseArguments(createTestMergedSettings());
|
||||
const config = await loadCliConfig(
|
||||
@@ -4012,6 +4012,6 @@ describe('loadCliConfig acpMode and clientName', () => {
|
||||
argv,
|
||||
);
|
||||
expect(config.getAcpMode()).toBe(false);
|
||||
expect(config.getClientName()).toBeUndefined();
|
||||
expect(config.getClientName()).toBe('tui');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -931,7 +931,13 @@ export async function loadCliConfig(
|
||||
(ide.name !== 'vscode' || process.env['TERM_PROGRAM'] === 'vscode')
|
||||
) {
|
||||
clientName = `acp-${ide.name}`;
|
||||
} else {
|
||||
clientName = 'acp';
|
||||
}
|
||||
} else if (argv.isCommand) {
|
||||
clientName = 'cli-command';
|
||||
} else {
|
||||
clientName = 'tui';
|
||||
}
|
||||
|
||||
// TODO(joshualitt): Clean this up alongside removal of the legacy config.
|
||||
|
||||
Reference in New Issue
Block a user