diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index da0479ecae..e716157ff6 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -265,6 +265,8 @@ describe('Gemini Client (client.ts)', () => { client = new GeminiClient(mockConfig); await client.initialize(); vi.mocked(mockConfig.getGeminiClient).mockReturnValue(client); + + vi.mocked(uiTelemetryService.setLastPromptTokenCount).mockClear(); }); afterEach(() => { @@ -328,6 +330,7 @@ describe('Gemini Client (client.ts)', () => { getHistory: mockGetHistory, addHistory: vi.fn(), setHistory: vi.fn(), + getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; }); @@ -343,6 +346,7 @@ describe('Gemini Client (client.ts)', () => { const mockOriginalChat: Partial = { getHistory: vi.fn((_curated?: boolean) => chatHistory), setHistory: vi.fn(), + getLastPromptTokenCount: vi.fn().mockReturnValue(originalTokenCount), }; client['chat'] = mockOriginalChat as GeminiChat; @@ -370,6 +374,7 @@ describe('Gemini Client (client.ts)', () => { const mockNewChat: Partial = { getHistory: vi.fn().mockReturnValue(newHistory), setHistory: vi.fn(), + getLastPromptTokenCount: vi.fn().mockReturnValue(newTokenCount), }; client['startChat'] = vi @@ -651,6 +656,7 @@ describe('Gemini Client (client.ts)', () => { const mockChat = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), } as unknown as GeminiChat; client['chat'] = mockChat; @@ -713,6 +719,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -773,6 +780,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -849,6 +857,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -895,6 +904,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -942,6 +952,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1006,6 +1017,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1062,6 +1074,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1118,9 +1131,11 @@ ${JSON.stringify( // Set last prompt token count const lastPromptTokenCount = 900; - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - lastPromptTokenCount, - ); + const mockChat: Partial = { + getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; // Remaining = 100. Threshold (95%) = 95. // We need a request > 95 tokens. @@ -1177,9 +1192,11 @@ ${JSON.stringify( // Set token count const lastPromptTokenCount = 900; - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( - lastPromptTokenCount, - ); + const mockChat: Partial = { + getLastPromptTokenCount: vi.fn().mockReturnValue(lastPromptTokenCount), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; // Remaining (sticky) = 100. Threshold (95%) = 95. // We need a request > 95 tokens. @@ -1236,6 +1253,13 @@ ${JSON.stringify( yield { type: 'content', value: 'Hello' }; })(), ); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), + }; + client['chat'] = mockChat as GeminiChat; }); it('should use the model router service to select a model on the first turn', async () => { @@ -1409,6 +1433,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1460,6 +1485,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1493,6 +1519,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -1539,6 +1566,7 @@ ${JSON.stringify( .mockReturnValue([ { role: 'user', parts: [{ text: 'previous message' }] }, ]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; }); @@ -1797,6 +1825,7 @@ ${JSON.stringify( addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), // Default empty history setHistory: vi.fn(), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -2137,6 +2166,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -2173,6 +2203,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; @@ -2209,6 +2240,7 @@ ${JSON.stringify( const mockChat: Partial = { addHistory: vi.fn(), getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), }; client['chat'] = mockChat as GeminiChat; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 6b22ee99b7..f190fb5e6a 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -46,10 +46,10 @@ import { ContentRetryFailureEvent, NextSpeakerCheckEvent, } from '../telemetry/types.js'; +import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import type { IdeContext, File } from '../ide/types.js'; import { handleFallback } from '../fallback/handler.js'; import type { RoutingContext } from '../routing/routingStrategy.js'; -import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { debugLogger } from '../utils/debugLogger.js'; export function isThinkingSupported(model: string) { @@ -92,8 +92,17 @@ export class GeminiClient { this.lastPromptId = this.config.getSessionId(); } + private updateTelemetryTokenCount() { + if (this.chat) { + uiTelemetryService.setLastPromptTokenCount( + this.chat.getLastPromptTokenCount(), + ); + } + } + async initialize() { this.chat = await this.startChat(); + this.updateTelemetryTokenCount(); } private getContentGeneratorOrFail(): ContentGenerator { @@ -140,6 +149,7 @@ export class GeminiClient { async resetChat(): Promise { this.chat = await this.startChat(); + this.updateTelemetryTokenCount(); } getChatRecordingService(): ChatRecordingService | undefined { @@ -424,8 +434,7 @@ export class GeminiClient { ); const remainingTokenCount = - tokenLimit(modelForLimitCheck) - - uiTelemetryService.getLastPromptTokenCount(); + tokenLimit(modelForLimitCheck) - this.getChat().getLastPromptTokenCount(); if (estimatedRequestTokenCount > remainingTokenCount * 0.95) { yield { @@ -506,6 +515,9 @@ export class GeminiClient { return turn; } yield event; + + this.updateTelemetryTokenCount(); + if (event.type === GeminiEventType.InvalidStream) { if (this.config.getContinueOnFailedApiCall()) { if (isInvalidStreamRetry) { @@ -671,6 +683,7 @@ export class GeminiClient { } else if (info.compressionStatus === CompressionStatus.COMPRESSED) { if (newHistory) { this.chat = await this.startChat(newHistory); + this.updateTelemetryTokenCount(); this.forceFullIdeContext = true; } } diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index e5a017e711..f5facae5dc 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -141,6 +141,25 @@ describe('GeminiChat', () => { vi.resetAllMocks(); }); + describe('constructor', () => { + it('should initialize lastPromptTokenCount based on history size', () => { + const history: Content[] = [ + { role: 'user', parts: [{ text: 'Hello' }] }, + { role: 'model', parts: [{ text: 'Hi there' }] }, + ]; + const chatWithHistory = new GeminiChat(mockConfig, config, history); + const estimatedTokens = Math.ceil(JSON.stringify(history).length / 4); + expect(chatWithHistory.getLastPromptTokenCount()).toBe(estimatedTokens); + }); + + it('should initialize lastPromptTokenCount for empty history', () => { + const chatEmpty = new GeminiChat(mockConfig, config, []); + expect(chatEmpty.getLastPromptTokenCount()).toBe( + Math.ceil(JSON.stringify([]).length / 4), + ); + }); + }); + describe('sendMessageStream', () => { it('should succeed if a tool call is followed by an empty part', async () => { // 1. Mock a stream that contains a tool call, then an invalid (empty) part. @@ -708,14 +727,6 @@ describe('GeminiChat', () => { }, 'prompt-id-1', ); - - // Verify that token counting is called when usageMetadata is present - expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledWith( - 42, - ); - expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledTimes( - 1, - ); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index f4dc1c552c..54553bc4e4 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -38,7 +38,6 @@ import { import { handleFallback } from '../fallback/handler.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; -import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; export enum StreamEventType { /** A regular content chunk from the API. */ @@ -186,6 +185,7 @@ export class GeminiChat { // model. private sendPromise: Promise = Promise.resolve(); private readonly chatRecordingService: ChatRecordingService; + private lastPromptTokenCount: number; constructor( private readonly config: Config, @@ -195,6 +195,9 @@ export class GeminiChat { validateHistory(history); this.chatRecordingService = new ChatRecordingService(config); this.chatRecordingService.initialize(); + this.lastPromptTokenCount = Math.ceil( + JSON.stringify(this.history).length / 4, + ); } setSystemInstruction(sysInstr: string) { @@ -521,9 +524,7 @@ export class GeminiChat { if (chunk.usageMetadata) { this.chatRecordingService.recordMessageTokens(chunk.usageMetadata); if (chunk.usageMetadata.promptTokenCount !== undefined) { - uiTelemetryService.setLastPromptTokenCount( - chunk.usageMetadata.promptTokenCount, - ); + this.lastPromptTokenCount = chunk.usageMetadata.promptTokenCount; } } @@ -584,6 +585,10 @@ export class GeminiChat { this.history.push({ role: 'model', parts: consolidatedParts }); } + getLastPromptTokenCount(): number { + return this.lastPromptTokenCount; + } + /** * Gets the chat recording service instance. */ diff --git a/packages/core/src/services/chatCompressionService.test.ts b/packages/core/src/services/chatCompressionService.test.ts index ba5688b458..f7ffe55eed 100644 --- a/packages/core/src/services/chatCompressionService.test.ts +++ b/packages/core/src/services/chatCompressionService.test.ts @@ -11,14 +11,12 @@ import { } from './chatCompressionService.js'; import type { Content, GenerateContentResponse } from '@google/genai'; import { CompressionStatus } from '../core/turn.js'; -import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { tokenLimit } from '../core/tokenLimits.js'; import type { GeminiChat } from '../core/geminiChat.js'; import type { Config } from '../config/config.js'; import { getInitialChatHistory } from '../utils/environmentContext.js'; import type { ContentGenerator } from '../core/contentGenerator.js'; -vi.mock('../telemetry/uiTelemetry.js'); vi.mock('../core/tokenLimits.js'); vi.mock('../telemetry/loggers.js'); vi.mock('../utils/environmentContext.js'); @@ -114,6 +112,7 @@ describe('ChatCompressionService', () => { service = new ChatCompressionService(); mockChat = { getHistory: vi.fn(), + getLastPromptTokenCount: vi.fn().mockReturnValue(500), } as unknown as GeminiChat; mockConfig = { getChatCompression: vi.fn(), @@ -121,7 +120,6 @@ describe('ChatCompressionService', () => { } as unknown as Config; vi.mocked(tokenLimit).mockReturnValue(1000); - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(500); vi.mocked(getInitialChatHistory).mockImplementation( async (_config, extraHistory) => extraHistory || [], ); @@ -165,7 +163,7 @@ describe('ChatCompressionService', () => { vi.mocked(mockChat.getHistory).mockReturnValue([ { role: 'user', parts: [{ text: 'hi' }] }, ]); - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(600); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600); vi.mocked(tokenLimit).mockReturnValue(1000); // Threshold is 0.7 * 1000 = 700. 600 < 700, so NOOP. @@ -189,7 +187,7 @@ describe('ChatCompressionService', () => { { role: 'model', parts: [{ text: 'msg4' }] }, ]; vi.mocked(mockChat.getHistory).mockReturnValue(history); - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(800); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(800); vi.mocked(tokenLimit).mockReturnValue(1000); const mockGenerateContent = vi.fn().mockResolvedValue({ candidates: [ @@ -227,7 +225,7 @@ describe('ChatCompressionService', () => { { role: 'model', parts: [{ text: 'msg4' }] }, ]; vi.mocked(mockChat.getHistory).mockReturnValue(history); - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(100); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(100); vi.mocked(tokenLimit).mockReturnValue(1000); const mockGenerateContent = vi.fn().mockResolvedValue({ @@ -262,7 +260,7 @@ describe('ChatCompressionService', () => { { role: 'model', parts: [{ text: 'msg2' }] }, ]; vi.mocked(mockChat.getHistory).mockReturnValue(history); - vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(10); + vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(10); vi.mocked(tokenLimit).mockReturnValue(1000); const longSummary = 'a'.repeat(1000); // Long summary to inflate token count diff --git a/packages/core/src/services/chatCompressionService.ts b/packages/core/src/services/chatCompressionService.ts index 0c94130398..fa96b27749 100644 --- a/packages/core/src/services/chatCompressionService.ts +++ b/packages/core/src/services/chatCompressionService.ts @@ -8,7 +8,6 @@ import type { Content } from '@google/genai'; import type { Config } from '../config/config.js'; import type { GeminiChat } from '../core/geminiChat.js'; import { type ChatCompressionInfo, CompressionStatus } from '../core/turn.js'; -import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { tokenLimit } from '../core/tokenLimits.js'; import { getCompressionPrompt } from '../core/prompts.js'; import { getResponseText } from '../utils/partUtils.js'; @@ -102,7 +101,7 @@ export class ChatCompressionService { }; } - const originalTokenCount = uiTelemetryService.getLastPromptTokenCount(); + const originalTokenCount = chat.getLastPromptTokenCount(); // Don't compress if not forced and we are under the limit. if (!force) { @@ -204,7 +203,6 @@ export class ChatCompressionService { }, }; } else { - uiTelemetryService.setLastPromptTokenCount(newTokenCount); return { newHistory: extraHistory, info: {