feat(core): Improve request token calculation accuracy (#13824)

This commit is contained in:
Sandy Tao
2025-11-26 12:20:46 +08:00
committed by GitHub
parent 36a0a3d37b
commit e1d2653a7a
8 changed files with 307 additions and 56 deletions
@@ -0,0 +1,130 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi } from 'vitest';
import { calculateRequestTokenCount } from './tokenCalculation.js';
import type { ContentGenerator } from '../core/contentGenerator.js';
describe('calculateRequestTokenCount', () => {
const mockContentGenerator = {
countTokens: vi.fn(),
} as unknown as ContentGenerator;
const model = 'gemini-pro';
it('should use countTokens API for media requests (images/files)', async () => {
vi.mocked(mockContentGenerator.countTokens).mockResolvedValue({
totalTokens: 100,
});
const request = [{ inlineData: { mimeType: 'image/png', data: 'data' } }];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(100);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
it('should estimate tokens locally for tool calls', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
const request = [{ functionCall: { name: 'foo', args: { bar: 'baz' } } }];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
// Estimation logic: JSON.stringify(part).length / 4
// JSON: {"functionCall":{"name":"foo","args":{"bar":"baz"}}}
// Length: ~53 chars. 53 / 4 = 13.25 -> 13.
expect(count).toBeGreaterThan(0);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should estimate tokens locally for simple ASCII text', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 12 chars. 12 * 0.25 = 3 tokens.
const request = 'Hello world!';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should estimate tokens locally for CJK text with higher weight', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 2 chars. 2 * 1.3 = 2.6 -> floor(2.6) = 2.
// Old logic would be 2/4 = 0.5 -> 0.
const request = '你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBeGreaterThanOrEqual(2);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle mixed content', async () => {
vi.mocked(mockContentGenerator.countTokens).mockClear();
// 'Hi': 2 * 0.25 = 0.5
// '你好': 2 * 1.3 = 2.6
// Total: 3.1 -> 3
const request = 'Hi你好';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(3);
expect(mockContentGenerator.countTokens).not.toHaveBeenCalled();
});
it('should handle empty text', async () => {
const request = '';
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
expect(count).toBe(0);
});
it('should fallback to local estimation when countTokens API fails', async () => {
vi.mocked(mockContentGenerator.countTokens).mockRejectedValue(
new Error('API error'),
);
const request = [
{ text: 'Hello' },
{ inlineData: { mimeType: 'image/png', data: 'data' } },
];
const count = await calculateRequestTokenCount(
request,
mockContentGenerator,
model,
);
// Should fallback to estimation:
// 'Hello': 5 chars * 0.25 = 1.25
// inlineData: JSON.stringify length / 4
expect(count).toBeGreaterThan(0);
expect(mockContentGenerator.countTokens).toHaveBeenCalled();
});
});
@@ -0,0 +1,79 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { PartListUnion, Part } from '@google/genai';
import type { ContentGenerator } from '../core/contentGenerator.js';
// Token estimation constants
// ASCII characters (0-127) are roughly 4 chars per token
const ASCII_TOKENS_PER_CHAR = 0.25;
// Non-ASCII characters (including CJK) are often 1-2 tokens per char.
// We use 1.3 as a conservative estimate to avoid underestimation.
const NON_ASCII_TOKENS_PER_CHAR = 1.3;
/**
* Estimates token count for parts synchronously using a heuristic.
* - Text: character-based heuristic (ASCII vs CJK).
* - Non-text (Tools, etc): JSON string length / 4.
*/
export function estimateTokenCountSync(parts: Part[]): number {
let totalTokens = 0;
for (const part of parts) {
if (typeof part.text === 'string') {
for (const char of part.text) {
if (char.codePointAt(0)! <= 127) {
totalTokens += ASCII_TOKENS_PER_CHAR;
} else {
totalTokens += NON_ASCII_TOKENS_PER_CHAR;
}
}
} else {
// For non-text parts (functionCall, functionResponse, executableCode, etc.),
// we fallback to the JSON string length heuristic.
// Note: This is an approximation.
totalTokens += JSON.stringify(part).length / 4;
}
}
return Math.floor(totalTokens);
}
/**
* Calculates the token count of the request.
* If the request contains only text or tools, it estimates the token count locally.
* If the request contains media (images, files), it uses the countTokens API.
*/
export async function calculateRequestTokenCount(
request: PartListUnion,
contentGenerator: ContentGenerator,
model: string,
): Promise<number> {
const parts: Part[] = Array.isArray(request)
? request.map((p) => (typeof p === 'string' ? { text: p } : p))
: typeof request === 'string'
? [{ text: request }]
: [request];
// Use countTokens API only for heavy media parts that are hard to estimate.
const hasMedia = parts.some((p) => {
const isMedia = 'inlineData' in p || 'fileData' in p;
return isMedia;
});
if (hasMedia) {
try {
const response = await contentGenerator.countTokens({
model,
contents: [{ role: 'user', parts }],
});
return response.totalTokens ?? 0;
} catch {
// Fallback to local estimation if the API call fails
return estimateTokenCountSync(parts);
}
}
return estimateTokenCountSync(parts);
}