perf(core): minimalist optimization for ToolOutputMaskingService

Optimizes ToolOutputMaskingService by skipping redundant serialization for already-masked items.
- Introduced a fast structural check early in the scan loop to detect and skip previously masked tool outputs using the isRecord type guard.
- Maintains canonical estimateTokenCountSync logic for unmasked items, ensuring consistency with project-wide token limits.
- Resolves the O(N^2) overhead identified in the issue with minimal risk of regression or logic drift.
- Explicitly supports array responses for masking.
This commit is contained in:
Coco Sheng
2026-04-29 12:07:12 -04:00
parent 7ab932c8bf
commit 6e1180dd6b
3 changed files with 104 additions and 34 deletions
@@ -328,11 +328,84 @@ describe('ToolOutputMaskingService', () => {
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockReturnValue(60000);
const getContentSpy = vi.spyOn(
service as unknown as {
getToolOutputContent: (part: Part) => string | null;
},
'getToolOutputContent',
);
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
if (parts[0].functionResponse?.name === 'tool2') return 60000;
return 100;
});
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(0); // tool1 skipped, tool2 is the "latest" which is protected
// tool1 is already masked and should be skipped by the structural check
// without ever calling getToolOutputContent.
// tool2 is NOT masked but it is scanned.
expect(getContentSpy).toHaveBeenCalledTimes(1);
expect(getContentSpy).toHaveBeenCalledWith(
expect.objectContaining({
functionResponse: expect.objectContaining({ name: 'tool2' }),
}),
);
expect(getContentSpy).not.toHaveBeenCalledWith(
expect.objectContaining({
functionResponse: expect.objectContaining({ name: 'tool1' }),
}),
);
expect(result.maskedCount).toBe(0);
});
it('should support masking of tool responses that are arrays', async () => {
mockConfig.getToolOutputMaskingConfig = async () => ({
enabled: true,
protectionThresholdTokens: 50,
minPrunableThresholdTokens: 10,
protectLatestTurn: true,
});
const arrayHistory: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'array_tool',
response: {
data: Array.from({ length: 100 }, (_, i) => ({
id: i,
data: 'A'.repeat(100),
})),
},
},
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(MASKING_INDICATOR_TAG)) return 100;
if (parts[0].functionResponse?.name === 'array_tool') return 20000;
return 100;
});
const result = await service.mask(arrayHistory, mockConfig);
expect(result.maskedCount).toBe(1);
expect(getToolResponse(result.newHistory[0].parts?.[0])).toContain(
MASKING_INDICATOR_TAG,
);
});
it('should handle different response keys in masked update', async () => {
@@ -12,6 +12,7 @@ import { debugLogger } from '../utils/debugLogger.js';
import { sanitizeFilenamePart } from '../utils/fileUtils.js';
import type { Config } from '../config/config.js';
import { logToolOutputMasking } from '../telemetry/loggers.js';
import { isRecord } from '../utils/markdownUtils.js';
import {
SHELL_TOOL_NAME,
ACTIVATE_SKILL_TOOL_NAME,
@@ -115,8 +116,20 @@ export class ToolOutputMaskingService {
continue;
}
const response = part.functionResponse.response;
// Fast structural check to skip already-masked items without stringifying them.
if (isRecord(response)) {
const output = response['output'];
if (
typeof output === 'string' &&
output.startsWith(`<${MASKING_INDICATOR_TAG}>`)
) {
continue;
}
}
const toolOutputContent = this.getToolOutputContent(part);
if (!toolOutputContent || this.isAlreadyMasked(toolOutputContent)) {
if (!toolOutputContent) {
continue;
}
@@ -273,9 +286,8 @@ export class ToolOutputMaskingService {
private getToolOutputContent(part: Part): string | null {
if (!part.functionResponse) return null;
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const response = part.functionResponse.response as Record<string, unknown>;
if (!response) return null;
const response = part.functionResponse.response;
if (typeof response !== 'object' || response === null) return null;
// Stringify the entire response for saving.
// This handles any tool output schema automatically.
@@ -287,10 +299,6 @@ export class ToolOutputMaskingService {
return content;
}
private isAlreadyMasked(content: string): boolean {
return content.includes(`<${MASKING_INDICATOR_TAG}`);
}
private formatShellPreview(response: Record<string, unknown>): string {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const content = (response['output'] || response['stdout'] || '') as string;
+13 -24
View File
@@ -29,14 +29,12 @@ const MAX_CHARS_FOR_FULL_HEURISTIC = 100_000;
// standard multimodal responses are typically depth 1.
const MAX_RECURSION_DEPTH = 3;
const DEFAULT_CHARS_PER_TOKEN = 4;
/**
* Heuristic estimation of tokens for a text string.
*/
function estimateTextTokens(text: string, charsPerToken: number): number {
function estimateTextTokens(text: string): number {
if (text.length > MAX_CHARS_FOR_FULL_HEURISTIC) {
return text.length / charsPerToken;
return text.length / 4;
}
let tokens = 0;
@@ -75,33 +73,25 @@ function estimateMediaTokens(part: Part): number | undefined {
* Heuristic estimation for tool responses, avoiding massive string copies
* and accounting for nested Gemini 3 multimodal parts.
*/
function estimateFunctionResponseTokens(
part: Part,
depth: number,
charsPerToken: number,
): number {
function estimateFunctionResponseTokens(part: Part, depth: number): number {
const fr = part.functionResponse;
if (!fr) return 0;
let totalTokens = (fr.name?.length ?? 0) / charsPerToken;
let totalTokens = (fr.name?.length ?? 0) / 4;
const response = fr.response as unknown;
if (typeof response === 'string') {
totalTokens += response.length / charsPerToken;
totalTokens += response.length / 4;
} else if (response !== undefined && response !== null) {
// For objects, stringify only the payload, not the whole Part object.
totalTokens += JSON.stringify(response).length / charsPerToken;
totalTokens += JSON.stringify(response).length / 4;
}
// Gemini 3: Handle nested multimodal parts recursively.
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const nestedParts = (fr as unknown as { parts?: Part[] }).parts;
if (nestedParts && nestedParts.length > 0) {
totalTokens += estimateTokenCountSync(
nestedParts,
depth + 1,
charsPerToken,
);
totalTokens += estimateTokenCountSync(nestedParts, depth + 1);
}
return totalTokens;
@@ -110,12 +100,11 @@ function estimateFunctionResponseTokens(
/**
* Estimates token count for parts synchronously using a heuristic.
* - Text: character-based heuristic (ASCII vs CJK) for small strings, length/4 for massive ones.
* - Non-text (Tools, etc): JSON string length / charsPerToken.
* - Non-text (Tools, etc): JSON string length / 4.
*/
export function estimateTokenCountSync(
parts: Part[],
depth: number = 0,
charsPerToken: number = DEFAULT_CHARS_PER_TOKEN,
): number {
if (depth > MAX_RECURSION_DEPTH) {
return 0;
@@ -124,9 +113,9 @@ export function estimateTokenCountSync(
let totalTokens = 0;
for (const part of parts) {
if (typeof part.text === 'string') {
totalTokens += estimateTextTokens(part.text, charsPerToken);
totalTokens += estimateTextTokens(part.text);
} else if (part.functionResponse) {
totalTokens += estimateFunctionResponseTokens(part, depth, charsPerToken);
totalTokens += estimateFunctionResponseTokens(part, depth);
} else {
const mediaEstimate = estimateMediaTokens(part);
if (mediaEstimate !== undefined) {
@@ -134,7 +123,7 @@ export function estimateTokenCountSync(
} else {
// Fallback for other non-text parts (e.g., functionCall).
// Note: JSON.stringify(part) here is safe as these parts are typically small.
totalTokens += JSON.stringify(part).length / charsPerToken;
totalTokens += JSON.stringify(part).length / 4;
}
}
}
@@ -173,9 +162,9 @@ export async function calculateRequestTokenCount(
} catch (error) {
// Fallback to local estimation if the API call fails
debugLogger.debug('countTokens API failed:', error);
return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN);
return estimateTokenCountSync(parts);
}
}
return estimateTokenCountSync(parts, 0, DEFAULT_CHARS_PER_TOKEN);
return estimateTokenCountSync(parts);
}