diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 7de8393082..bf90bdc6cc 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -24,11 +24,20 @@ vi.mock('./ui/hooks/atCommandProcessor.js'); vi.mock('@google/gemini-cli-core', async (importOriginal) => { const original = await importOriginal(); + + class MockChatRecordingService { + initialize = vi.fn(); + recordMessage = vi.fn(); + recordMessageTokens = vi.fn(); + recordToolCalls = vi.fn(); + } + return { ...original, executeToolCall: vi.fn(), shutdownTelemetry: vi.fn(), isTelemetrySdkInitialized: vi.fn().mockReturnValue(true), + ChatRecordingService: MockChatRecordingService, }; }); @@ -41,6 +50,7 @@ describe('runNonInteractive', () => { let processStdoutSpy: vi.SpyInstance; let mockGeminiClient: { sendMessageStream: vi.Mock; + getChatRecordingService: vi.Mock; }; beforeEach(async () => { @@ -59,6 +69,12 @@ describe('runNonInteractive', () => { mockGeminiClient = { sendMessageStream: vi.fn(), + getChatRecordingService: vi.fn(() => ({ + initialize: vi.fn(), + recordMessage: vi.fn(), + recordMessageTokens: vi.fn(), + recordToolCalls: vi.fn(), + })), }; mockConfig = { @@ -66,6 +82,11 @@ describe('runNonInteractive', () => { getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getMaxSessionTurns: vi.fn().mockReturnValue(10), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getProjectRoot: vi.fn().mockReturnValue('/test/project'), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/project/.gemini/tmp'), + }, getIdeMode: vi.fn().mockReturnValue(false), getFullContext: vi.fn().mockReturnValue(false), getContentGeneratorConfig: vi.fn().mockReturnValue({}), @@ -97,6 +118,10 @@ describe('runNonInteractive', () => { const events: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Hello' }, { type: GeminiEventType.Content, value: ' World' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, ]; mockGeminiClient.sendMessageStream.mockReturnValue( createStreamFromEvents(events), @@ -132,6 +157,10 @@ describe('runNonInteractive', () => { const firstCallEvents: ServerGeminiStreamEvent[] = [toolCallEvent]; const secondCallEvents: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, ]; mockGeminiClient.sendMessageStream @@ -187,6 +216,10 @@ describe('runNonInteractive', () => { type: GeminiEventType.Content, value: 'Sorry, let me try again.', }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, ]; mockGeminiClient.sendMessageStream .mockReturnValueOnce(createStreamFromEvents([toolCallEvent])) @@ -242,12 +275,17 @@ describe('runNonInteractive', () => { mockCoreExecuteToolCall.mockResolvedValue({ error: new Error('Tool "nonexistentTool" not found in registry.'), resultDisplay: 'Tool "nonexistentTool" not found in registry.', + responseParts: [], }); const finalResponse: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: "Sorry, I can't find that tool.", }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, ]; mockGeminiClient.sendMessageStream @@ -304,6 +342,10 @@ describe('runNonInteractive', () => { // Mock a simple stream response from the Gemini client const events: ServerGeminiStreamEvent[] = [ { type: GeminiEventType.Content, value: 'Summary complete.' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, ]; mockGeminiClient.sendMessageStream.mockReturnValue( createStreamFromEvents(events), diff --git a/packages/cli/src/ui/App.test.tsx b/packages/cli/src/ui/App.test.tsx index e0d5fc9922..2d77e42816 100644 --- a/packages/cli/src/ui/App.test.tsx +++ b/packages/cli/src/ui/App.test.tsx @@ -157,7 +157,22 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => { getProjectRoot: vi.fn(() => opts.targetDir), getEnablePromptCompletion: vi.fn(() => false), getGeminiClient: vi.fn(() => ({ + isInitialized: vi.fn(() => true), getUserTier: vi.fn(), + getChatRecordingService: vi.fn(() => ({ + initialize: vi.fn(), + recordMessage: vi.fn(), + recordMessageTokens: vi.fn(), + recordToolCalls: vi.fn(), + })), + getChat: vi.fn(() => ({ + getChatRecordingService: vi.fn(() => ({ + initialize: vi.fn(), + recordMessage: vi.fn(), + recordMessageTokens: vi.fn(), + recordToolCalls: vi.fn(), + })), + })), })), getCheckpointingEnabled: vi.fn(() => opts.checkpointing ?? true), getAllGeminiMdFilenames: vi.fn(() => ['GEMINI.md']), diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 09dfb713ac..0652d01e45 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -48,6 +48,14 @@ const MockedGeminiClientClass = vi.hoisted(() => this.startChat = mockStartChat; this.sendMessageStream = mockSendMessageStream; this.addHistory = vi.fn(); + this.getChatRecordingService = vi.fn().mockReturnValue({ + recordThought: vi.fn(), + initialize: vi.fn(), + recordMessage: vi.fn(), + recordMessageTokens: vi.fn(), + recordToolCalls: vi.fn(), + getConversationFile: vi.fn(), + }); }), ); @@ -1275,7 +1283,10 @@ describe('useGeminiStream', () => { type: ServerGeminiEventType.Content, value: 'This is a truncated response...', }; - yield { type: ServerGeminiEventType.Finished, value: 'MAX_TOKENS' }; + yield { + type: ServerGeminiEventType.Finished, + value: { reason: 'MAX_TOKENS', usageMetadata: undefined }, + }; })(), ); @@ -1324,7 +1335,10 @@ describe('useGeminiStream', () => { type: ServerGeminiEventType.Content, value: 'Complete response', }; - yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + yield { + type: ServerGeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }; })(), ); @@ -1373,7 +1387,10 @@ describe('useGeminiStream', () => { }; yield { type: ServerGeminiEventType.Finished, - value: 'FINISH_REASON_UNSPECIFIED', + value: { + reason: 'FINISH_REASON_UNSPECIFIED', + usageMetadata: undefined, + }, }; })(), ); @@ -1464,7 +1481,10 @@ describe('useGeminiStream', () => { type: ServerGeminiEventType.Content, value: `Response for ${reason}`, }; - yield { type: ServerGeminiEventType.Finished, value: reason }; + yield { + type: ServerGeminiEventType.Finished, + value: { reason, usageMetadata: undefined }, + }; })(), ); @@ -1579,7 +1599,10 @@ describe('useGeminiStream', () => { type: ServerGeminiEventType.Content, value: 'Some response content', }; - yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + yield { + type: ServerGeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }; })(), ); @@ -1626,7 +1649,10 @@ describe('useGeminiStream', () => { type: ServerGeminiEventType.Content, value: 'New response content', }; - yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + yield { + type: ServerGeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }; })(), ); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 139a39eafc..d09d3fcdea 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -516,7 +516,10 @@ export const useGeminiStream = ( const handleFinishedEvent = useCallback( (event: ServerGeminiFinishedEvent, userMessageTimestamp: number) => { - const finishReason = event.value; + const finishReason = event.value.reason; + if (!finishReason) { + return; + } const finishReasonMessages: Record = { [FinishReason.FINISH_REASON_UNSPECIFIED]: undefined, diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 067174f7a2..ad03b1ac64 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -59,6 +59,7 @@ const mockConfig = { model: 'test-model', authType: 'oauth-personal', }), + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const mockTool = new MockTool({ diff --git a/packages/core/src/code_assist/converter.ts b/packages/core/src/code_assist/converter.ts index 09a3631449..147dcb01ef 100644 --- a/packages/core/src/code_assist/converter.ts +++ b/packages/core/src/code_assist/converter.ts @@ -204,7 +204,7 @@ function toContent(content: ContentUnion): Content { }; } -function toParts(parts: PartUnion[]): Part[] { +export function toParts(parts: PartUnion[]): Part[] { return parts.map(toPart); } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index b7ba47ee97..4a9706c029 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -49,6 +49,32 @@ import { tokenLimit } from './tokenLimits.js'; import { ideContext } from '../ide/ideContext.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; +// Mock fs module to prevent actual file system operations during tests +const mockFileSystem = new Map(); + +vi.mock('node:fs', () => { + const fsModule = { + mkdirSync: vi.fn(), + writeFileSync: vi.fn((path: string, data: string) => { + mockFileSystem.set(path, data); + }), + readFileSync: vi.fn((path: string) => { + if (mockFileSystem.has(path)) { + return mockFileSystem.get(path); + } + throw Object.assign(new Error('ENOENT: no such file or directory'), { + code: 'ENOENT', + }); + }), + existsSync: vi.fn((path: string) => mockFileSystem.has(path)), + }; + + return { + default: fsModule, + ...fsModule, + }; +}); + // --- Mocks --- const mockChatCreateFn = vi.fn(); const mockGenerateContentFn = vi.fn(); @@ -278,6 +304,10 @@ describe('Gemini Client (client.ts)', () => { setFallbackMode: vi.fn(), getChatCompression: vi.fn().mockReturnValue(undefined), getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), + getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }, }; const MockedConfig = vi.mocked(Config, true); MockedConfig.mockImplementation( diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index e07ed56e4c..d00504dc5b 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -30,6 +30,7 @@ import { retryWithBackoff } from '../utils/retry.js'; import { getErrorMessage } from '../utils/errors.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; +import type { ChatRecordingService } from '../services/chatRecordingService.js'; import type { ContentGenerator, ContentGeneratorConfig, @@ -222,6 +223,10 @@ export class GeminiClient { this.chat = await this.startChat(); } + getChatRecordingService(): ChatRecordingService | undefined { + return this.chat?.getChatRecordingService(); + } + async addDirectoryContext(): Promise { if (!this.chat) { return; diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index e4ef0dc8e3..55f6b9f0af 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -168,6 +168,7 @@ describe('CoreToolScheduler', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -201,6 +202,7 @@ describe('CoreToolScheduler', () => { // Create mocked tool registry const mockConfig = { getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const mockToolRegistry = { getAllToolNames: () => ['list_files', 'read_file', 'write_file'], @@ -265,6 +267,7 @@ describe('CoreToolScheduler with payload', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -571,6 +574,7 @@ describe('CoreToolScheduler edit cancellation', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -662,6 +666,7 @@ describe('CoreToolScheduler YOLO mode', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -752,6 +757,7 @@ describe('CoreToolScheduler request queueing', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -868,6 +874,7 @@ describe('CoreToolScheduler request queueing', () => { model: 'test-model', authType: 'oauth-personal', }), + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -948,6 +955,7 @@ describe('CoreToolScheduler request queueing', () => { authType: 'oauth-personal', }), getToolRegistry: () => mockToolRegistry, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -1007,6 +1015,7 @@ describe('CoreToolScheduler request queueing', () => { setApprovalMode: (mode: ApprovalMode) => { approvalMode = mode; }, + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; const testTool = new TestApprovalTool(mockConfig); diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 161f2900e6..156b284237 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -16,6 +16,32 @@ import { GeminiChat, EmptyStreamError } from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; +// Mock fs module to prevent actual file system operations during tests +const mockFileSystem = new Map(); + +vi.mock('node:fs', () => { + const fsModule = { + mkdirSync: vi.fn(), + writeFileSync: vi.fn((path: string, data: string) => { + mockFileSystem.set(path, data); + }), + readFileSync: vi.fn((path: string) => { + if (mockFileSystem.has(path)) { + return mockFileSystem.get(path); + } + throw Object.assign(new Error('ENOENT: no such file or directory'), { + code: 'ENOENT', + }); + }), + existsSync: vi.fn((path: string) => mockFileSystem.has(path)), + }; + + return { + default: fsModule, + ...fsModule, + }; +}); + // Mocks const mockModelsModule = { generateContent: vi.fn(), @@ -59,6 +85,13 @@ describe('GeminiChat', () => { getQuotaErrorOccurred: vi.fn().mockReturnValue(false), setQuotaErrorOccurred: vi.fn(), flashFallbackHandler: undefined, + getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), + storage: { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }, + getToolRegistry: vi.fn().mockReturnValue({ + getTool: vi.fn(), + }), } as unknown as Config; // Disable 429 simulation for tests diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 6eea6f22aa..af08d0c0a4 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -15,6 +15,7 @@ import type { Part, Tool, } from '@google/genai'; +import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import type { ContentGenerator } from './contentGenerator.js'; @@ -23,16 +24,20 @@ import type { Config } from '../config/config.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; +import type { CompletedToolCall } from './coreToolScheduler.js'; import { logContentRetry, logContentRetryFailure, logInvalidChunk, } from '../telemetry/loggers.js'; +import { ChatRecordingService } from '../services/chatRecordingService.js'; import { ContentRetryEvent, ContentRetryFailureEvent, InvalidChunkEvent, } from '../telemetry/types.js'; +import { isFunctionResponse } from '../utils/messageInspectors.js'; +import { partListUnionToString } from './geminiRequest.js'; /** * Options for retrying due to invalid content from the model. @@ -151,6 +156,7 @@ export class GeminiChat { // A promise to represent the current state of the message being sent to the // model. private sendPromise: Promise = Promise.resolve(); + private readonly chatRecordingService: ChatRecordingService; constructor( private readonly config: Config, @@ -159,6 +165,8 @@ export class GeminiChat { private history: Content[] = [], ) { validateHistory(history); + this.chatRecordingService = new ChatRecordingService(config); + this.chatRecordingService.initialize(); } /** @@ -237,6 +245,18 @@ export class GeminiChat { ): Promise { await this.sendPromise; const userContent = createUserContent(params.message); + + // Record user input - capture complete message with all parts (text, files, images, etc.) + // but skip recording function responses (tool call results) as they should be stored in tool call records + if (!isFunctionResponse(userContent)) { + const userMessage = Array.isArray(params.message) + ? params.message + : [params.message]; + this.chatRecordingService.recordMessage({ + type: 'user', + content: userMessage, + }); + } const requestContents = this.getHistory(true).concat(userContent); let response: GenerateContentResponse; @@ -351,6 +371,19 @@ export class GeminiChat { const userContent = createUserContent(params.message); + // Record user input - capture complete message with all parts (text, files, images, etc.) + // but skip recording function responses (tool call results) as they should be stored in tool call records + if (!isFunctionResponse(userContent)) { + const userMessage = Array.isArray(params.message) + ? params.message + : [params.message]; + const userMessageContent = partListUnionToString(toParts(userMessage)); + this.chatRecordingService.recordMessage({ + type: 'user', + content: userMessageContent, + }); + } + // Add user content to history ONCE before any attempts. this.history.push(userContent); const requestContents = this.getHistory(true); @@ -582,10 +615,15 @@ export class GeminiChat { const content = chunk.candidates?.[0]?.content; if (content?.parts) { - modelResponseParts.push(...content.parts); + if (content.parts.some((part) => part.thought)) { + // Record thoughts + this.recordThoughtFromContent(content); + } if (content.parts.some((part) => part.functionCall)) { hasToolCall = true; } + // Always add parts - thoughts will be filtered out later in recordHistory + modelResponseParts.push(...content.parts); } } else { logInvalidChunk( @@ -595,7 +633,13 @@ export class GeminiChat { isStreamInvalid = true; firstInvalidChunkEncountered = true; } - yield chunk; + + // Record token usage if this chunk has usageMetadata + if (chunk.usageMetadata) { + this.chatRecordingService.recordMessageTokens(chunk.usageMetadata); + } + + yield chunk; // Yield every chunk to the UI immediately. } if (!hasReceivedAnyChunk) { @@ -625,6 +669,21 @@ export class GeminiChat { } } + // Record model response text from the collected parts + if (modelResponseParts.length > 0) { + const responseText = modelResponseParts + .filter((part) => part.text && !part.thought) + .map((part) => part.text) + .join(''); + + if (responseText.trim()) { + this.chatRecordingService.recordMessage({ + type: 'gemini', + content: responseText, + }); + } + } + // Bundle all streamed parts into a single Content object const modelOutput: Content[] = modelResponseParts.length > 0 @@ -734,7 +793,64 @@ export class GeminiChat { this.history.push({ role: 'model', parts: [] }); } } + + /** + * Gets the chat recording service instance. + */ + getChatRecordingService(): ChatRecordingService { + return this.chatRecordingService; + } + + /** + * Records completed tool calls with full metadata. + * This is called by external components when tool calls complete, before sending responses to Gemini. + */ + recordCompletedToolCalls(toolCalls: CompletedToolCall[]): void { + const toolCallRecords = toolCalls.map((call) => { + const resultDisplayRaw = call.response?.resultDisplay; + const resultDisplay = + typeof resultDisplayRaw === 'string' ? resultDisplayRaw : undefined; + + return { + id: call.request.callId, + name: call.request.name, + args: call.request.args, + result: call.response?.responseParts || null, + status: call.status as 'error' | 'success' | 'cancelled', + timestamp: new Date().toISOString(), + resultDisplay, + }; + }); + + this.chatRecordingService.recordToolCalls(toolCallRecords); + } + + /** + * Extracts and records thought from thought content. + */ + private recordThoughtFromContent(content: Content): void { + if (!content.parts || content.parts.length === 0) { + return; + } + + const thoughtPart = content.parts[0]; + if (thoughtPart.text) { + // Extract subject and description using the same logic as turn.ts + const rawText = thoughtPart.text; + const subjectStringMatches = rawText.match(/\*\*(.*?)\*\*/s); + const subject = subjectStringMatches + ? subjectStringMatches[1].trim() + : ''; + const description = rawText.replace(/\*\*(.*?)\*\*/s, '').trim(); + + this.chatRecordingService.recordThought({ + subject, + description, + }); + } + } } + /** Visible for Testing */ export function isSchemaDepthError(errorMessage: string): boolean { return errorMessage.includes('maximum schema depth exceeded'); diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 99cd5ab707..698b72a9ad 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -41,6 +41,7 @@ describe('executeToolCall', () => { model: 'test-model', authType: 'oauth-personal', }), + getGeminiClient: () => null, // No client needed for these tests } as unknown as Config; abortController = new AbortController(); diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 74e601e5df..081a4c63eb 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -105,7 +105,21 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Hello' }, + { + type: GeminiEventType.Finished, + value: { + reason: undefined, + usageMetadata: undefined, + }, + }, { type: GeminiEventType.Content, value: ' world' }, + { + type: GeminiEventType.Finished, + value: { + reason: undefined, + usageMetadata: undefined, + }, + }, ]); expect(turn.getDebugResponses().length).toBe(2); }); @@ -135,7 +149,7 @@ describe('Turn', () => { events.push(event); } - expect(events.length).toBe(2); + expect(events.length).toBe(3); const event1 = events[0] as ServerGeminiToolCallRequestEvent; expect(event1.type).toBe(GeminiEventType.ToolCallRequest); expect(event1.value).toEqual( @@ -190,6 +204,13 @@ describe('Turn', () => { } expect(events).toEqual([ { type: GeminiEventType.Content, value: 'First part' }, + { + type: GeminiEventType.Finished, + value: { + reason: undefined, + usageMetadata: undefined, + }, + }, { type: GeminiEventType.UserCancelled }, ]); expect(turn.getDebugResponses().length).toBe(1); @@ -247,7 +268,7 @@ describe('Turn', () => { events.push(event); } - expect(events.length).toBe(3); + expect(events.length).toBe(4); const event1 = events[0] as ServerGeminiToolCallRequestEvent; expect(event1.type).toBe(GeminiEventType.ToolCallRequest); expect(event1.value).toEqual( @@ -295,6 +316,13 @@ describe('Turn', () => { finishReason: 'STOP', }, ], + usageMetadata: { + promptTokenCount: 17, + candidatesTokenCount: 50, + cachedContentTokenCount: 10, + thoughtsTokenCount: 5, + toolUsePromptTokenCount: 2, + }, } as unknown as GenerateContentResponse; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -310,7 +338,19 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Partial response' }, - { type: GeminiEventType.Finished, value: 'STOP' }, + { + type: GeminiEventType.Finished, + value: { + reason: 'STOP', + usageMetadata: { + promptTokenCount: 17, + candidatesTokenCount: 50, + cachedContentTokenCount: 10, + thoughtsTokenCount: 5, + toolUsePromptTokenCount: 2, + }, + }, + }, ]); }); @@ -345,7 +385,10 @@ describe('Turn', () => { type: GeminiEventType.Content, value: 'This is a long response that was cut off...', }, - { type: GeminiEventType.Finished, value: 'MAX_TOKENS' }, + { + type: GeminiEventType.Finished, + value: { reason: 'MAX_TOKENS', usageMetadata: undefined }, + }, ]); }); @@ -373,11 +416,14 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Content blocked' }, - { type: GeminiEventType.Finished, value: 'SAFETY' }, + { + type: GeminiEventType.Finished, + value: { reason: 'SAFETY', usageMetadata: undefined }, + }, ]); }); - it('should not yield finished event when there is no finish reason', async () => { + it('should yield finished event with undefined reason when there is no finish reason', async () => { const mockResponseStream = (async function* () { yield { candidates: [ @@ -404,8 +450,11 @@ describe('Turn', () => { type: GeminiEventType.Content, value: 'Response without finish reason', }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: undefined }, + }, ]); - // No Finished event should be emitted }); it('should handle multiple responses with different finish reasons', async () => { @@ -440,8 +489,18 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'First part' }, + { + type: GeminiEventType.Finished, + value: { + reason: undefined, + usageMetadata: undefined, + }, + }, { type: GeminiEventType.Content, value: 'Second part' }, - { type: GeminiEventType.Finished, value: 'OTHER' }, + { + type: GeminiEventType.Finished, + value: { reason: 'OTHER', usageMetadata: undefined }, + }, ]); }); @@ -480,7 +539,10 @@ describe('Turn', () => { type: GeminiEventType.Citation, value: 'Citations:\n(Source 1 Title) https://example.com/source1', }, - { type: GeminiEventType.Finished, value: 'STOP' }, + { + type: GeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }, ]); }); @@ -524,7 +586,10 @@ describe('Turn', () => { value: 'Citations:\n(Title1) https://example.com/source1\n(Title2) https://example.com/source2', }, - { type: GeminiEventType.Finished, value: 'STOP' }, + { + type: GeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }, ]); }); @@ -559,8 +624,12 @@ describe('Turn', () => { expect(events).toEqual([ { type: GeminiEventType.Content, value: 'Some text.' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: undefined }, + }, ]); - // No Citation or Finished event + // No Citation event (but we do get a Finished event with undefined reason) expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe( false, ); @@ -605,7 +674,10 @@ describe('Turn', () => { type: GeminiEventType.Citation, value: 'Citations:\n(Good Source) https://example.com/source1', }, - { type: GeminiEventType.Finished, value: 'STOP' }, + { + type: GeminiEventType.Finished, + value: { reason: 'STOP', usageMetadata: undefined }, + }, ]); }); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index f21138c7b4..69118d2a9a 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -11,6 +11,7 @@ import type { FunctionCall, FunctionDeclaration, FinishReason, + GenerateContentResponseUsageMetadata, } from '@google/genai'; import type { ToolCallConfirmationDetails, @@ -66,6 +67,11 @@ export interface GeminiErrorEventValue { error: StructuredError; } +export interface GeminiFinishedEventValue { + reason: FinishReason | undefined; + usageMetadata: GenerateContentResponseUsageMetadata | undefined; +} + export interface ToolCallRequestInfo { callId: string; name: string; @@ -157,7 +163,7 @@ export type ServerGeminiMaxSessionTurnsEvent = { export type ServerGeminiFinishedEvent = { type: GeminiEventType.Finished; - value: FinishReason; + value: GeminiFinishedEventValue; }; export type ServerGeminiLoopDetectedEvent = { @@ -272,11 +278,14 @@ export class Turn { } this.finishReason = finishReason; - yield { - type: GeminiEventType.Finished, - value: finishReason as FinishReason, - }; } + yield { + type: GeminiEventType.Finished, + value: { + reason: finishReason ? finishReason : undefined, + usageMetadata: resp.usageMetadata, + }, + }; } } catch (e) { if (signal.aborted) { diff --git a/packages/core/src/services/chatRecordingService.test.ts b/packages/core/src/services/chatRecordingService.test.ts index 82de5f0e7b..50ee04b182 100644 --- a/packages/core/src/services/chatRecordingService.test.ts +++ b/packages/core/src/services/chatRecordingService.test.ts @@ -19,7 +19,14 @@ import { getProjectHash } from '../utils/paths.js'; vi.mock('node:fs'); vi.mock('node:path'); -vi.mock('node:crypto'); +vi.mock('node:crypto', () => ({ + randomUUID: vi.fn(), + createHash: vi.fn(() => ({ + update: vi.fn(() => ({ + digest: vi.fn(() => 'mocked-hash'), + })), + })), +})); vi.mock('../utils/paths.js'); describe('ChatRecordingService', () => { @@ -40,6 +47,13 @@ describe('ChatRecordingService', () => { }, getModel: vi.fn().mockReturnValue('gemini-pro'), getDebugMode: vi.fn().mockReturnValue(false), + getToolRegistry: vi.fn().mockReturnValue({ + getTool: vi.fn().mockReturnValue({ + displayName: 'Test Tool', + description: 'A test tool', + isOutputMarkdown: false, + }), + }), } as unknown as Config; vi.mocked(getProjectHash).mockReturnValue('test-project-hash'); @@ -124,7 +138,7 @@ describe('ChatRecordingService', () => { expect(conversation.messages[0].type).toBe('user'); }); - it('should append to the last message if append is true and types match', () => { + it('should create separate messages when recording multiple messages', () => { const writeFileSyncSpy = vi .spyOn(fs, 'writeFileSync') .mockImplementation(() => undefined); @@ -146,8 +160,7 @@ describe('ChatRecordingService', () => { chatRecordingService.recordMessage({ type: 'user', - content: ' World', - append: true, + content: 'World', }); expect(mkdirSyncSpy).toHaveBeenCalled(); @@ -155,8 +168,9 @@ describe('ChatRecordingService', () => { const conversation = JSON.parse( writeFileSyncSpy.mock.calls[0][1] as string, ) as ConversationRecord; - expect(conversation.messages).toHaveLength(1); - expect(conversation.messages[0].content).toBe('Hello World'); + expect(conversation.messages).toHaveLength(2); + expect(conversation.messages[0].content).toBe('Hello'); + expect(conversation.messages[1].content).toBe('World'); }); }); @@ -204,10 +218,10 @@ describe('ChatRecordingService', () => { ); chatRecordingService.recordMessageTokens({ - input: 1, - output: 2, - total: 3, - cached: 0, + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + cachedContentTokenCount: 0, }); expect(mkdirSyncSpy).toHaveBeenCalled(); @@ -217,7 +231,14 @@ describe('ChatRecordingService', () => { ) as ConversationRecord; expect(conversation.messages[0]).toEqual({ ...initialConversation.messages[0], - tokens: { input: 1, output: 2, total: 3, cached: 0 }, + tokens: { + input: 1, + output: 2, + total: 3, + cached: 0, + thoughts: 0, + tool: 0, + }, }); }); @@ -240,10 +261,10 @@ describe('ChatRecordingService', () => { ); chatRecordingService.recordMessageTokens({ - input: 2, - output: 2, - total: 4, - cached: 0, + promptTokenCount: 2, + candidatesTokenCount: 2, + totalTokenCount: 4, + cachedContentTokenCount: 0, }); // @ts-expect-error private property @@ -252,6 +273,8 @@ describe('ChatRecordingService', () => { output: 2, total: 4, cached: 0, + thoughts: 0, + tool: 0, }); }); }); @@ -297,7 +320,14 @@ describe('ChatRecordingService', () => { ) as ConversationRecord; expect(conversation.messages[0]).toEqual({ ...initialConversation.messages[0], - toolCalls: [toolCall], + toolCalls: [ + { + ...toolCall, + displayName: 'Test Tool', + description: 'A test tool', + renderOutputAsMarkdown: false, + }, + ], }); }); @@ -343,7 +373,14 @@ describe('ChatRecordingService', () => { type: 'gemini', thoughts: [], content: '', - toolCalls: [toolCall], + toolCalls: [ + { + ...toolCall, + displayName: 'Test Tool', + description: 'A test tool', + renderOutputAsMarkdown: false, + }, + ], }); }); }); diff --git a/packages/core/src/services/chatRecordingService.ts b/packages/core/src/services/chatRecordingService.ts index 138c2d810d..6179ae7d08 100644 --- a/packages/core/src/services/chatRecordingService.ts +++ b/packages/core/src/services/chatRecordingService.ts @@ -11,7 +11,10 @@ import { getProjectHash } from '../utils/paths.js'; import path from 'node:path'; import fs from 'node:fs'; import { randomUUID } from 'node:crypto'; -import type { PartListUnion } from '@google/genai'; +import type { + PartListUnion, + GenerateContentResponseUsageMetadata, +} from '@google/genai'; /** * Token usage summary for a message or conversation. @@ -31,7 +34,7 @@ export interface TokensSummary { export interface BaseMessageRecord { id: string; timestamp: string; - content: string; + content: PartListUnion; } /** @@ -178,7 +181,7 @@ export class ChatRecordingService { private newMessage( type: ConversationRecordExtra['type'], - content: string, + content: PartListUnion, ): MessageRecord { return { id: randomUUID(), @@ -193,22 +196,12 @@ export class ChatRecordingService { */ recordMessage(message: { type: ConversationRecordExtra['type']; - content: string; - append?: boolean; + content: PartListUnion; }): void { if (!this.conversationFile) return; try { this.updateConversation((conversation) => { - if (message.append) { - const lastMsg = this.getLastMessage(conversation); - if (lastMsg && lastMsg.type === message.type) { - lastMsg.content += message.content; - return; - } - } - // We're not appending, or we are appending but the last message's type is not the same as - // the specified type, so just create a new message. const msg = this.newMessage(message.type, message.content); if (msg.type === 'gemini') { // If it's a new Gemini message then incorporate any queued thoughts. @@ -243,27 +236,28 @@ export class ChatRecordingService { timestamp: new Date().toISOString(), }); } catch (error) { - if (this.config.getDebugMode()) { - console.error('Error saving thought:', error); - throw error; - } + console.error('Error saving thought:', error); + throw error; } } /** * Updates the tokens for the last message in the conversation (which should be by Gemini). */ - recordMessageTokens(tokens: { - input: number; - output: number; - cached: number; - thoughts?: number; - tool?: number; - total: number; - }): void { + recordMessageTokens( + respUsageMetadata: GenerateContentResponseUsageMetadata, + ): void { if (!this.conversationFile) return; try { + const tokens = { + input: respUsageMetadata.promptTokenCount ?? 0, + output: respUsageMetadata.candidatesTokenCount ?? 0, + cached: respUsageMetadata.cachedContentTokenCount ?? 0, + thoughts: respUsageMetadata.thoughtsTokenCount ?? 0, + tool: respUsageMetadata.toolUsePromptTokenCount ?? 0, + total: respUsageMetadata.totalTokenCount ?? 0, + }; this.updateConversation((conversation) => { const lastMsg = this.getLastMessage(conversation); // If the last message already has token info, it's because this new token info is for a @@ -283,10 +277,23 @@ export class ChatRecordingService { /** * Adds tool calls to the last message in the conversation (which should be by Gemini). + * This method enriches tool calls with metadata from the ToolRegistry. */ recordToolCalls(toolCalls: ToolCallRecord[]): void { if (!this.conversationFile) return; + // Enrich tool calls with metadata from the ToolRegistry + const toolRegistry = this.config.getToolRegistry(); + const enrichedToolCalls = toolCalls.map((toolCall) => { + const toolInstance = toolRegistry.getTool(toolCall.name); + return { + ...toolCall, + displayName: toolInstance?.displayName || toolCall.name, + description: toolInstance?.description || '', + renderOutputAsMarkdown: toolInstance?.isOutputMarkdown || false, + }; + }); + try { this.updateConversation((conversation) => { const lastMsg = this.getLastMessage(conversation); @@ -309,7 +316,7 @@ export class ChatRecordingService { // resulting message's type, and so it thinks that toolCalls may // not be present. Confirming the type here satisfies it. type: 'gemini' as const, - toolCalls, + toolCalls: enrichedToolCalls, thoughts: this.queuedThoughts, model: this.config.getModel(), }; @@ -346,7 +353,7 @@ export class ChatRecordingService { }); // Add any new tools calls that aren't in the message yet. - for (const toolCall of toolCalls) { + for (const toolCall of enrichedToolCalls) { const existingToolCall = lastMsg.toolCalls.find( (tc) => tc.id === toolCall.id, ); diff --git a/packages/core/src/utils/nextSpeakerChecker.test.ts b/packages/core/src/utils/nextSpeakerChecker.test.ts index 3314ca3c84..b9e861998e 100644 --- a/packages/core/src/utils/nextSpeakerChecker.test.ts +++ b/packages/core/src/utils/nextSpeakerChecker.test.ts @@ -14,6 +14,32 @@ import type { NextSpeakerResponse } from './nextSpeakerChecker.js'; import { checkNextSpeaker } from './nextSpeakerChecker.js'; import { GeminiChat } from '../core/geminiChat.js'; +// Mock fs module to prevent actual file system operations during tests +const mockFileSystem = new Map(); + +vi.mock('node:fs', () => { + const fsModule = { + mkdirSync: vi.fn(), + writeFileSync: vi.fn((path: string, data: string) => { + mockFileSystem.set(path, data); + }), + readFileSync: vi.fn((path: string) => { + if (mockFileSystem.has(path)) { + return mockFileSystem.get(path); + } + throw Object.assign(new Error('ENOENT: no such file or directory'), { + code: 'ENOENT', + }); + }), + existsSync: vi.fn((path: string) => mockFileSystem.has(path)), + }; + + return { + default: fsModule, + ...fsModule, + }; +}); + // Mock GeminiClient and Config constructor vi.mock('../core/client.js'); vi.mock('../config/config.js'); @@ -64,6 +90,17 @@ describe('checkNextSpeaker', () => { undefined, ); + // Mock the methods that ChatRecordingService needs + mockConfigInstance.getSessionId = vi + .fn() + .mockReturnValue('test-session-id'); + mockConfigInstance.getProjectRoot = vi + .fn() + .mockReturnValue('/test/project/root'); + mockConfigInstance.storage = { + getProjectTempDir: vi.fn().mockReturnValue('/test/temp'), + }; + mockGeminiClient = new GeminiClient(mockConfigInstance); // Reset mocks before each test to ensure test isolation