mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 23:14:32 -07:00
feat(core): Improve request token calculation accuracy (#13824)
This commit is contained in:
@@ -166,6 +166,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
// Because the GeminiClient constructor kicks off an async process (startChat)
|
||||
@@ -902,6 +903,75 @@ ${JSON.stringify(
|
||||
});
|
||||
});
|
||||
|
||||
it('should use local estimation for text-only requests and NOT call countTokens', async () => {
|
||||
const request = [{ text: 'Hello world' }];
|
||||
const generator = client['getContentGeneratorOrFail']();
|
||||
const countTokensSpy = vi.spyOn(generator, 'countTokens');
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-prompt-id',
|
||||
);
|
||||
await stream.next(); // Trigger the generator
|
||||
|
||||
expect(countTokensSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use countTokens API for requests with non-text parts', async () => {
|
||||
const request = [
|
||||
{ text: 'Describe this image' },
|
||||
{ inlineData: { mimeType: 'image/png', data: 'base64...' } },
|
||||
];
|
||||
const generator = client['getContentGeneratorOrFail']();
|
||||
const countTokensSpy = vi
|
||||
.spyOn(generator, 'countTokens')
|
||||
.mockResolvedValue({ totalTokens: 123 });
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-prompt-id',
|
||||
);
|
||||
await stream.next(); // Trigger the generator
|
||||
|
||||
expect(countTokensSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
contents: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
parts: expect.arrayContaining([
|
||||
{ text: 'Describe this image' },
|
||||
{ inlineData: { mimeType: 'image/png', data: 'base64...' } },
|
||||
]),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should estimate CJK characters more conservatively (closer to 1 token/char)', async () => {
|
||||
const request = [{ text: '你好世界' }]; // 4 chars
|
||||
const generator = client['getContentGeneratorOrFail']();
|
||||
const countTokensSpy = vi.spyOn(generator, 'countTokens');
|
||||
|
||||
// 4 chars.
|
||||
// Old logic: 4/4 = 1.
|
||||
// New logic (heuristic): 4 * 1 = 4. (Or at least > 1).
|
||||
// Let's assert it's roughly accurate.
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
request,
|
||||
new AbortController().signal,
|
||||
'test-prompt-id',
|
||||
);
|
||||
await stream.next();
|
||||
|
||||
// Should NOT call countTokens (it's text only)
|
||||
expect(countTokensSpy).not.toHaveBeenCalled();
|
||||
|
||||
// The actual token calculation is unit tested in tokenCalculation.test.ts
|
||||
});
|
||||
|
||||
it('should return the turn instance after the stream is complete', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
|
||||
@@ -56,47 +56,10 @@ import { handleFallback } from '../fallback/handler.js';
|
||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { calculateRequestTokenCount } from '../utils/tokenCalculation.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
/**
|
||||
* Estimates the character length of text-only parts in a request.
|
||||
* Binary data (inline_data, fileData) is excluded from the estimation
|
||||
* because Gemini counts these as fixed token values, not based on their size.
|
||||
* @param request The request to estimate tokens for
|
||||
* @returns Estimated character length of text content
|
||||
*/
|
||||
function estimateTextOnlyLength(request: PartListUnion): number {
|
||||
if (typeof request === 'string') {
|
||||
return request.length;
|
||||
}
|
||||
|
||||
// Ensure request is an array before iterating
|
||||
if (!Array.isArray(request)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let textLength = 0;
|
||||
for (const part of request) {
|
||||
// Handle string elements in the array
|
||||
if (typeof part === 'string') {
|
||||
textLength += part.length;
|
||||
}
|
||||
// Handle object elements with text property
|
||||
else if (
|
||||
typeof part === 'object' &&
|
||||
part !== null &&
|
||||
'text' in part &&
|
||||
part.text
|
||||
) {
|
||||
textLength += part.text.length;
|
||||
}
|
||||
// inlineData, fileData, and other binary parts are ignored
|
||||
// as they are counted as fixed tokens by Gemini
|
||||
}
|
||||
return textLength;
|
||||
}
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private sessionTurnCount = 0;
|
||||
@@ -493,11 +456,12 @@ export class GeminiClient {
|
||||
// Check for context window overflow
|
||||
const modelForLimitCheck = this._getEffectiveModelForCurrentTurn();
|
||||
|
||||
// Estimate tokens based on text content only.
|
||||
// Binary data (PDFs, images) are counted as fixed tokens by Gemini,
|
||||
// not based on their base64-encoded size.
|
||||
const estimatedRequestTokenCount = Math.floor(
|
||||
estimateTextOnlyLength(request) / 4,
|
||||
// Estimate tokens. For text-only requests, we estimate based on character length.
|
||||
// For requests with non-text parts (like images, tools), we use the countTokens API.
|
||||
const estimatedRequestTokenCount = await calculateRequestTokenCount(
|
||||
request,
|
||||
this.getContentGeneratorOrFail(),
|
||||
modelForLimitCheck,
|
||||
);
|
||||
|
||||
const remainingTokenCount =
|
||||
|
||||
@@ -174,15 +174,15 @@ describe('GeminiChat', () => {
|
||||
{ role: 'model', parts: [{ text: 'Hi there' }] },
|
||||
];
|
||||
const chatWithHistory = new GeminiChat(mockConfig, '', [], history);
|
||||
const estimatedTokens = Math.ceil(JSON.stringify(history).length / 4);
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(estimatedTokens);
|
||||
// 'Hello': 5 chars * 0.25 = 1.25
|
||||
// 'Hi there': 8 chars * 0.25 = 2.0
|
||||
// Total: 3.25 -> floor(3.25) = 3
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(3);
|
||||
});
|
||||
|
||||
it('should initialize lastPromptTokenCount for empty history', () => {
|
||||
const chatEmpty = new GeminiChat(mockConfig);
|
||||
expect(chatEmpty.getLastPromptTokenCount()).toBe(
|
||||
Math.ceil(JSON.stringify([]).length / 4),
|
||||
);
|
||||
expect(chatEmpty.getLastPromptTokenCount()).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ import { handleFallback } from '../fallback/handler.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
/** A regular content chunk from the API. */
|
||||
@@ -213,8 +214,8 @@ export class GeminiChat {
|
||||
validateHistory(history);
|
||||
this.chatRecordingService = new ChatRecordingService(config);
|
||||
this.chatRecordingService.initialize(resumedSessionData);
|
||||
this.lastPromptTokenCount = Math.ceil(
|
||||
JSON.stringify(this.history).length / 4,
|
||||
this.lastPromptTokenCount = estimateTokenCountSync(
|
||||
this.history.flatMap((c) => c.parts || []),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -154,6 +154,9 @@ describe('ChatCompressionService', () => {
|
||||
generateContent: mockGenerateContent,
|
||||
}),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getContentGenerator: vi.fn().mockReturnValue({
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||
}),
|
||||
} as unknown as Config;
|
||||
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
@@ -286,6 +289,11 @@ describe('ChatCompressionService', () => {
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
// Override mock to simulate high token count for this specific test
|
||||
vi.mocked(mockConfig.getContentGenerator().countTokens).mockResolvedValue({
|
||||
totalTokens: 10000,
|
||||
});
|
||||
|
||||
const result = await service.compress(
|
||||
mockChat,
|
||||
mockPromptId,
|
||||
|
||||
@@ -14,6 +14,7 @@ import { getResponseText } from '../utils/partUtils.js';
|
||||
import { logChatCompression } from '../telemetry/loggers.js';
|
||||
import { makeChatCompressionEvent } from '../telemetry/types.js';
|
||||
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||
import { calculateRequestTokenCount } from '../utils/tokenCalculation.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
@@ -195,12 +196,10 @@ export class ChatCompressionService {
|
||||
// Use a shared utility to construct the initial history for an accurate token count.
|
||||
const fullNewHistory = await getInitialChatHistory(config, extraHistory);
|
||||
|
||||
// Estimate token count 1 token ≈ 4 characters
|
||||
const newTokenCount = Math.floor(
|
||||
fullNewHistory.reduce(
|
||||
(total, content) => total + JSON.stringify(content).length,
|
||||
0,
|
||||
) / 4,
|
||||
const newTokenCount = await calculateRequestTokenCount(
|
||||
fullNewHistory.flatMap((c) => c.parts || []),
|
||||
config.getContentGenerator(),
|
||||
model,
|
||||
);
|
||||
|
||||
logChatCompression(
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { calculateRequestTokenCount } from './tokenCalculation.js';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
|
||||
describe('calculateRequestTokenCount', () => {
|
||||
const mockContentGenerator = {
|
||||
countTokens: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
const model = 'gemini-pro';
|
||||
|
||||
it('should use countTokens API for media requests (images/files)', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
|
||||
totalTokens: 100,
|
||||
});
|
||||
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(100);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for tool calls', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
// Estimation logic: JSON.stringify(part).length / 4
|
||||
// JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}}
|
||||
// Length: ~53 chars. 53 / 4 = 13.25 -> 13.
|
||||
expect(count).toBeGreaterThan(0);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for simple ASCII text', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 12 chars. 12 * 0.25 = 3 tokens.
|
||||
const request = 'Hello world!';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should estimate tokens locally for CJK text with higher weight', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
|
||||
// Old logic would be 2/4 = 0.5 -> 0.
|
||||
const request = '你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBeGreaterThanOrEqual(2);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle mixed content', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockClear();
|
||||
// 'Hi': 2 * 0.25 = 0.5
|
||||
// '你好': 2 * 1.3 = 2.6
|
||||
// Total: 3.1 -> 3
|
||||
const request = 'Hi你好';
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
expect(count).toBe(3);
|
||||
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle empty text', async () => {
|
||||
const request = '';
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
expect(count).toBe(0);
|
||||
});
|
||||
|
||||
it('should fallback to local estimation when countTokens API fails', async () => {
|
||||
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
|
||||
new Error('API error'),
|
||||
);
|
||||
const request = [
|
||||
{ text: 'Hello' },
|
||||
{ inlineData: { mimeType: 'image/png', data: 'data' } },
|
||||
];
|
||||
|
||||
const count = await calculateRequestTokenCount(
|
||||
request,
|
||||
mockContentGenerator,
|
||||
model,
|
||||
);
|
||||
|
||||
// Should fallback to estimation:
|
||||
// 'Hello': 5 chars * 0.25 = 1.25
|
||||
// inlineData: JSON.stringify length / 4
|
||||
expect(count).toBeGreaterThan(0);
|
||||
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,79 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { PartListUnion, Part } from '@google/genai';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
|
||||
// Token estimation constants
|
||||
// ASCII characters (0-127) are roughly 4 chars per token
|
||||
const ASCII_TOKENS_PER_CHAR = 0.25;
|
||||
// Non-ASCII characters (including CJK) are often 1-2 tokens per char.
|
||||
// We use 1.3 as a conservative estimate to avoid underestimation.
|
||||
const NON_ASCII_TOKENS_PER_CHAR = 1.3;
|
||||
|
||||
/**
|
||||
* Estimates token count for parts synchronously using a heuristic.
|
||||
* - Text: character-based heuristic (ASCII vs CJK).
|
||||
* - Non-text (Tools, etc): JSON string length / 4.
|
||||
*/
|
||||
export function estimateTokenCountSync(parts: Part[]): number {
|
||||
let totalTokens = 0;
|
||||
for (const part of parts) {
|
||||
if (typeof part.text === 'string') {
|
||||
for (const char of part.text) {
|
||||
if (char.codePointAt(0)! <= 127) {
|
||||
totalTokens += ASCII_TOKENS_PER_CHAR;
|
||||
} else {
|
||||
totalTokens += NON_ASCII_TOKENS_PER_CHAR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-text parts (functionCall, functionResponse, executableCode, etc.),
|
||||
// we fallback to the JSON string length heuristic.
|
||||
// Note: This is an approximation.
|
||||
totalTokens += JSON.stringify(part).length / 4;
|
||||
}
|
||||
}
|
||||
return Math.floor(totalTokens);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the token count of the request.
|
||||
* If the request contains only text or tools, it estimates the token count locally.
|
||||
* If the request contains media (images, files), it uses the countTokens API.
|
||||
*/
|
||||
export async function calculateRequestTokenCount(
|
||||
request: PartListUnion,
|
||||
contentGenerator: ContentGenerator,
|
||||
model: string,
|
||||
): Promise<number> {
|
||||
const parts: Part[] = Array.isArray(request)
|
||||
? request.map((p) => (typeof p === 'string' ? { text: p } : p))
|
||||
: typeof request === 'string'
|
||||
? [{ text: request }]
|
||||
: [request];
|
||||
|
||||
// Use countTokens API only for heavy media parts that are hard to estimate.
|
||||
const hasMedia = parts.some((p) => {
|
||||
const isMedia = 'inlineData' in p || 'fileData' in p;
|
||||
return isMedia;
|
||||
});
|
||||
|
||||
if (hasMedia) {
|
||||
try {
|
||||
const response = await contentGenerator.countTokens({
|
||||
model,
|
||||
contents: [{ role: 'user', parts }],
|
||||
});
|
||||
return response.totalTokens ?? 0;
|
||||
} catch {
|
||||
// Fallback to local estimation if the API call fails
|
||||
return estimateTokenCountSync(parts);
|
||||
}
|
||||
}
|
||||
|
||||
return estimateTokenCountSync(parts);
|
||||
}
|
||||
Reference in New Issue
Block a user