From ffcd9963667ffc792f0591df93ee70ec2a287f9a Mon Sep 17 00:00:00 2001 From: Sandy Tao Date: Sat, 27 Sep 2025 11:56:10 -0700 Subject: [PATCH] feat(core): Use lastPromptTokenCount to determine if we need to compress (#10000) --- packages/core/src/core/client.test.ts | 374 ++++++++++++++++++-------- packages/core/src/core/client.ts | 42 +-- 2 files changed, 277 insertions(+), 139 deletions(-) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 6ab25ce658..54050f0ff3 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -115,6 +115,7 @@ vi.mock('../ide/ideContext.js'); vi.mock('../telemetry/uiTelemetry.js', () => ({ uiTelemetryService: { setLastPromptTokenCount: vi.fn(), + getLastPromptTokenCount: vi.fn(), }, })); @@ -261,7 +262,6 @@ describe('Gemini Client (client.ts)', () => { mockContentGenerator = { generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), - countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), batchEmbedContents: vi.fn(), } as unknown as ContentGenerator; @@ -402,102 +402,161 @@ describe('Gemini Client (client.ts)', () => { { role: 'user', parts: [{ text: 'Long conversation' }] }, { role: 'model', parts: [{ text: 'Long response' }] }, ] as Content[], + originalTokenCount = 1000, + summaryText = 'This is a summary.', } = {}) { - const mockChat: Partial = { - getHistory: vi.fn().mockReturnValue(chatHistory), + const mockOriginalChat: Partial = { + getHistory: vi.fn((_curated?: boolean) => chatHistory), setHistory: vi.fn(), }; - vi.mocked(mockContentGenerator.countTokens) - .mockResolvedValueOnce({ totalTokens: 1000 }) - .mockResolvedValueOnce({ totalTokens: 5000 }); + client['chat'] = mockOriginalChat as GeminiChat; - client['chat'] = mockChat as GeminiChat; - client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat }); + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); - return { client, mockChat }; + mockGenerateContentFn.mockResolvedValue({ + candidates: [ + { + content: { + role: 'model', + parts: [{ text: summaryText }], + }, + }, + ], + } as unknown as GenerateContentResponse); + + // Calculate what the new history will be + const splitPoint = findCompressSplitPoint(chatHistory, 0.7); // 1 - 0.3 + const historyToKeep = chatHistory.slice(splitPoint); + + // This is the history that the new chat will have. + // It includes the default startChat history + the extra history from tryCompressChat + const newCompressedHistory: Content[] = [ + // Mocked envParts + canned response from startChat + { + role: 'user', + parts: [{ text: 'Mocked env context' }], + }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the context!' }], + }, + // extraHistory from tryCompressChat + { + role: 'user', + parts: [{ text: summaryText }], + }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the additional context!' }], + }, + ...historyToKeep, + ]; + + const mockNewChat: Partial = { + getHistory: vi.fn().mockReturnValue(newCompressedHistory), + setHistory: vi.fn(), + }; + + client['startChat'] = vi + .fn() + .mockResolvedValue(mockNewChat as GeminiChat); + + const totalChars = newCompressedHistory.reduce( + (total, content) => total + JSON.stringify(content).length, + 0, + ); + const estimatedNewTokenCount = Math.floor(totalChars / 4); + + return { + client, + mockOriginalChat, + mockNewChat, + estimatedNewTokenCount, + }; } describe('when compression inflates the token count', () => { it('allows compression to be forced/manual after a failure', async () => { - const { client } = setup(); - - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: 1000, + // Call 1 (Fails): Setup with a long summary to inflate tokens + const longSummary = 'long summary '.repeat(100); + const { client, estimatedNewTokenCount: inflatedTokenCount } = setup({ + originalTokenCount: 100, + summaryText: longSummary, }); + expect(inflatedTokenCount).toBeGreaterThan(100); // Ensure setup is correct + await client.tryCompressChat('prompt-id-4', false); // Fails - const result = await client.tryCompressChat('prompt-id-4', true); + + // Call 2 (Forced): Re-setup with a short summary + const shortSummary = 'short'; + const { estimatedNewTokenCount: compressedTokenCount } = setup({ + originalTokenCount: 100, + summaryText: shortSummary, + }); + expect(compressedTokenCount).toBeLessThanOrEqual(100); // Ensure setup is correct + + const result = await client.tryCompressChat('prompt-id-4', true); // Forced expect(result).toEqual({ compressionStatus: CompressionStatus.COMPRESSED, - newTokenCount: 1000, - originalTokenCount: 1000, + newTokenCount: compressedTokenCount, + originalTokenCount: 100, }); }); it('yields the result even if the compression inflated the tokens', async () => { - const { client } = setup(); - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: 1000, + const longSummary = 'long summary '.repeat(100); + const { client, estimatedNewTokenCount } = setup({ + originalTokenCount: 100, + summaryText: longSummary, }); + expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct + const result = await client.tryCompressChat('prompt-id-4', false); expect(result).toEqual({ compressionStatus: CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, - newTokenCount: 5000, - originalTokenCount: 1000, + newTokenCount: estimatedNewTokenCount, + originalTokenCount: 100, }); - expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledWith( - 5000, - ); + // IMPORTANT: The change in client.ts means setLastPromptTokenCount is NOT called on failure expect( uiTelemetryService.setLastPromptTokenCount, - ).toHaveBeenCalledTimes(1); + ).not.toHaveBeenCalled(); }); it('does not manipulate the source chat', async () => { - const { client, mockChat } = setup(); + const longSummary = 'long summary '.repeat(100); + const { client, mockOriginalChat, estimatedNewTokenCount } = setup({ + originalTokenCount: 100, + summaryText: longSummary, + }); + expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct + await client.tryCompressChat('prompt-id-4', false); - expect(client['chat']).toBe(mockChat); // a new chat session was not created - }); - - it('restores the history back to the original', async () => { - vi.mocked(tokenLimit).mockReturnValue(1000); - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: 999, - }); - - const originalHistory: Content[] = [ - { role: 'user', parts: [{ text: 'what is your wisdom?' }] }, - { role: 'model', parts: [{ text: 'some wisdom' }] }, - { role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] }, - ]; - - const { client } = setup({ - chatHistory: originalHistory, - }); - const { compressionStatus } = await client.tryCompressChat( - 'prompt-id-4', - false, - ); - - expect(compressionStatus).toBe( - CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, - ); - expect(client['chat']?.setHistory).toHaveBeenCalledWith( - originalHistory, - ); + // On failure, the chat should NOT be replaced + expect(client['chat']).toBe(mockOriginalChat); }); it('will not attempt to compress context after a failure', async () => { - const { client } = setup(); - await client.tryCompressChat('prompt-id-4', false); + const longSummary = 'long summary '.repeat(100); + const { client, estimatedNewTokenCount } = setup({ + originalTokenCount: 100, + summaryText: longSummary, + }); + expect(estimatedNewTokenCount).toBeGreaterThan(100); // Ensure setup is correct + await client.tryCompressChat('prompt-id-4', false); // This fails and sets hasFailedCompressionAttempt = true + + // This call should now be a NOOP const result = await client.tryCompressChat('prompt-id-5', false); - // it counts tokens for {original, compressed} and then never again - expect(mockContentGenerator.countTokens).toHaveBeenCalledTimes(2); + // generateContent (for summary) should only have been called once + expect(mockGenerateContentFn).toHaveBeenCalledTimes(1); expect(result).toEqual({ compressionStatus: CompressionStatus.NOOP, newTokenCount: 0, @@ -512,9 +571,10 @@ describe('Gemini Client (client.ts)', () => { mockGetHistory.mockReturnValue([ { role: 'user', parts: [{ text: '...history...' }] }, ]); - vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ - totalTokens: MOCKED_TOKEN_LIMIT * 0.699, // TOKEN_THRESHOLD_FOR_SUMMARIZATION = 0.7 - }); + const originalTokenCount = MOCKED_TOKEN_LIMIT * 0.699; + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); const initialChat = client.getChat(); const result = await client.tryCompressChat('prompt-id-2', false); @@ -523,8 +583,8 @@ describe('Gemini Client (client.ts)', () => { expect(tokenLimit).toHaveBeenCalled(); expect(result).toEqual({ compressionStatus: CompressionStatus.NOOP, - newTokenCount: 699, - originalTokenCount: 699, + newTokenCount: originalTokenCount, + originalTokenCount, }); expect(newChat).toBe(initialChat); }); @@ -538,17 +598,43 @@ describe('Gemini Client (client.ts)', () => { vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({ contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD, }); - mockGetHistory.mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]); + const history = [{ role: 'user', parts: [{ text: '...history...' }] }]; + mockGetHistory.mockReturnValue(history); const originalTokenCount = MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; - const newTokenCount = 100; - vi.mocked(mockContentGenerator.countTokens) - .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check - .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); + + // We need to control the estimated new token count. + // We mock startChat to return a chat with a known history. + const summaryText = 'This is a summary.'; + const splitPoint = findCompressSplitPoint(history, 0.7); + const historyToKeep = history.slice(splitPoint); + const newCompressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'Mocked env context' }] }, + { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, + { role: 'user', parts: [{ text: summaryText }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the additional context!' }], + }, + ...historyToKeep, + ]; + const mockNewChat: Partial = { + getHistory: vi.fn().mockReturnValue(newCompressedHistory), + }; + client['startChat'] = vi + .fn() + .mockResolvedValue(mockNewChat as GeminiChat); + + const totalChars = newCompressedHistory.reduce( + (total, content) => total + JSON.stringify(content).length, + 0, + ); + const newTokenCount = Math.floor(totalChars / 4); // Mock the summary response from the chat mockGenerateContentFn.mockResolvedValue({ @@ -556,7 +642,7 @@ describe('Gemini Client (client.ts)', () => { { content: { role: 'model', - parts: [{ text: 'This is a summary.' }], + parts: [{ text: summaryText }], }, }, ], @@ -587,17 +673,42 @@ describe('Gemini Client (client.ts)', () => { vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({ contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD, }); - mockGetHistory.mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]); + const history = [{ role: 'user', parts: [{ text: '...history...' }] }]; + mockGetHistory.mockReturnValue(history); const originalTokenCount = MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD; - const newTokenCount = 100; - vi.mocked(mockContentGenerator.countTokens) - .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check - .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); + + // Mock summary and new chat + const summaryText = 'This is a summary.'; + const splitPoint = findCompressSplitPoint(history, 0.7); + const historyToKeep = history.slice(splitPoint); + const newCompressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'Mocked env context' }] }, + { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, + { role: 'user', parts: [{ text: summaryText }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the additional context!' }], + }, + ...historyToKeep, + ]; + const mockNewChat: Partial = { + getHistory: vi.fn().mockReturnValue(newCompressedHistory), + }; + client['startChat'] = vi + .fn() + .mockResolvedValue(mockNewChat as GeminiChat); + + const totalChars = newCompressedHistory.reduce( + (total, content) => total + JSON.stringify(content).length, + 0, + ); + const newTokenCount = Math.floor(totalChars / 4); // Mock the summary response from the chat mockGenerateContentFn.mockResolvedValue({ @@ -605,7 +716,7 @@ describe('Gemini Client (client.ts)', () => { { content: { role: 'model', - parts: [{ text: 'This is a summary.' }], + parts: [{ text: summaryText }], }, }, ], @@ -632,7 +743,7 @@ describe('Gemini Client (client.ts)', () => { it('should not compress across a function call response', async () => { const MOCKED_TOKEN_LIMIT = 1000; vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); - mockGetHistory.mockReturnValue([ + const history: Content[] = [ { role: 'user', parts: [{ text: '...history 1...' }] }, { role: 'model', parts: [{ text: '...history 2...' }] }, { role: 'user', parts: [{ text: '...history 3...' }] }, @@ -649,14 +760,45 @@ describe('Gemini Client (client.ts)', () => { { role: 'model', parts: [{ text: '...history 10...' }] }, // Instead we will break here. { role: 'user', parts: [{ text: '...history 10...' }] }, - ]); + ]; + mockGetHistory.mockReturnValue(history); const originalTokenCount = 1000 * 0.7; - const newTokenCount = 100; + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); - vi.mocked(mockContentGenerator.countTokens) - .mockResolvedValueOnce({ totalTokens: originalTokenCount }) // First call for the check - .mockResolvedValueOnce({ totalTokens: newTokenCount }); // Second call for the new history + // Mock summary and new chat + const summaryText = 'This is a summary.'; + const splitPoint = findCompressSplitPoint(history, 0.7); // This should be 10 + expect(splitPoint).toBe(10); // Verify split point logic + const historyToKeep = history.slice(splitPoint); // Should keep last user message + expect(historyToKeep).toEqual([ + { role: 'user', parts: [{ text: '...history 10...' }] }, + ]); + + const newCompressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'Mocked env context' }] }, + { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, + { role: 'user', parts: [{ text: summaryText }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the additional context!' }], + }, + ...historyToKeep, + ]; + const mockNewChat: Partial = { + getHistory: vi.fn().mockReturnValue(newCompressedHistory), + }; + client['startChat'] = vi + .fn() + .mockResolvedValue(mockNewChat as GeminiChat); + + const totalChars = newCompressedHistory.reduce( + (total, content) => total + JSON.stringify(content).length, + 0, + ); + const newTokenCount = Math.floor(totalChars / 4); // Mock the summary response from the chat mockGenerateContentFn.mockResolvedValue({ @@ -664,7 +806,7 @@ describe('Gemini Client (client.ts)', () => { { content: { role: 'model', - parts: [{ text: 'This is a summary.' }], + parts: [{ text: summaryText }], }, }, ], @@ -686,25 +828,49 @@ describe('Gemini Client (client.ts)', () => { // Assert that the chat was reset expect(newChat).not.toBe(initialChat); - // 1. standard start context message - // 2. standard canned user start message - // 3. compressed summary message - // 4. standard canned user summary message - // 5. The last user message (not the last 3 because that would start with a function response) + // 1. standard start context message (env) + // 2. standard canned model response + // 3. compressed summary message (user) + // 4. standard canned model response + // 5. The last user message (historyToKeep) expect(newChat.getHistory().length).toEqual(5); }); it('should always trigger summarization when force is true, regardless of token count', async () => { - mockGetHistory.mockReturnValue([ - { role: 'user', parts: [{ text: '...history...' }] }, - ]); + const history = [{ role: 'user', parts: [{ text: '...history...' }] }]; + mockGetHistory.mockReturnValue(history); - const originalTokenCount = 10; // Well below threshold - const newTokenCount = 5; + const originalTokenCount = 100; // Well below threshold, but > estimated new count + vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue( + originalTokenCount, + ); - vi.mocked(mockContentGenerator.countTokens) - .mockResolvedValueOnce({ totalTokens: originalTokenCount }) - .mockResolvedValueOnce({ totalTokens: newTokenCount }); + // Mock summary and new chat + const summaryText = 'This is a summary.'; + const splitPoint = findCompressSplitPoint(history, 0.7); + const historyToKeep = history.slice(splitPoint); + const newCompressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'Mocked env context' }] }, + { role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] }, + { role: 'user', parts: [{ text: summaryText }] }, + { + role: 'model', + parts: [{ text: 'Got it. Thanks for the additional context!' }], + }, + ...historyToKeep, + ]; + const mockNewChat: Partial = { + getHistory: vi.fn().mockReturnValue(newCompressedHistory), + }; + client['startChat'] = vi + .fn() + .mockResolvedValue(mockNewChat as GeminiChat); + + const totalChars = newCompressedHistory.reduce( + (total, content) => total + JSON.stringify(content).length, + 0, + ); + const newTokenCount = Math.floor(totalChars / 4); // Mock the summary response from the chat mockGenerateContentFn.mockResolvedValue({ @@ -712,14 +878,14 @@ describe('Gemini Client (client.ts)', () => { { content: { role: 'model', - parts: [{ text: 'This is a summary.' }], + parts: [{ text: summaryText }], }, }, ], } as unknown as GenerateContentResponse); const initialChat = client.getChat(); - const result = await client.tryCompressChat('prompt-id-1', false); // force = true + const result = await client.tryCompressChat('prompt-id-1', true); // force = true const newChat = client.getChat(); expect(mockGenerateContentFn).toHaveBeenCalled(); @@ -776,10 +942,6 @@ describe('Gemini Client (client.ts)', () => { CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, }, { compressionStatus: CompressionStatus.NOOP }, - { - compressionStatus: - CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, - }, ])( 'does not emit a compression event when the status is $compressionStatus', async ({ compressionStatus }) => { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 07e645ba9e..940c8349dd 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -679,21 +679,7 @@ export class GeminiClient { }; } - const { totalTokens: originalTokenCount } = - await this.getContentGeneratorOrFail().countTokens({ - model, - contents: curatedHistory, - }); - if (originalTokenCount === undefined) { - console.warn(`Could not determine token count for model ${model}.`); - this.hasFailedCompressionAttempt = !force && true; - return { - originalTokenCount: 0, - newTokenCount: 0, - compressionStatus: - CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, - }; - } + const originalTokenCount = uiTelemetryService.getLastPromptTokenCount(); const contextPercentageThreshold = this.config.getChatCompression()?.contextPercentageThreshold; @@ -756,23 +742,13 @@ export class GeminiClient { ]); this.forceFullIdeContext = true; - const { totalTokens: newTokenCount } = - await this.getContentGeneratorOrFail().countTokens({ - model, - contents: chat.getHistory(), - }); - if (newTokenCount === undefined) { - console.warn('Could not determine compressed history token count.'); - this.hasFailedCompressionAttempt = !force && true; - return { - originalTokenCount, - newTokenCount: originalTokenCount, - compressionStatus: - CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, - }; - } - - uiTelemetryService.setLastPromptTokenCount(newTokenCount); + // Estimate token count 1 token ≈ 4 characters + const newTokenCount = Math.floor( + chat + .getHistory() + .reduce((total, content) => total + JSON.stringify(content).length, 0) / + 4, + ); logChatCompression( this.config, @@ -783,7 +759,6 @@ export class GeminiClient { ); if (newTokenCount > originalTokenCount) { - this.getChat().setHistory(curatedHistory); this.hasFailedCompressionAttempt = !force && true; return { originalTokenCount, @@ -793,6 +768,7 @@ export class GeminiClient { }; } else { this.chat = chat; // Chat compression successful, set new state. + uiTelemetryService.setLastPromptTokenCount(newTokenCount); } return {