mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-18 09:11:55 -07:00
feat(sessions): Integrate chat recording into GeminiChat (#6721)
This commit is contained in:
@@ -204,7 +204,7 @@ function toContent(content: ContentUnion): Content {
|
||||
};
|
||||
}
|
||||
|
||||
function toParts(parts: PartUnion[]): Part[] {
|
||||
export function toParts(parts: PartUnion[]): Part[] {
|
||||
return parts.map(toPart);
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,32 @@ import { tokenLimit } from './tokenLimits.js';
|
||||
import { ideContext } from '../ide/ideContext.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
|
||||
vi.mock('node:fs', () => {
|
||||
const fsModule = {
|
||||
mkdirSync: vi.fn(),
|
||||
writeFileSync: vi.fn((path: string, data: string) => {
|
||||
mockFileSystem.set(path, data);
|
||||
}),
|
||||
readFileSync: vi.fn((path: string) => {
|
||||
if (mockFileSystem.has(path)) {
|
||||
return mockFileSystem.get(path);
|
||||
}
|
||||
throw Object.assign(new Error('ENOENT: no such file or directory'), {
|
||||
code: 'ENOENT',
|
||||
});
|
||||
}),
|
||||
existsSync: vi.fn((path: string) => mockFileSystem.has(path)),
|
||||
};
|
||||
|
||||
return {
|
||||
default: fsModule,
|
||||
...fsModule,
|
||||
};
|
||||
});
|
||||
|
||||
// --- Mocks ---
|
||||
const mockChatCreateFn = vi.fn();
|
||||
const mockGenerateContentFn = vi.fn();
|
||||
@@ -278,6 +304,10 @@ describe('Gemini Client (client.ts)', () => {
|
||||
setFallbackMode: vi.fn(),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
};
|
||||
const MockedConfig = vi.mocked(Config, true);
|
||||
MockedConfig.mockImplementation(
|
||||
|
||||
@@ -30,6 +30,7 @@ import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import type {
|
||||
ContentGenerator,
|
||||
ContentGeneratorConfig,
|
||||
@@ -222,6 +223,10 @@ export class GeminiClient {
|
||||
this.chat = await this.startChat();
|
||||
}
|
||||
|
||||
getChatRecordingService(): ChatRecordingService | undefined {
|
||||
return this.chat?.getChatRecordingService();
|
||||
}
|
||||
|
||||
async addDirectoryContext(): Promise<void> {
|
||||
if (!this.chat) {
|
||||
return;
|
||||
|
||||
@@ -168,6 +168,7 @@ describe('CoreToolScheduler', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -201,6 +202,7 @@ describe('CoreToolScheduler', () => {
|
||||
// Create mocked tool registry
|
||||
const mockConfig = {
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
const mockToolRegistry = {
|
||||
getAllToolNames: () => ['list_files', 'read_file', 'write_file'],
|
||||
@@ -265,6 +267,7 @@ describe('CoreToolScheduler with payload', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -571,6 +574,7 @@ describe('CoreToolScheduler edit cancellation', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -662,6 +666,7 @@ describe('CoreToolScheduler YOLO mode', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -752,6 +757,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -868,6 +874,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -948,6 +955,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
@@ -1007,6 +1015,7 @@ describe('CoreToolScheduler request queueing', () => {
|
||||
setApprovalMode: (mode: ApprovalMode) => {
|
||||
approvalMode = mode;
|
||||
},
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
const testTool = new TestApprovalTool(mockConfig);
|
||||
|
||||
@@ -16,6 +16,32 @@ import { GeminiChat, EmptyStreamError } from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
|
||||
vi.mock('node:fs', () => {
|
||||
const fsModule = {
|
||||
mkdirSync: vi.fn(),
|
||||
writeFileSync: vi.fn((path: string, data: string) => {
|
||||
mockFileSystem.set(path, data);
|
||||
}),
|
||||
readFileSync: vi.fn((path: string) => {
|
||||
if (mockFileSystem.has(path)) {
|
||||
return mockFileSystem.get(path);
|
||||
}
|
||||
throw Object.assign(new Error('ENOENT: no such file or directory'), {
|
||||
code: 'ENOENT',
|
||||
});
|
||||
}),
|
||||
existsSync: vi.fn((path: string) => mockFileSystem.has(path)),
|
||||
};
|
||||
|
||||
return {
|
||||
default: fsModule,
|
||||
...fsModule,
|
||||
};
|
||||
});
|
||||
|
||||
// Mocks
|
||||
const mockModelsModule = {
|
||||
generateContent: vi.fn(),
|
||||
@@ -59,6 +85,13 @@ describe('GeminiChat', () => {
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
flashFallbackHandler: undefined,
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn(),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
|
||||
@@ -15,6 +15,7 @@ import type {
|
||||
Part,
|
||||
Tool,
|
||||
} from '@google/genai';
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
@@ -23,16 +24,20 @@ import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { hasCycleInSchema } from '../tools/tools.js';
|
||||
import type { StructuredError } from './turn.js';
|
||||
import type { CompletedToolCall } from './coreToolScheduler.js';
|
||||
import {
|
||||
logContentRetry,
|
||||
logContentRetryFailure,
|
||||
logInvalidChunk,
|
||||
} from '../telemetry/loggers.js';
|
||||
import { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import {
|
||||
ContentRetryEvent,
|
||||
ContentRetryFailureEvent,
|
||||
InvalidChunkEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
|
||||
/**
|
||||
* Options for retrying due to invalid content from the model.
|
||||
@@ -151,6 +156,7 @@ export class GeminiChat {
|
||||
// A promise to represent the current state of the message being sent to the
|
||||
// model.
|
||||
private sendPromise: Promise<void> = Promise.resolve();
|
||||
private readonly chatRecordingService: ChatRecordingService;
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
@@ -159,6 +165,8 @@ export class GeminiChat {
|
||||
private history: Content[] = [],
|
||||
) {
|
||||
validateHistory(history);
|
||||
this.chatRecordingService = new ChatRecordingService(config);
|
||||
this.chatRecordingService.initialize();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -237,6 +245,18 @@ export class GeminiChat {
|
||||
): Promise<GenerateContentResponse> {
|
||||
await this.sendPromise;
|
||||
const userContent = createUserContent(params.message);
|
||||
|
||||
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
||||
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
||||
if (!isFunctionResponse(userContent)) {
|
||||
const userMessage = Array.isArray(params.message)
|
||||
? params.message
|
||||
: [params.message];
|
||||
this.chatRecordingService.recordMessage({
|
||||
type: 'user',
|
||||
content: userMessage,
|
||||
});
|
||||
}
|
||||
const requestContents = this.getHistory(true).concat(userContent);
|
||||
|
||||
let response: GenerateContentResponse;
|
||||
@@ -351,6 +371,19 @@ export class GeminiChat {
|
||||
|
||||
const userContent = createUserContent(params.message);
|
||||
|
||||
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
||||
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
||||
if (!isFunctionResponse(userContent)) {
|
||||
const userMessage = Array.isArray(params.message)
|
||||
? params.message
|
||||
: [params.message];
|
||||
const userMessageContent = partListUnionToString(toParts(userMessage));
|
||||
this.chatRecordingService.recordMessage({
|
||||
type: 'user',
|
||||
content: userMessageContent,
|
||||
});
|
||||
}
|
||||
|
||||
// Add user content to history ONCE before any attempts.
|
||||
this.history.push(userContent);
|
||||
const requestContents = this.getHistory(true);
|
||||
@@ -582,10 +615,15 @@ export class GeminiChat {
|
||||
|
||||
const content = chunk.candidates?.[0]?.content;
|
||||
if (content?.parts) {
|
||||
modelResponseParts.push(...content.parts);
|
||||
if (content.parts.some((part) => part.thought)) {
|
||||
// Record thoughts
|
||||
this.recordThoughtFromContent(content);
|
||||
}
|
||||
if (content.parts.some((part) => part.functionCall)) {
|
||||
hasToolCall = true;
|
||||
}
|
||||
// Always add parts - thoughts will be filtered out later in recordHistory
|
||||
modelResponseParts.push(...content.parts);
|
||||
}
|
||||
} else {
|
||||
logInvalidChunk(
|
||||
@@ -595,7 +633,13 @@ export class GeminiChat {
|
||||
isStreamInvalid = true;
|
||||
firstInvalidChunkEncountered = true;
|
||||
}
|
||||
yield chunk;
|
||||
|
||||
// Record token usage if this chunk has usageMetadata
|
||||
if (chunk.usageMetadata) {
|
||||
this.chatRecordingService.recordMessageTokens(chunk.usageMetadata);
|
||||
}
|
||||
|
||||
yield chunk; // Yield every chunk to the UI immediately.
|
||||
}
|
||||
|
||||
if (!hasReceivedAnyChunk) {
|
||||
@@ -625,6 +669,21 @@ export class GeminiChat {
|
||||
}
|
||||
}
|
||||
|
||||
// Record model response text from the collected parts
|
||||
if (modelResponseParts.length > 0) {
|
||||
const responseText = modelResponseParts
|
||||
.filter((part) => part.text && !part.thought)
|
||||
.map((part) => part.text)
|
||||
.join('');
|
||||
|
||||
if (responseText.trim()) {
|
||||
this.chatRecordingService.recordMessage({
|
||||
type: 'gemini',
|
||||
content: responseText,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Bundle all streamed parts into a single Content object
|
||||
const modelOutput: Content[] =
|
||||
modelResponseParts.length > 0
|
||||
@@ -734,7 +793,64 @@ export class GeminiChat {
|
||||
this.history.push({ role: 'model', parts: [] });
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the chat recording service instance.
|
||||
*/
|
||||
getChatRecordingService(): ChatRecordingService {
|
||||
return this.chatRecordingService;
|
||||
}
|
||||
|
||||
/**
|
||||
* Records completed tool calls with full metadata.
|
||||
* This is called by external components when tool calls complete, before sending responses to Gemini.
|
||||
*/
|
||||
recordCompletedToolCalls(toolCalls: CompletedToolCall[]): void {
|
||||
const toolCallRecords = toolCalls.map((call) => {
|
||||
const resultDisplayRaw = call.response?.resultDisplay;
|
||||
const resultDisplay =
|
||||
typeof resultDisplayRaw === 'string' ? resultDisplayRaw : undefined;
|
||||
|
||||
return {
|
||||
id: call.request.callId,
|
||||
name: call.request.name,
|
||||
args: call.request.args,
|
||||
result: call.response?.responseParts || null,
|
||||
status: call.status as 'error' | 'success' | 'cancelled',
|
||||
timestamp: new Date().toISOString(),
|
||||
resultDisplay,
|
||||
};
|
||||
});
|
||||
|
||||
this.chatRecordingService.recordToolCalls(toolCallRecords);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts and records thought from thought content.
|
||||
*/
|
||||
private recordThoughtFromContent(content: Content): void {
|
||||
if (!content.parts || content.parts.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const thoughtPart = content.parts[0];
|
||||
if (thoughtPart.text) {
|
||||
// Extract subject and description using the same logic as turn.ts
|
||||
const rawText = thoughtPart.text;
|
||||
const subjectStringMatches = rawText.match(/\*\*(.*?)\*\*/s);
|
||||
const subject = subjectStringMatches
|
||||
? subjectStringMatches[1].trim()
|
||||
: '';
|
||||
const description = rawText.replace(/\*\*(.*?)\*\*/s, '').trim();
|
||||
|
||||
this.chatRecordingService.recordThought({
|
||||
subject,
|
||||
description,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Visible for Testing */
|
||||
export function isSchemaDepthError(errorMessage: string): boolean {
|
||||
return errorMessage.includes('maximum schema depth exceeded');
|
||||
|
||||
@@ -41,6 +41,7 @@ describe('executeToolCall', () => {
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getGeminiClient: () => null, // No client needed for these tests
|
||||
} as unknown as Config;
|
||||
|
||||
abortController = new AbortController();
|
||||
|
||||
@@ -105,7 +105,21 @@ describe('Turn', () => {
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: undefined,
|
||||
usageMetadata: undefined,
|
||||
},
|
||||
},
|
||||
{ type: GeminiEventType.Content, value: ' world' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: undefined,
|
||||
usageMetadata: undefined,
|
||||
},
|
||||
},
|
||||
]);
|
||||
expect(turn.getDebugResponses().length).toBe(2);
|
||||
});
|
||||
@@ -135,7 +149,7 @@ describe('Turn', () => {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(2);
|
||||
expect(events.length).toBe(3);
|
||||
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event1.value).toEqual(
|
||||
@@ -190,6 +204,13 @@ describe('Turn', () => {
|
||||
}
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'First part' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: undefined,
|
||||
usageMetadata: undefined,
|
||||
},
|
||||
},
|
||||
{ type: GeminiEventType.UserCancelled },
|
||||
]);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
@@ -247,7 +268,7 @@ describe('Turn', () => {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(3);
|
||||
expect(events.length).toBe(4);
|
||||
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event1.value).toEqual(
|
||||
@@ -295,6 +316,13 @@ describe('Turn', () => {
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 17,
|
||||
candidatesTokenCount: 50,
|
||||
cachedContentTokenCount: 10,
|
||||
thoughtsTokenCount: 5,
|
||||
toolUsePromptTokenCount: 2,
|
||||
},
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
@@ -310,7 +338,19 @@ describe('Turn', () => {
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Partial response' },
|
||||
{ type: GeminiEventType.Finished, value: 'STOP' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: 'STOP',
|
||||
usageMetadata: {
|
||||
promptTokenCount: 17,
|
||||
candidatesTokenCount: 50,
|
||||
cachedContentTokenCount: 10,
|
||||
thoughtsTokenCount: 5,
|
||||
toolUsePromptTokenCount: 2,
|
||||
},
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -345,7 +385,10 @@ describe('Turn', () => {
|
||||
type: GeminiEventType.Content,
|
||||
value: 'This is a long response that was cut off...',
|
||||
},
|
||||
{ type: GeminiEventType.Finished, value: 'MAX_TOKENS' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'MAX_TOKENS', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -373,11 +416,14 @@ describe('Turn', () => {
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Content blocked' },
|
||||
{ type: GeminiEventType.Finished, value: 'SAFETY' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'SAFETY', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should not yield finished event when there is no finish reason', async () => {
|
||||
it('should yield finished event with undefined reason when there is no finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
@@ -404,8 +450,11 @@ describe('Turn', () => {
|
||||
type: GeminiEventType.Content,
|
||||
value: 'Response without finish reason',
|
||||
},
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: undefined, usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
// No Finished event should be emitted
|
||||
});
|
||||
|
||||
it('should handle multiple responses with different finish reasons', async () => {
|
||||
@@ -440,8 +489,18 @@ describe('Turn', () => {
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'First part' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: undefined,
|
||||
usageMetadata: undefined,
|
||||
},
|
||||
},
|
||||
{ type: GeminiEventType.Content, value: 'Second part' },
|
||||
{ type: GeminiEventType.Finished, value: 'OTHER' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'OTHER', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -480,7 +539,10 @@ describe('Turn', () => {
|
||||
type: GeminiEventType.Citation,
|
||||
value: 'Citations:\n(Source 1 Title) https://example.com/source1',
|
||||
},
|
||||
{ type: GeminiEventType.Finished, value: 'STOP' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'STOP', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -524,7 +586,10 @@ describe('Turn', () => {
|
||||
value:
|
||||
'Citations:\n(Title1) https://example.com/source1\n(Title2) https://example.com/source2',
|
||||
},
|
||||
{ type: GeminiEventType.Finished, value: 'STOP' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'STOP', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -559,8 +624,12 @@ describe('Turn', () => {
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Some text.' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: undefined, usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
// No Citation or Finished event
|
||||
// No Citation event (but we do get a Finished event with undefined reason)
|
||||
expect(events.some((e) => e.type === GeminiEventType.Citation)).toBe(
|
||||
false,
|
||||
);
|
||||
@@ -605,7 +674,10 @@ describe('Turn', () => {
|
||||
type: GeminiEventType.Citation,
|
||||
value: 'Citations:\n(Good Source) https://example.com/source1',
|
||||
},
|
||||
{ type: GeminiEventType.Finished, value: 'STOP' },
|
||||
{
|
||||
type: GeminiEventType.Finished,
|
||||
value: { reason: 'STOP', usageMetadata: undefined },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
FinishReason,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
@@ -66,6 +67,11 @@ export interface GeminiErrorEventValue {
|
||||
error: StructuredError;
|
||||
}
|
||||
|
||||
export interface GeminiFinishedEventValue {
|
||||
reason: FinishReason | undefined;
|
||||
usageMetadata: GenerateContentResponseUsageMetadata | undefined;
|
||||
}
|
||||
|
||||
export interface ToolCallRequestInfo {
|
||||
callId: string;
|
||||
name: string;
|
||||
@@ -157,7 +163,7 @@ export type ServerGeminiMaxSessionTurnsEvent = {
|
||||
|
||||
export type ServerGeminiFinishedEvent = {
|
||||
type: GeminiEventType.Finished;
|
||||
value: FinishReason;
|
||||
value: GeminiFinishedEventValue;
|
||||
};
|
||||
|
||||
export type ServerGeminiLoopDetectedEvent = {
|
||||
@@ -272,11 +278,14 @@ export class Turn {
|
||||
}
|
||||
|
||||
this.finishReason = finishReason;
|
||||
yield {
|
||||
type: GeminiEventType.Finished,
|
||||
value: finishReason as FinishReason,
|
||||
};
|
||||
}
|
||||
yield {
|
||||
type: GeminiEventType.Finished,
|
||||
value: {
|
||||
reason: finishReason ? finishReason : undefined,
|
||||
usageMetadata: resp.usageMetadata,
|
||||
},
|
||||
};
|
||||
}
|
||||
} catch (e) {
|
||||
if (signal.aborted) {
|
||||
|
||||
@@ -19,7 +19,14 @@ import { getProjectHash } from '../utils/paths.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
vi.mock('node:path');
|
||||
vi.mock('node:crypto');
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomUUID: vi.fn(),
|
||||
createHash: vi.fn(() => ({
|
||||
update: vi.fn(() => ({
|
||||
digest: vi.fn(() => 'mocked-hash'),
|
||||
})),
|
||||
})),
|
||||
}));
|
||||
vi.mock('../utils/paths.js');
|
||||
|
||||
describe('ChatRecordingService', () => {
|
||||
@@ -40,6 +47,13 @@ describe('ChatRecordingService', () => {
|
||||
},
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
getToolRegistry: vi.fn().mockReturnValue({
|
||||
getTool: vi.fn().mockReturnValue({
|
||||
displayName: 'Test Tool',
|
||||
description: 'A test tool',
|
||||
isOutputMarkdown: false,
|
||||
}),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
vi.mocked(getProjectHash).mockReturnValue('test-project-hash');
|
||||
@@ -124,7 +138,7 @@ describe('ChatRecordingService', () => {
|
||||
expect(conversation.messages[0].type).toBe('user');
|
||||
});
|
||||
|
||||
it('should append to the last message if append is true and types match', () => {
|
||||
it('should create separate messages when recording multiple messages', () => {
|
||||
const writeFileSyncSpy = vi
|
||||
.spyOn(fs, 'writeFileSync')
|
||||
.mockImplementation(() => undefined);
|
||||
@@ -146,8 +160,7 @@ describe('ChatRecordingService', () => {
|
||||
|
||||
chatRecordingService.recordMessage({
|
||||
type: 'user',
|
||||
content: ' World',
|
||||
append: true,
|
||||
content: 'World',
|
||||
});
|
||||
|
||||
expect(mkdirSyncSpy).toHaveBeenCalled();
|
||||
@@ -155,8 +168,9 @@ describe('ChatRecordingService', () => {
|
||||
const conversation = JSON.parse(
|
||||
writeFileSyncSpy.mock.calls[0][1] as string,
|
||||
) as ConversationRecord;
|
||||
expect(conversation.messages).toHaveLength(1);
|
||||
expect(conversation.messages[0].content).toBe('Hello World');
|
||||
expect(conversation.messages).toHaveLength(2);
|
||||
expect(conversation.messages[0].content).toBe('Hello');
|
||||
expect(conversation.messages[1].content).toBe('World');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -204,10 +218,10 @@ describe('ChatRecordingService', () => {
|
||||
);
|
||||
|
||||
chatRecordingService.recordMessageTokens({
|
||||
input: 1,
|
||||
output: 2,
|
||||
total: 3,
|
||||
cached: 0,
|
||||
promptTokenCount: 1,
|
||||
candidatesTokenCount: 2,
|
||||
totalTokenCount: 3,
|
||||
cachedContentTokenCount: 0,
|
||||
});
|
||||
|
||||
expect(mkdirSyncSpy).toHaveBeenCalled();
|
||||
@@ -217,7 +231,14 @@ describe('ChatRecordingService', () => {
|
||||
) as ConversationRecord;
|
||||
expect(conversation.messages[0]).toEqual({
|
||||
...initialConversation.messages[0],
|
||||
tokens: { input: 1, output: 2, total: 3, cached: 0 },
|
||||
tokens: {
|
||||
input: 1,
|
||||
output: 2,
|
||||
total: 3,
|
||||
cached: 0,
|
||||
thoughts: 0,
|
||||
tool: 0,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -240,10 +261,10 @@ describe('ChatRecordingService', () => {
|
||||
);
|
||||
|
||||
chatRecordingService.recordMessageTokens({
|
||||
input: 2,
|
||||
output: 2,
|
||||
total: 4,
|
||||
cached: 0,
|
||||
promptTokenCount: 2,
|
||||
candidatesTokenCount: 2,
|
||||
totalTokenCount: 4,
|
||||
cachedContentTokenCount: 0,
|
||||
});
|
||||
|
||||
// @ts-expect-error private property
|
||||
@@ -252,6 +273,8 @@ describe('ChatRecordingService', () => {
|
||||
output: 2,
|
||||
total: 4,
|
||||
cached: 0,
|
||||
thoughts: 0,
|
||||
tool: 0,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -297,7 +320,14 @@ describe('ChatRecordingService', () => {
|
||||
) as ConversationRecord;
|
||||
expect(conversation.messages[0]).toEqual({
|
||||
...initialConversation.messages[0],
|
||||
toolCalls: [toolCall],
|
||||
toolCalls: [
|
||||
{
|
||||
...toolCall,
|
||||
displayName: 'Test Tool',
|
||||
description: 'A test tool',
|
||||
renderOutputAsMarkdown: false,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
@@ -343,7 +373,14 @@ describe('ChatRecordingService', () => {
|
||||
type: 'gemini',
|
||||
thoughts: [],
|
||||
content: '',
|
||||
toolCalls: [toolCall],
|
||||
toolCalls: [
|
||||
{
|
||||
...toolCall,
|
||||
displayName: 'Test Tool',
|
||||
description: 'A test tool',
|
||||
renderOutputAsMarkdown: false,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,10 @@ import { getProjectHash } from '../utils/paths.js';
|
||||
import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import type { PartListUnion } from '@google/genai';
|
||||
import type {
|
||||
PartListUnion,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
|
||||
/**
|
||||
* Token usage summary for a message or conversation.
|
||||
@@ -31,7 +34,7 @@ export interface TokensSummary {
|
||||
export interface BaseMessageRecord {
|
||||
id: string;
|
||||
timestamp: string;
|
||||
content: string;
|
||||
content: PartListUnion;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -178,7 +181,7 @@ export class ChatRecordingService {
|
||||
|
||||
private newMessage(
|
||||
type: ConversationRecordExtra['type'],
|
||||
content: string,
|
||||
content: PartListUnion,
|
||||
): MessageRecord {
|
||||
return {
|
||||
id: randomUUID(),
|
||||
@@ -193,22 +196,12 @@ export class ChatRecordingService {
|
||||
*/
|
||||
recordMessage(message: {
|
||||
type: ConversationRecordExtra['type'];
|
||||
content: string;
|
||||
append?: boolean;
|
||||
content: PartListUnion;
|
||||
}): void {
|
||||
if (!this.conversationFile) return;
|
||||
|
||||
try {
|
||||
this.updateConversation((conversation) => {
|
||||
if (message.append) {
|
||||
const lastMsg = this.getLastMessage(conversation);
|
||||
if (lastMsg && lastMsg.type === message.type) {
|
||||
lastMsg.content += message.content;
|
||||
return;
|
||||
}
|
||||
}
|
||||
// We're not appending, or we are appending but the last message's type is not the same as
|
||||
// the specified type, so just create a new message.
|
||||
const msg = this.newMessage(message.type, message.content);
|
||||
if (msg.type === 'gemini') {
|
||||
// If it's a new Gemini message then incorporate any queued thoughts.
|
||||
@@ -243,27 +236,28 @@ export class ChatRecordingService {
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
} catch (error) {
|
||||
if (this.config.getDebugMode()) {
|
||||
console.error('Error saving thought:', error);
|
||||
throw error;
|
||||
}
|
||||
console.error('Error saving thought:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates the tokens for the last message in the conversation (which should be by Gemini).
|
||||
*/
|
||||
recordMessageTokens(tokens: {
|
||||
input: number;
|
||||
output: number;
|
||||
cached: number;
|
||||
thoughts?: number;
|
||||
tool?: number;
|
||||
total: number;
|
||||
}): void {
|
||||
recordMessageTokens(
|
||||
respUsageMetadata: GenerateContentResponseUsageMetadata,
|
||||
): void {
|
||||
if (!this.conversationFile) return;
|
||||
|
||||
try {
|
||||
const tokens = {
|
||||
input: respUsageMetadata.promptTokenCount ?? 0,
|
||||
output: respUsageMetadata.candidatesTokenCount ?? 0,
|
||||
cached: respUsageMetadata.cachedContentTokenCount ?? 0,
|
||||
thoughts: respUsageMetadata.thoughtsTokenCount ?? 0,
|
||||
tool: respUsageMetadata.toolUsePromptTokenCount ?? 0,
|
||||
total: respUsageMetadata.totalTokenCount ?? 0,
|
||||
};
|
||||
this.updateConversation((conversation) => {
|
||||
const lastMsg = this.getLastMessage(conversation);
|
||||
// If the last message already has token info, it's because this new token info is for a
|
||||
@@ -283,10 +277,23 @@ export class ChatRecordingService {
|
||||
|
||||
/**
|
||||
* Adds tool calls to the last message in the conversation (which should be by Gemini).
|
||||
* This method enriches tool calls with metadata from the ToolRegistry.
|
||||
*/
|
||||
recordToolCalls(toolCalls: ToolCallRecord[]): void {
|
||||
if (!this.conversationFile) return;
|
||||
|
||||
// Enrich tool calls with metadata from the ToolRegistry
|
||||
const toolRegistry = this.config.getToolRegistry();
|
||||
const enrichedToolCalls = toolCalls.map((toolCall) => {
|
||||
const toolInstance = toolRegistry.getTool(toolCall.name);
|
||||
return {
|
||||
...toolCall,
|
||||
displayName: toolInstance?.displayName || toolCall.name,
|
||||
description: toolInstance?.description || '',
|
||||
renderOutputAsMarkdown: toolInstance?.isOutputMarkdown || false,
|
||||
};
|
||||
});
|
||||
|
||||
try {
|
||||
this.updateConversation((conversation) => {
|
||||
const lastMsg = this.getLastMessage(conversation);
|
||||
@@ -309,7 +316,7 @@ export class ChatRecordingService {
|
||||
// resulting message's type, and so it thinks that toolCalls may
|
||||
// not be present. Confirming the type here satisfies it.
|
||||
type: 'gemini' as const,
|
||||
toolCalls,
|
||||
toolCalls: enrichedToolCalls,
|
||||
thoughts: this.queuedThoughts,
|
||||
model: this.config.getModel(),
|
||||
};
|
||||
@@ -346,7 +353,7 @@ export class ChatRecordingService {
|
||||
});
|
||||
|
||||
// Add any new tools calls that aren't in the message yet.
|
||||
for (const toolCall of toolCalls) {
|
||||
for (const toolCall of enrichedToolCalls) {
|
||||
const existingToolCall = lastMsg.toolCalls.find(
|
||||
(tc) => tc.id === toolCall.id,
|
||||
);
|
||||
|
||||
@@ -14,6 +14,32 @@ import type { NextSpeakerResponse } from './nextSpeakerChecker.js';
|
||||
import { checkNextSpeaker } from './nextSpeakerChecker.js';
|
||||
import { GeminiChat } from '../core/geminiChat.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
|
||||
vi.mock('node:fs', () => {
|
||||
const fsModule = {
|
||||
mkdirSync: vi.fn(),
|
||||
writeFileSync: vi.fn((path: string, data: string) => {
|
||||
mockFileSystem.set(path, data);
|
||||
}),
|
||||
readFileSync: vi.fn((path: string) => {
|
||||
if (mockFileSystem.has(path)) {
|
||||
return mockFileSystem.get(path);
|
||||
}
|
||||
throw Object.assign(new Error('ENOENT: no such file or directory'), {
|
||||
code: 'ENOENT',
|
||||
});
|
||||
}),
|
||||
existsSync: vi.fn((path: string) => mockFileSystem.has(path)),
|
||||
};
|
||||
|
||||
return {
|
||||
default: fsModule,
|
||||
...fsModule,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock GeminiClient and Config constructor
|
||||
vi.mock('../core/client.js');
|
||||
vi.mock('../config/config.js');
|
||||
@@ -64,6 +90,17 @@ describe('checkNextSpeaker', () => {
|
||||
undefined,
|
||||
);
|
||||
|
||||
// Mock the methods that ChatRecordingService needs
|
||||
mockConfigInstance.getSessionId = vi
|
||||
.fn()
|
||||
.mockReturnValue('test-session-id');
|
||||
mockConfigInstance.getProjectRoot = vi
|
||||
.fn()
|
||||
.mockReturnValue('/test/project/root');
|
||||
mockConfigInstance.storage = {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
};
|
||||
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
|
||||
// Reset mocks before each test to ensure test isolation
|
||||
|
||||
Reference in New Issue
Block a user