From 6e1180dd6b9833951a49e6dfdc961ee058be5867 Mon Sep 17 00:00:00 2001 From: Coco Sheng Date: Wed, 29 Apr 2026 12:07:12 -0400 Subject: [PATCH] perf(core): minimalist optimization for ToolOutputMaskingService Optimizes ToolOutputMaskingService by skipping redundant serialization for already-masked items. - Introduced a fast structural check early in the scan loop to detect and skip previously masked tool outputs using the isRecord type guard. - Maintains canonical estimateTokenCountSync logic for unmasked items, ensuring consistency with project-wide token limits. - Resolves the O(N^2) overhead identified in the issue with minimal risk of regression or logic drift. - Explicitly supports array responses for masking. --- .../context/toolOutputMaskingService.test.ts | 77 ++++++++++++++++++- .../src/context/toolOutputMaskingService.ts | 24 ++++-- packages/core/src/utils/tokenCalculation.ts | 37 ++++----- 3 files changed, 104 insertions(+), 34 deletions(-) diff --git a/packages/core/src/context/toolOutputMaskingService.test.ts b/packages/core/src/context/toolOutputMaskingService.test.ts index 037890b443..26dd48e7c9 100644 --- a/packages/core/src/context/toolOutputMaskingService.test.ts +++ b/packages/core/src/context/toolOutputMaskingService.test.ts @@ -328,11 +328,84 @@ describe('ToolOutputMaskingService', () => { }, ], }, + { role: 'user', parts: [{ text: 'latest' }] }, ]; - mockedEstimateTokenCountSync.mockReturnValue(60000); + + const getContentSpy = vi.spyOn( + service as unknown as { + getToolOutputContent: (part: Part) => string | null; + }, + 'getToolOutputContent', + ); + mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => { + if (parts[0].functionResponse?.name === 'tool2') return 60000; + return 100; + }); const result = await service.mask(history, mockConfig); - expect(result.maskedCount).toBe(0); // tool1 skipped, tool2 is the "latest" which is protected + + // tool1 is already masked and should be skipped by the structural check + // without ever calling getToolOutputContent. + // tool2 is NOT masked but it is scanned. + + expect(getContentSpy).toHaveBeenCalledTimes(1); + expect(getContentSpy).toHaveBeenCalledWith( + expect.objectContaining({ + functionResponse: expect.objectContaining({ name: 'tool2' }), + }), + ); + expect(getContentSpy).not.toHaveBeenCalledWith( + expect.objectContaining({ + functionResponse: expect.objectContaining({ name: 'tool1' }), + }), + ); + expect(result.maskedCount).toBe(0); + }); + + it('should support masking of tool responses that are arrays', async () => { + mockConfig.getToolOutputMaskingConfig = async () => ({ + enabled: true, + protectionThresholdTokens: 50, + minPrunableThresholdTokens: 10, + protectLatestTurn: true, + }); + + const arrayHistory: Content[] = [ + { + role: 'user', + parts: [ + { + functionResponse: { + name: 'array_tool', + response: { + data: Array.from({ length: 100 }, (_, i) => ({ + id: i, + data: 'A'.repeat(100), + })), + }, + }, + }, + ], + }, + { role: 'user', parts: [{ text: 'latest' }] }, + ]; + + mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => { + const resp = parts[0].functionResponse?.response as Record< + string, + unknown + >; + const content = (resp?.['output'] as string) ?? JSON.stringify(resp); + if (content.includes(MASKING_INDICATOR_TAG)) return 100; + if (parts[0].functionResponse?.name === 'array_tool') return 20000; + return 100; + }); + + const result = await service.mask(arrayHistory, mockConfig); + expect(result.maskedCount).toBe(1); + expect(getToolResponse(result.newHistory[0].parts?.[0])).toContain( + MASKING_INDICATOR_TAG, + ); }); it('should handle different response keys in masked update', async () => { diff --git a/packages/core/src/context/toolOutputMaskingService.ts b/packages/core/src/context/toolOutputMaskingService.ts index 77158040ca..b25c9e722a 100644 --- a/packages/core/src/context/toolOutputMaskingService.ts +++ b/packages/core/src/context/toolOutputMaskingService.ts @@ -12,6 +12,7 @@ import { debugLogger } from '../utils/debugLogger.js'; import { sanitizeFilenamePart } from '../utils/fileUtils.js'; import type { Config } from '../config/config.js'; import { logToolOutputMasking } from '../telemetry/loggers.js'; +import { isRecord } from '../utils/markdownUtils.js'; import { SHELL_TOOL_NAME, ACTIVATE_SKILL_TOOL_NAME, @@ -115,8 +116,20 @@ export class ToolOutputMaskingService { continue; } + const response = part.functionResponse.response; + // Fast structural check to skip already-masked items without stringifying them. + if (isRecord(response)) { + const output = response['output']; + if ( + typeof output === 'string' && + output.startsWith(`<${MASKING_INDICATOR_TAG}>`) + ) { + continue; + } + } + const toolOutputContent = this.getToolOutputContent(part); - if (!toolOutputContent || this.isAlreadyMasked(toolOutputContent)) { + if (!toolOutputContent) { continue; } @@ -273,9 +286,8 @@ export class ToolOutputMaskingService { private getToolOutputContent(part: Part): string | null { if (!part.functionResponse) return null; - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const response = part.functionResponse.response as Record; - if (!response) return null; + const response = part.functionResponse.response; + if (typeof response !== 'object' || response === null) return null; // Stringify the entire response for saving. // This handles any tool output schema automatically. @@ -287,10 +299,6 @@ export class ToolOutputMaskingService { return content; } - private isAlreadyMasked(content: string): boolean { - return content.includes(`<${MASKING_INDICATOR_TAG}`); - } - private formatShellPreview(response: Record): string { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const content = (response['output'] || response['stdout'] || '') as string; diff --git a/packages/core/src/utils/tokenCalculation.ts b/packages/core/src/utils/tokenCalculation.ts index a1115bcf74..b61b7cbb5d 100644 --- a/packages/core/src/utils/tokenCalculation.ts +++ b/packages/core/src/utils/tokenCalculation.ts @@ -29,14 +29,12 @@ const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000; // standard multimodal responses are typically depth 1. const MAX_RECURSION_DEPTH = 3; -const DEFAULT_CHARS_PER_TOKEN = 4; - /** * Heuristic estimation of tokens for a text string. */ -function estimateTextTokens(text: string, charsPerToken: number): number { +function estimateTextTokens(text: string): number { if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) { - return text.length / charsPerToken; + return text.length / 4; } let tokens = 0; @@ -75,33 +73,25 @@ function estimateMediaTokens(part: Part): number | undefined { * Heuristic estimation for tool responses, avoiding massive string copies * and accounting for nested Gemini 3 multimodal parts. */ -function estimateFunctionResponseTokens( - part: Part, - depth: number, - charsPerToken: number, -): number { +function estimateFunctionResponseTokens(part: Part, depth: number): number { const fr = part.functionResponse; if (!fr) return 0; - let totalTokens = (fr.name?.length ?? 0) / charsPerToken; + let totalTokens = (fr.name?.length ?? 0) / 4; const response = fr.response as unknown; if (typeof response === 'string') { - totalTokens += response.length / charsPerToken; + totalTokens += response.length / 4; } else if (response !== undefined && response !== null) { // For objects, stringify only the payload, not the whole Part object. - totalTokens += JSON.stringify(response).length / charsPerToken; + totalTokens += JSON.stringify(response).length / 4; } // Gemini 3: Handle nested multimodal parts recursively. // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const nestedParts = (fr as unknown as { parts?: Part[] }).parts; if (nestedParts && nestedParts.length > 0) { - totalTokens += estimateTokenCountSync( - nestedParts, - depth + 1, - charsPerToken, - ); + totalTokens += estimateTokenCountSync(nestedParts, depth + 1); } return totalTokens; @@ -110,12 +100,11 @@ function estimateFunctionResponseTokens( /** * Estimates token count for parts synchronously using a heuristic. * - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones. - * - Non-text (Tools, etc): JSON string length / charsPerToken. + * - Non-text (Tools, etc): JSON string length / 4. */ export function estimateTokenCountSync( parts: Part[], depth: number = 0, - charsPerToken: number = DEFAULT_CHARS_PER_TOKEN, ): number { if (depth > MAX_RECURSION_DEPTH) { return 0; @@ -124,9 +113,9 @@ export function estimateTokenCountSync( let totalTokens = 0; for (const part of parts) { if (typeof part.text === 'string') { - totalTokens += estimateTextTokens(part.text, charsPerToken); + totalTokens += estimateTextTokens(part.text); } else if (part.functionResponse) { - totalTokens += estimateFunctionResponseTokens(part, depth, charsPerToken); + totalTokens += estimateFunctionResponseTokens(part, depth); } else { const mediaEstimate = estimateMediaTokens(part); if (mediaEstimate !== undefined) { @@ -134,7 +123,7 @@ export function estimateTokenCountSync( } else { // Fallback for other non-text parts (e.g., functionCall). // Note: JSON.stringify(part) here is safe as these parts are typically small. - totalTokens += JSON.stringify(part).length / charsPerToken; + totalTokens += JSON.stringify(part).length / 4; } } } @@ -173,9 +162,9 @@ export async function calculateRequestTokenCount( } catch (error) { // Fallback to local estimation if the API call fails debugLogger.debug('countTokens API failed:', error); - return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN); + return estimateTokenCountSync(parts); } } - return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN); + return estimateTokenCountSync(parts); }