mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-14 13:53:02 -07:00
perf(core): minimalist optimization for ToolOutputMaskingService
Optimizes ToolOutputMaskingService by skipping redundant serialization for already-masked items. - Introduced a fast structural check early in the scan loop to detect and skip previously masked tool outputs using the isRecord type guard. - Maintains canonical estimateTokenCountSync logic for unmasked items, ensuring consistency with project-wide token limits. - Resolves the O(N^2) overhead identified in the issue with minimal risk of regression or logic drift. - Explicitly supports array responses for masking.
This commit is contained in:
@@ -328,11 +328,84 @@ describe('ToolOutputMaskingService', () => {
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: 'user', parts: [{ text: 'latest' }] },
|
||||
];
|
||||
mockedEstimateTokenCountSync.mockReturnValue(60000);
|
||||
|
||||
const getContentSpy = vi.spyOn(
|
||||
service as unknown as {
|
||||
getToolOutputContent: (part: Part) => string | null;
|
||||
},
|
||||
'getToolOutputContent',
|
||||
);
|
||||
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
|
||||
if (parts[0].functionResponse?.name === 'tool2') return 60000;
|
||||
return 100;
|
||||
});
|
||||
|
||||
const result = await service.mask(history, mockConfig);
|
||||
expect(result.maskedCount).toBe(0); // tool1 skipped, tool2 is the "latest" which is protected
|
||||
|
||||
// tool1 is already masked and should be skipped by the structural check
|
||||
// without ever calling getToolOutputContent.
|
||||
// tool2 is NOT masked but it is scanned.
|
||||
|
||||
expect(getContentSpy).toHaveBeenCalledTimes(1);
|
||||
expect(getContentSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({ name: 'tool2' }),
|
||||
}),
|
||||
);
|
||||
expect(getContentSpy).not.toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({ name: 'tool1' }),
|
||||
}),
|
||||
);
|
||||
expect(result.maskedCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should support masking of tool responses that are arrays', async () => {
|
||||
mockConfig.getToolOutputMaskingConfig = async () => ({
|
||||
enabled: true,
|
||||
protectionThresholdTokens: 50,
|
||||
minPrunableThresholdTokens: 10,
|
||||
protectLatestTurn: true,
|
||||
});
|
||||
|
||||
const arrayHistory: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'array_tool',
|
||||
response: {
|
||||
data: Array.from({ length: 100 }, (_, i) => ({
|
||||
id: i,
|
||||
data: 'A'.repeat(100),
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: 'user', parts: [{ text: 'latest' }] },
|
||||
];
|
||||
|
||||
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
|
||||
const resp = parts[0].functionResponse?.response as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
|
||||
if (content.includes(MASKING_INDICATOR_TAG)) return 100;
|
||||
if (parts[0].functionResponse?.name === 'array_tool') return 20000;
|
||||
return 100;
|
||||
});
|
||||
|
||||
const result = await service.mask(arrayHistory, mockConfig);
|
||||
expect(result.maskedCount).toBe(1);
|
||||
expect(getToolResponse(result.newHistory[0].parts?.[0])).toContain(
|
||||
MASKING_INDICATOR_TAG,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle different response keys in masked update', async () => {
|
||||
|
||||
@@ -12,6 +12,7 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { sanitizeFilenamePart } from '../utils/fileUtils.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { logToolOutputMasking } from '../telemetry/loggers.js';
|
||||
import { isRecord } from '../utils/markdownUtils.js';
|
||||
import {
|
||||
SHELL_TOOL_NAME,
|
||||
ACTIVATE_SKILL_TOOL_NAME,
|
||||
@@ -115,8 +116,20 @@ export class ToolOutputMaskingService {
|
||||
continue;
|
||||
}
|
||||
|
||||
const response = part.functionResponse.response;
|
||||
// Fast structural check to skip already-masked items without stringifying them.
|
||||
if (isRecord(response)) {
|
||||
const output = response['output'];
|
||||
if (
|
||||
typeof output === 'string' &&
|
||||
output.startsWith(`<${MASKING_INDICATOR_TAG}>`)
|
||||
) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const toolOutputContent = this.getToolOutputContent(part);
|
||||
if (!toolOutputContent || this.isAlreadyMasked(toolOutputContent)) {
|
||||
if (!toolOutputContent) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -273,9 +286,8 @@ export class ToolOutputMaskingService {
|
||||
|
||||
private getToolOutputContent(part: Part): string | null {
|
||||
if (!part.functionResponse) return null;
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const response = part.functionResponse.response as Record<string, unknown>;
|
||||
if (!response) return null;
|
||||
const response = part.functionResponse.response;
|
||||
if (typeof response !== 'object' || response === null) return null;
|
||||
|
||||
// Stringify the entire response for saving.
|
||||
// This handles any tool output schema automatically.
|
||||
@@ -287,10 +299,6 @@ export class ToolOutputMaskingService {
|
||||
return content;
|
||||
}
|
||||
|
||||
private isAlreadyMasked(content: string): boolean {
|
||||
return content.includes(`<${MASKING_INDICATOR_TAG}`);
|
||||
}
|
||||
|
||||
private formatShellPreview(response: Record<string, unknown>): string {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const content = (response['output'] || response['stdout'] || '') as string;
|
||||
|
||||
@@ -29,14 +29,12 @@ const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000;
|
||||
// standard multimodal responses are typically depth 1.
|
||||
const MAX_RECURSION_DEPTH = 3;
|
||||
|
||||
const DEFAULT_CHARS_PER_TOKEN = 4;
|
||||
|
||||
/**
|
||||
* Heuristic estimation of tokens for a text string.
|
||||
*/
|
||||
function estimateTextTokens(text: string, charsPerToken: number): number {
|
||||
function estimateTextTokens(text: string): number {
|
||||
if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
|
||||
return text.length / charsPerToken;
|
||||
return text.length / 4;
|
||||
}
|
||||
|
||||
let tokens = 0;
|
||||
@@ -75,33 +73,25 @@ function estimateMediaTokens(part: Part): number | undefined {
|
||||
* Heuristic estimation for tool responses, avoiding massive string copies
|
||||
* and accounting for nested Gemini 3 multimodal parts.
|
||||
*/
|
||||
function estimateFunctionResponseTokens(
|
||||
part: Part,
|
||||
depth: number,
|
||||
charsPerToken: number,
|
||||
): number {
|
||||
function estimateFunctionResponseTokens(part: Part, depth: number): number {
|
||||
const fr = part.functionResponse;
|
||||
if (!fr) return 0;
|
||||
|
||||
let totalTokens = (fr.name?.length ?? 0) / charsPerToken;
|
||||
let totalTokens = (fr.name?.length ?? 0) / 4;
|
||||
const response = fr.response as unknown;
|
||||
|
||||
if (typeof response === 'string') {
|
||||
totalTokens += response.length / charsPerToken;
|
||||
totalTokens += response.length / 4;
|
||||
} else if (response !== undefined && response !== null) {
|
||||
// For objects, stringify only the payload, not the whole Part object.
|
||||
totalTokens += JSON.stringify(response).length / charsPerToken;
|
||||
totalTokens += JSON.stringify(response).length / 4;
|
||||
}
|
||||
|
||||
// Gemini 3: Handle nested multimodal parts recursively.
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const nestedParts = (fr as unknown as { parts?: Part[] }).parts;
|
||||
if (nestedParts && nestedParts.length > 0) {
|
||||
totalTokens += estimateTokenCountSync(
|
||||
nestedParts,
|
||||
depth + 1,
|
||||
charsPerToken,
|
||||
);
|
||||
totalTokens += estimateTokenCountSync(nestedParts, depth + 1);
|
||||
}
|
||||
|
||||
return totalTokens;
|
||||
@@ -110,12 +100,11 @@ function estimateFunctionResponseTokens(
|
||||
/**
|
||||
* Estimates token count for parts synchronously using a heuristic.
|
||||
* - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones.
|
||||
* - Non-text (Tools, etc): JSON string length / charsPerToken.
|
||||
* - Non-text (Tools, etc): JSON string length / 4.
|
||||
*/
|
||||
export function estimateTokenCountSync(
|
||||
parts: Part[],
|
||||
depth: number = 0,
|
||||
charsPerToken: number = DEFAULT_CHARS_PER_TOKEN,
|
||||
): number {
|
||||
if (depth > MAX_RECURSION_DEPTH) {
|
||||
return 0;
|
||||
@@ -124,9 +113,9 @@ export function estimateTokenCountSync(
|
||||
let totalTokens = 0;
|
||||
for (const part of parts) {
|
||||
if (typeof part.text === 'string') {
|
||||
totalTokens += estimateTextTokens(part.text, charsPerToken);
|
||||
totalTokens += estimateTextTokens(part.text);
|
||||
} else if (part.functionResponse) {
|
||||
totalTokens += estimateFunctionResponseTokens(part, depth, charsPerToken);
|
||||
totalTokens += estimateFunctionResponseTokens(part, depth);
|
||||
} else {
|
||||
const mediaEstimate = estimateMediaTokens(part);
|
||||
if (mediaEstimate !== undefined) {
|
||||
@@ -134,7 +123,7 @@ export function estimateTokenCountSync(
|
||||
} else {
|
||||
// Fallback for other non-text parts (e.g., functionCall).
|
||||
// Note: JSON.stringify(part) here is safe as these parts are typically small.
|
||||
totalTokens += JSON.stringify(part).length / charsPerToken;
|
||||
totalTokens += JSON.stringify(part).length / 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -173,9 +162,9 @@ export async function calculateRequestTokenCount(
|
||||
} catch (error) {
|
||||
// Fallback to local estimation if the API call fails
|
||||
debugLogger.debug('countTokens API failed:', error);
|
||||
return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN);
|
||||
return estimateTokenCountSync(parts);
|
||||
}
|
||||
}
|
||||
|
||||
return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN);
|
||||
return estimateTokenCountSync(parts);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user