diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index aac4ee84a1..6487e04e80 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -166,6 +166,7 @@ describe('Gemini Client (client.ts)', () => { generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), batchEmbedContents: vi.fn(), + countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), } as unknown as ContentGenerator; // Because the GeminiClient constructor kicks off an async process (startChat) @@ -902,6 +903,75 @@ ${JSON.stringify( }); }); + it('should use local estimation for text-only requests and NOT call countTokens', async () => { + const request = [{ text: 'Hello world' }]; + const generator = client['getContentGeneratorOrFail'](); + const countTokensSpy = vi.spyOn(generator, 'countTokens'); + + const stream = client.sendMessageStream( + request, + new AbortController().signal, + 'test-prompt-id', + ); + await stream.next(); // Trigger the generator + + expect(countTokensSpy).not.toHaveBeenCalled(); + }); + + it('should use countTokens API for requests with non-text parts', async () => { + const request = [ + { text: 'Describe this image' }, + { inlineData: { mimeType: 'image/png', data: 'base64...' } }, + ]; + const generator = client['getContentGeneratorOrFail'](); + const countTokensSpy = vi + .spyOn(generator, 'countTokens') + .mockResolvedValue({ totalTokens: 123 }); + + const stream = client.sendMessageStream( + request, + new AbortController().signal, + 'test-prompt-id', + ); + await stream.next(); // Trigger the generator + + expect(countTokensSpy).toHaveBeenCalledWith( + expect.objectContaining({ + contents: expect.arrayContaining([ + expect.objectContaining({ + parts: expect.arrayContaining([ + { text: 'Describe this image' }, + { inlineData: { mimeType: 'image/png', data: 'base64...' } }, + ]), + }), + ]), + }), + ); + }); + + it('should estimate CJK characters more conservatively (closer to 1 token/char)', async () => { + const request = [{ text: '你好世界' }]; // 4 chars + const generator = client['getContentGeneratorOrFail'](); + const countTokensSpy = vi.spyOn(generator, 'countTokens'); + + // 4 chars. + // Old logic: 4/4 = 1. + // New logic (heuristic): 4 * 1 = 4. (Or at least > 1). + // Let's assert it's roughly accurate. + + const stream = client.sendMessageStream( + request, + new AbortController().signal, + 'test-prompt-id', + ); + await stream.next(); + + // Should NOT call countTokens (it's text only) + expect(countTokensSpy).not.toHaveBeenCalled(); + + // The actual token calculation is unit tested in tokenCalculation.test.ts + }); + it('should return the turn instance after the stream is complete', async () => { // Arrange const mockStream = (async function* () { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 95988e14e7..df42b1694e 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -56,47 +56,10 @@ import { handleFallback } from '../fallback/handler.js'; import type { RoutingContext } from '../routing/routingStrategy.js'; import { debugLogger } from '../utils/debugLogger.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; +import { calculateRequestTokenCount } from '../utils/tokenCalculation.js'; const MAX_TURNS = 100; -/** - * Estimates the character length of text-only parts in a request. - * Binary data (inline_data, fileData) is excluded from the estimation - * because Gemini counts these as fixed token values, not based on their size. - * @param request The request to estimate tokens for - * @returns Estimated character length of text content - */ -function estimateTextOnlyLength(request: PartListUnion): number { - if (typeof request === 'string') { - return request.length; - } - - // Ensure request is an array before iterating - if (!Array.isArray(request)) { - return 0; - } - - let textLength = 0; - for (const part of request) { - // Handle string elements in the array - if (typeof part === 'string') { - textLength += part.length; - } - // Handle object elements with text property - else if ( - typeof part === 'object' && - part !== null && - 'text' in part && - part.text - ) { - textLength += part.text.length; - } - // inlineData, fileData, and other binary parts are ignored - // as they are counted as fixed tokens by Gemini - } - return textLength; -} - export class GeminiClient { private chat?: GeminiChat; private sessionTurnCount = 0; @@ -493,11 +456,12 @@ export class GeminiClient { // Check for context window overflow const modelForLimitCheck = this._getEffectiveModelForCurrentTurn(); - // Estimate tokens based on text content only. - // Binary data (PDFs, images) are counted as fixed tokens by Gemini, - // not based on their base64-encoded size. - const estimatedRequestTokenCount = Math.floor( - estimateTextOnlyLength(request) / 4, + // Estimate tokens. For text-only requests, we estimate based on character length. + // For requests with non-text parts (like images, tools), we use the countTokens API. + const estimatedRequestTokenCount = await calculateRequestTokenCount( + request, + this.getContentGeneratorOrFail(), + modelForLimitCheck, ); const remainingTokenCount = diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index bafc12dcae..96db6f214e 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -174,15 +174,15 @@ describe('GeminiChat', () => { { role: 'model', parts: [{ text: 'Hi there' }] }, ]; const chatWithHistory = new GeminiChat(mockConfig, '', [], history); - const estimatedTokens = Math.ceil(JSON.stringify(history).length / 4); - expect(chatWithHistory.getLastPromptTokenCount()).toBe(estimatedTokens); + // 'Hello': 5 chars * 0.25 = 1.25 + // 'Hi there': 8 chars * 0.25 = 2.0 + // Total: 3.25 -> floor(3.25) = 3 + expect(chatWithHistory.getLastPromptTokenCount()).toBe(3); }); it('should initialize lastPromptTokenCount for empty history', () => { const chatEmpty = new GeminiChat(mockConfig); - expect(chatEmpty.getLastPromptTokenCount()).toBe( - Math.ceil(JSON.stringify([]).length / 4), - ); + expect(chatEmpty.getLastPromptTokenCount()).toBe(0); }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 3e48bcfcb9..a62fa73108 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -46,6 +46,7 @@ import { handleFallback } from '../fallback/handler.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; +import { estimateTokenCountSync } from '../utils/tokenCalculation.js'; export enum StreamEventType { /** A regular content chunk from the API. */ @@ -213,8 +214,8 @@ export class GeminiChat { validateHistory(history); this.chatRecordingService = new ChatRecordingService(config); this.chatRecordingService.initialize(resumedSessionData); - this.lastPromptTokenCount = Math.ceil( - JSON.stringify(this.history).length / 4, + this.lastPromptTokenCount = estimateTokenCountSync( + this.history.flatMap((c) => c.parts || []), ); } diff --git a/packages/core/src/services/chatCompressionService.test.ts b/packages/core/src/services/chatCompressionService.test.ts index 84d91ff192..4a27a72510 100644 --- a/packages/core/src/services/chatCompressionService.test.ts +++ b/packages/core/src/services/chatCompressionService.test.ts @@ -154,6 +154,9 @@ describe('ChatCompressionService', () => { generateContent: mockGenerateContent, }), isInteractive: vi.fn().mockReturnValue(false), + getContentGenerator: vi.fn().mockReturnValue({ + countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), + }), } as unknown as Config; vi.mocked(tokenLimit).mockReturnValue(1000); @@ -286,6 +289,11 @@ describe('ChatCompressionService', () => { ], } as unknown as GenerateContentResponse); + // Override mock to simulate high token count for this specific test + vi.mocked(mockConfig.getContentGenerator().countTokens).mockResolvedValue({ + totalTokens: 10000, + }); + const result = await service.compress( mockChat, mockPromptId, diff --git a/packages/core/src/services/chatCompressionService.ts b/packages/core/src/services/chatCompressionService.ts index 817f6cf48a..b3dbf74512 100644 --- a/packages/core/src/services/chatCompressionService.ts +++ b/packages/core/src/services/chatCompressionService.ts @@ -14,6 +14,7 @@ import { getResponseText } from '../utils/partUtils.js'; import { logChatCompression } from '../telemetry/loggers.js'; import { makeChatCompressionEvent } from '../telemetry/types.js'; import { getInitialChatHistory } from '../utils/environmentContext.js'; +import { calculateRequestTokenCount } from '../utils/tokenCalculation.js'; import { DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_FLASH_MODEL, @@ -195,12 +196,10 @@ export class ChatCompressionService { // Use a shared utility to construct the initial history for an accurate token count. const fullNewHistory = await getInitialChatHistory(config, extraHistory); - // Estimate token count 1 token ≈ 4 characters - const newTokenCount = Math.floor( - fullNewHistory.reduce( - (total, content) => total + JSON.stringify(content).length, - 0, - ) / 4, + const newTokenCount = await calculateRequestTokenCount( + fullNewHistory.flatMap((c) => c.parts || []), + config.getContentGenerator(), + model, ); logChatCompression( diff --git a/packages/core/src/utils/tokenCalculation.test.ts b/packages/core/src/utils/tokenCalculation.test.ts new file mode 100644 index 0000000000..7e1eae3e88 --- /dev/null +++ b/packages/core/src/utils/tokenCalculation.test.ts @@ -0,0 +1,130 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi } from 'vitest'; +import { calculateRequestTokenCount } from './tokenCalculation.js'; +import type { ContentGenerator } from '../core/contentGenerator.js'; + +describe('calculateRequestTokenCount', () => { + const mockContentGenerator = { + countTokens: vi.fn(), + } as unknown as ContentGenerator; + + const model = 'gemini-pro'; + + it('should use countTokens API for media requests (images/files)', async () => { + vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({ + totalTokens: 100, + }); + const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(100); + expect(mockContentGenerator.countTokens).toHaveBeenCalled(); + }); + + it('should estimate tokens locally for tool calls', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + // Estimation logic: JSON.stringify(part).length / 4 + // JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}} + // Length: ~53 chars. 53 / 4 = 13.25 -> 13. + expect(count).toBeGreaterThan(0); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should estimate tokens locally for simple ASCII text', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 12 chars. 12 * 0.25 = 3 tokens. + const request = 'Hello world!'; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should estimate tokens locally for CJK text with higher weight', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2. + // Old logic would be 2/4 = 0.5 -> 0. + const request = '你好'; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBeGreaterThanOrEqual(2); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should handle mixed content', async () => { + vi.mocked(mockContentGenerator.countTokens).mockClear(); + // 'Hi': 2 * 0.25 = 0.5 + // '你好': 2 * 1.3 = 2.6 + // Total: 3.1 -> 3 + const request = 'Hi你好'; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + expect(count).toBe(3); + expect(mockContentGenerator.countTokens).not.toHaveBeenCalled(); + }); + + it('should handle empty text', async () => { + const request = ''; + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + expect(count).toBe(0); + }); + + it('should fallback to local estimation when countTokens API fails', async () => { + vi.mocked(mockContentGenerator.countTokens).mockRejectedValue( + new Error('API error'), + ); + const request = [ + { text: 'Hello' }, + { inlineData: { mimeType: 'image/png', data: 'data' } }, + ]; + + const count = await calculateRequestTokenCount( + request, + mockContentGenerator, + model, + ); + + // Should fallback to estimation: + // 'Hello': 5 chars * 0.25 = 1.25 + // inlineData: JSON.stringify length / 4 + expect(count).toBeGreaterThan(0); + expect(mockContentGenerator.countTokens).toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/utils/tokenCalculation.ts b/packages/core/src/utils/tokenCalculation.ts new file mode 100644 index 0000000000..0359cb3e7c --- /dev/null +++ b/packages/core/src/utils/tokenCalculation.ts @@ -0,0 +1,79 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { PartListUnion, Part } from '@google/genai'; +import type { ContentGenerator } from '../core/contentGenerator.js'; + +// Token estimation constants +// ASCII characters (0-127) are roughly 4 chars per token +const ASCII_TOKENS_PER_CHAR = 0.25; +// Non-ASCII characters (including CJK) are often 1-2 tokens per char. +// We use 1.3 as a conservative estimate to avoid underestimation. +const NON_ASCII_TOKENS_PER_CHAR = 1.3; + +/** + * Estimates token count for parts synchronously using a heuristic. + * - Text: character-based heuristic (ASCII vs CJK). + * - Non-text (Tools, etc): JSON string length / 4. + */ +export function estimateTokenCountSync(parts: Part[]): number { + let totalTokens = 0; + for (const part of parts) { + if (typeof part.text === 'string') { + for (const char of part.text) { + if (char.codePointAt(0)! <= 127) { + totalTokens += ASCII_TOKENS_PER_CHAR; + } else { + totalTokens += NON_ASCII_TOKENS_PER_CHAR; + } + } + } else { + // For non-text parts (functionCall, functionResponse, executableCode, etc.), + // we fallback to the JSON string length heuristic. + // Note: This is an approximation. + totalTokens += JSON.stringify(part).length / 4; + } + } + return Math.floor(totalTokens); +} + +/** + * Calculates the token count of the request. + * If the request contains only text or tools, it estimates the token count locally. + * If the request contains media (images, files), it uses the countTokens API. + */ +export async function calculateRequestTokenCount( + request: PartListUnion, + contentGenerator: ContentGenerator, + model: string, +): Promise { + const parts: Part[] = Array.isArray(request) + ? request.map((p) => (typeof p === 'string' ? { text: p } : p)) + : typeof request === 'string' + ? [{ text: request }] + : [request]; + + // Use countTokens API only for heavy media parts that are hard to estimate. + const hasMedia = parts.some((p) => { + const isMedia = 'inlineData' in p || 'fileData' in p; + return isMedia; + }); + + if (hasMedia) { + try { + const response = await contentGenerator.countTokens({ + model, + contents: [{ role: 'user', parts }], + }); + return response.totalTokens ?? 0; + } catch { + // Fallback to local estimation if the API call fails + return estimateTokenCountSync(parts); + } + } + + return estimateTokenCountSync(parts); +}