refactor(core): Centralize context management logic into src/context (#24380)

This commit is contained in:
joshualitt
2026-03-31 17:01:46 -07:00
committed by GitHub
parent cdc602edd7
commit fd5c103f99
21 changed files with 51 additions and 20 deletions
@@ -0,0 +1,31 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`ToolOutputMaskingService > should match the expected snapshot for a masked tool output 1`] = `
"<tool_output_masked>
Line
Line
Line
Line
Line
Line
Line
Line
Line
Line
... [6 lines omitted] ...
Line
Line
Line
Line
Line
Line
Line
Line
Line
Output too large. Full output available at: /mock/temp/tool-outputs/session-mock-session/run_shell_command_deterministic.txt
</tool_output_masked>"
`;
@@ -0,0 +1,506 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { AgentHistoryProvider } from './agentHistoryProvider.js';
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
vi.mock('../utils/tokenCalculation.js', () => ({
estimateTokenCountSync: vi.fn(),
ASCII_TOKENS_PER_CHAR: 0.25,
NON_ASCII_TOKENS_PER_CHAR: 1.3,
}));
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import type { Config, ContextManagementConfig } from '../config/config.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { AgentHistoryProviderConfig } from '../services/types.js';
import {
TEXT_TRUNCATION_PREFIX,
TOOL_TRUNCATION_PREFIX,
truncateProportionally,
} from './truncation.js';
describe('AgentHistoryProvider', () => {
let config: Config;
let provider: AgentHistoryProvider;
let providerConfig: AgentHistoryProviderConfig;
let generateContentMock: ReturnType<typeof vi.fn>;
beforeEach(() => {
config = {
isExperimentalAgentHistoryTruncationEnabled: vi
.fn()
.mockReturnValue(false),
getContextManagementConfig: vi.fn().mockReturnValue(false),
getBaseLlmClient: vi.fn(),
} as unknown as Config;
// By default, messages are small
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 100,
);
generateContentMock = vi.fn().mockResolvedValue({
candidates: [{ content: { parts: [{ text: 'Mock intent summary' }] } }],
} as unknown as GenerateContentResponse);
config.getBaseLlmClient = vi.fn().mockReturnValue({
generateContent: generateContentMock,
} as unknown as BaseLlmClient);
providerConfig = {
maxTokens: 60000,
retainedTokens: 40000,
normalMessageTokens: 2500,
maximumMessageTokens: 10000,
normalizationHeadRatio: 0.2,
isSummarizationEnabled: false,
isTruncationEnabled: false,
};
provider = new AgentHistoryProvider(providerConfig, config);
});
const createMockHistory = (count: number): Content[] =>
Array.from({ length: count }).map((_, i) => ({
role: i % 2 === 0 ? 'user' : 'model',
parts: [{ text: `Message ${i}` }],
}));
it('should return history unchanged if truncation is disabled', async () => {
providerConfig.isTruncationEnabled = false;
const history = createMockHistory(40);
const result = await provider.manageHistory(history);
expect(result).toBe(history);
expect(result.length).toBe(40);
});
it('should return history unchanged if length is under threshold', async () => {
providerConfig.isTruncationEnabled = true;
const history = createMockHistory(20); // Threshold is 30
const result = await provider.manageHistory(history);
expect(result).toBe(history);
expect(result.length).toBe(20);
});
it('should truncate when total tokens exceed budget, preserving structural integrity', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.maxTokens = 60000;
providerConfig.retainedTokens = 60000;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: false,
} as unknown as ContextManagementConfig);
// Make each message cost 4000 tokens
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 4000,
);
const history = createMockHistory(35); // 35 * 4000 = 140,000 total tokens > maxTokens
const result = await provider.manageHistory(history);
// Budget = 60000. Each message costs 4000. 60000 / 4000 = 15.
// However, some messages get normalized.
// The grace period is 15 messages. Their target is MAXIMUM_MESSAGE_TOKENS (10000).
// So the 15 newest messages remain at 4000 tokens each.
// That's 15 * 4000 = 60000 tokens EXACTLY!
// The next older message will push it over budget.
// So EXACTLY 15 messages will be retained.
// If the 15th newest message is a user message with a functionResponse, it might pull in the model call.
// In our createMockHistory, we don't use functionResponses.
expect(result.length).toBe(15);
expect(generateContentMock).not.toHaveBeenCalled();
expect(result[0].role).toBe('user');
expect(result[0].parts![0].text).toContain(
'### [System Note: Conversation History Truncated]',
);
});
it('should call summarizer and prepend summary when summarization is enabled', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.isSummarizationEnabled = true;
providerConfig.maxTokens = 60000;
providerConfig.retainedTokens = 60000;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: true,
} as unknown as ContextManagementConfig);
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 4000,
);
const history = createMockHistory(35);
const result = await provider.manageHistory(history);
expect(generateContentMock).toHaveBeenCalled();
expect(result.length).toBe(15);
expect(result[0].role).toBe('user');
expect(result[0].parts![0].text).toContain('<intent_summary>');
expect(result[0].parts![0].text).toContain('Mock intent summary');
});
it('should handle summarizer failures gracefully', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.isSummarizationEnabled = true;
providerConfig.maxTokens = 60000;
providerConfig.retainedTokens = 60000;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: true,
} as unknown as ContextManagementConfig);
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 4000,
);
generateContentMock.mockRejectedValue(new Error('API Error'));
const history = createMockHistory(35);
const result = await provider.manageHistory(history);
expect(generateContentMock).toHaveBeenCalled();
expect(result.length).toBe(15);
// Should fallback to fallback text
expect(result[0].parts![0].text).toContain(
'[System Note: Conversation History Truncated]',
);
});
it('should pass the contextual bridge to the summarizer', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.isSummarizationEnabled = true;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: true,
} as unknown as ContextManagementConfig);
// Max tokens 30 means if total tokens > 30, it WILL truncate.
providerConfig.maxTokens = 30;
// budget 20 tokens means it will keep 2 messages if they are 10 each.
providerConfig.retainedTokens = 20;
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 10,
);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'Old Message' }] },
{ role: 'model', parts: [{ text: 'Old Response' }] },
{ role: 'user', parts: [{ text: 'Keep 1' }] },
{ role: 'user', parts: [{ text: 'Keep 2' }] },
];
await provider.manageHistory(history);
expect(generateContentMock).toHaveBeenCalled();
const callArgs = generateContentMock.mock.calls[0][0];
const prompt = callArgs.contents[0].parts[0].text;
expect(prompt).toContain('ACTIVE BRIDGE (LOOKAHEAD):');
expect(prompt).toContain('Keep 1');
expect(prompt).toContain('Keep 2');
});
it('should detect a previous summary in the truncated head', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.isSummarizationEnabled = true;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: true,
} as unknown as ContextManagementConfig);
providerConfig.maxTokens = 20;
providerConfig.retainedTokens = 10;
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 10,
);
const history: Content[] = [
{
role: 'user',
parts: [{ text: '<intent_summary>Previous Mandate</intent_summary>' }],
},
{ role: 'model', parts: [{ text: 'Work' }] },
{ role: 'user', parts: [{ text: 'New Work' }] },
];
await provider.manageHistory(history);
expect(generateContentMock).toHaveBeenCalled();
const callArgs = generateContentMock.mock.calls[0][0];
const prompt = callArgs.contents[0].parts[0].text;
expect(prompt).toContain('1. **Previous Summary:**');
expect(prompt).toContain('PREVIOUS SUMMARY AND TRUNCATED HISTORY:');
});
it('should include the Action Path (necklace of function names) in the prompt', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.isSummarizationEnabled = true;
vi.spyOn(config, 'getContextManagementConfig').mockReturnValue({
enabled: true,
} as unknown as ContextManagementConfig);
providerConfig.maxTokens = 20;
providerConfig.retainedTokens = 10;
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: Part[]) => parts.length * 10,
);
const history: Content[] = [
{
role: 'model',
parts: [
{ functionCall: { name: 'tool_a', args: {} } },
{ functionCall: { name: 'tool_b', args: {} } },
],
},
{ role: 'user', parts: [{ text: 'Keep' }] },
];
await provider.manageHistory(history);
expect(generateContentMock).toHaveBeenCalled();
const callArgs = generateContentMock.mock.calls[0][0];
const prompt = callArgs.contents[0].parts[0].text;
expect(prompt).toContain('The Action Path:');
expect(prompt).toContain('tool_a → tool_b');
});
describe('Tiered Normalization Logic', () => {
it('normalizes large messages incrementally: newest and exit-grace', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.retainedTokens = 30000;
providerConfig.maximumMessageTokens = 10000;
providerConfig.normalMessageTokens = 2500; // History of 35 messages.
// Index 34: Newest (Grace Zone) -> Target 10000 tokens (~40000 chars)
// Index 19: Exit Grace (35-1-15=19) -> Target 2500 tokens (~10000 chars)
// Index 10: Archived -> Should NOT be normalized in this turn (Incremental optimization)
const history = createMockHistory(35);
const hugeText = 'H'.repeat(100000);
history[34] = { role: 'user', parts: [{ text: hugeText }] };
history[19] = { role: 'model', parts: [{ text: hugeText }] };
history[10] = { role: 'user', parts: [{ text: hugeText }] };
// Mock token count to trigger normalization (100k chars = 25k tokens @ 4 chars/token)
vi.mocked(estimateTokenCountSync).mockImplementation((parts: Part[]) => {
if (!parts?.[0]) return 10;
const text = parts[0].text || '';
if (text.startsWith('H')) return 25000;
return 10;
});
const result = await provider.manageHistory(history);
// 1. Newest message (index 34) normalized to ~40000 chars
const normalizedLast = result[34].parts![0].text!;
expect(normalizedLast).toContain(TEXT_TRUNCATION_PREFIX);
expect(normalizedLast.length).toBeLessThan(50000);
expect(normalizedLast.length).toBeGreaterThan(30000);
// 2. Exit grace message (index 19) normalized to ~10000 chars
const normalizedArchived = result[19].parts![0].text!;
expect(normalizedArchived).toContain(TEXT_TRUNCATION_PREFIX);
expect(normalizedArchived.length).toBeLessThan(15000);
expect(normalizedArchived.length).toBeGreaterThan(8000);
// 3. Archived message (index 10) IS touched and normalized to ~10000 chars
const normalizedPastArchived = result[10].parts![0].text!;
expect(normalizedPastArchived).toContain(TEXT_TRUNCATION_PREFIX);
expect(normalizedPastArchived.length).toBeLessThan(15000);
expect(normalizedPastArchived.length).toBeGreaterThan(8000);
});
it('normalize function responses correctly by targeting large string values', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.maximumMessageTokens = 1000;
const hugeValue = 'O'.repeat(5000);
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'test_tool',
id: '1',
response: {
stdout: hugeValue,
stderr: 'small error',
exitCode: 0,
},
},
},
],
},
];
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: readonly Part[]) => {
if (parts?.[0]?.functionResponse) return 5000;
return 10;
},
);
const result = await provider.manageHistory(history);
const fr = result[0].parts![0].functionResponse!;
const resp = fr.response as Record<string, unknown>;
// stdout should be truncated
expect(resp['stdout']).toContain(TOOL_TRUNCATION_PREFIX);
expect((resp['stdout'] as string).length).toBeLessThan(hugeValue.length);
// stderr and exitCode should be PRESERVED (JSON integrity)
expect(resp['stderr']).toBe('small error');
expect(resp['exitCode']).toBe(0);
// Schema should be intact
expect(fr.name).toBe('test_tool');
expect(fr.id).toBe('1');
});
});
describe('truncateProportionally', () => {
it('returns original string if under target chars', () => {
const str = 'A'.repeat(50);
expect(truncateProportionally(str, 100, TEXT_TRUNCATION_PREFIX)).toBe(
str,
);
});
it('truncates proportionally with prefix and ellipsis', () => {
const str = 'A'.repeat(500) + 'B'.repeat(500); // 1000 chars
const target = 100;
const result = truncateProportionally(
str,
target,
TEXT_TRUNCATION_PREFIX,
);
expect(result.startsWith(TEXT_TRUNCATION_PREFIX)).toBe(true);
expect(result).toContain('\n...\n');
// The prefix and ellipsis take up some space
// It should keep ~20% head and ~80% tail of the *available* space
const ellipsis = '\n...\n';
const overhead = TEXT_TRUNCATION_PREFIX.length + ellipsis.length + 1; // +1 for the newline after prefix
const availableChars = Math.max(0, target - overhead);
const expectedHeadChars = Math.floor(availableChars * 0.2);
const expectedTailChars = availableChars - expectedHeadChars;
// Extract parts around the ellipsis
const parts = result.split(ellipsis);
expect(parts.length).toBe(2);
// Remove prefix + newline from the first part to check head length
const actualHead = parts[0].replace(TEXT_TRUNCATION_PREFIX + '\n', '');
const actualTail = parts[1];
expect(actualHead.length).toBe(expectedHeadChars);
expect(actualTail.length).toBe(expectedTailChars);
});
it('handles very small targets gracefully by just returning prefix', () => {
const str = 'A'.repeat(100);
const result = truncateProportionally(str, 10, TEXT_TRUNCATION_PREFIX);
expect(result).toBe(TEXT_TRUNCATION_PREFIX);
});
});
describe('Multi-part Proportional Normalization', () => {
it('distributes token budget proportionally across multiple large parts', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.maximumMessageTokens = 2500; // Small limit to trigger normalization on last msg
const history = createMockHistory(35);
// Make newest message (index 34) have two large parts
// Part 1: 10000 chars (~2500 tokens at 4 chars/token)
// Part 2: 30000 chars (~7500 tokens at 4 chars/token)
// Total tokens = 10000. Target = 2500. Ratio = 0.25.
const part1Text = 'A'.repeat(10000);
const part2Text = 'B'.repeat(30000);
history[34] = {
role: 'user',
parts: [{ text: part1Text }, { text: part2Text }],
};
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: readonly Part[]) => {
if (!parts || parts.length === 0) return 0;
let tokens = 0;
for (const p of parts) {
if (p.text?.startsWith('A')) tokens += 2500;
else if (p.text?.startsWith('B')) tokens += 7500;
else tokens += 10;
}
return tokens;
},
);
const result = await provider.manageHistory(history);
const normalizedMsg = result[34];
expect(normalizedMsg.parts!.length).toBe(2);
const p1 = normalizedMsg.parts![0].text!;
const p2 = normalizedMsg.parts![1].text!;
expect(p1).toContain(TEXT_TRUNCATION_PREFIX);
expect(p2).toContain(TEXT_TRUNCATION_PREFIX);
// Part 1: Target chars ~ 2500 * 0.25 * 4 = 2500
// Part 2: Target chars ~ 7500 * 0.25 * 4 = 7500
expect(p1.length).toBeLessThan(3500);
expect(p2.length).toBeLessThan(9000);
expect(p1.length).toBeLessThan(p2.length);
});
it('preserves small parts while truncating large parts in the same message', async () => {
providerConfig.isTruncationEnabled = true;
providerConfig.maximumMessageTokens = 2500;
const history = createMockHistory(35);
const smallText = 'Hello I am small';
const hugeText = 'B'.repeat(40000); // 10000 tokens
history[34] = {
role: 'user',
parts: [{ text: smallText }, { text: hugeText }],
};
vi.mocked(estimateTokenCountSync).mockImplementation(
(parts: readonly Part[]) => {
if (!parts || parts.length === 0) return 0;
let tokens = 0;
for (const p of parts) {
if (p.text === smallText) tokens += 10;
else if (p.text?.startsWith('B')) tokens += 10000;
else tokens += 10;
}
return tokens;
},
);
const result = await provider.manageHistory(history);
const normalizedMsg = result[34];
expect(normalizedMsg.parts!.length).toBe(2);
const p1 = normalizedMsg.parts![0].text!;
const p2 = normalizedMsg.parts![1].text!;
// Small part should be preserved
expect(p1).toBe(smallText);
// Huge part should be truncated
expect(p2).toContain(TEXT_TRUNCATION_PREFIX);
// Target tokens for huge part = ~2500 * (10000/10010) = ~2500
// Target chars = ~10000
expect(p2.length).toBeLessThan(12000);
});
});
});
@@ -0,0 +1,422 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, Part } from '@google/genai';
import { getResponseText } from '../utils/partUtils.js';
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
import { LlmRole } from '../telemetry/llmRole.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { AgentHistoryProviderConfig } from '../services/types.js';
import type { Config } from '../config/config.js';
import {
MIN_TARGET_TOKENS,
MIN_CHARS_FOR_TRUNCATION,
TEXT_TRUNCATION_PREFIX,
estimateCharsFromTokens,
truncateProportionally,
normalizeFunctionResponse,
} from './truncation.js';
export class AgentHistoryProvider {
// TODO(joshualitt): just pass the BaseLlmClient instead of the whole Config.
constructor(
private readonly providerConfig: AgentHistoryProviderConfig,
private readonly config: Config,
) {}
/**
* Evaluates the chat history and performs truncation and summarization if necessary.
* Returns a new array of Content if truncation occurred, otherwise returns the original array.
*/
async manageHistory(
history: readonly Content[],
abortSignal?: AbortSignal,
): Promise<readonly Content[]> {
if (!this.providerConfig.isTruncationEnabled || history.length === 0) {
return history;
}
// Step 1: Normalize newest messages.
const normalizedHistory = this.enforceMessageSizeLimits(history);
const totalTokens = estimateTokenCountSync(
normalizedHistory.flatMap((c) => c.parts || []),
);
// Step 2: Check if truncation is needed based on the token threshold (High Watermark)
if (totalTokens <= this.providerConfig.maxTokens) {
return normalizedHistory;
}
// Step 3: Split into keep/truncate boundaries
const { messagesToKeep, messagesToTruncate } =
this.splitHistoryForTruncation(normalizedHistory);
if (messagesToTruncate.length === 0) {
return messagesToKeep;
}
debugLogger.log(
`AgentHistoryProvider: Truncating ${messagesToTruncate.length} messages, retaining ${messagesToKeep.length} messages.`,
);
const summaryText = await this.getSummaryText(
messagesToTruncate,
messagesToKeep,
abortSignal,
);
return this.mergeSummaryWithHistory(summaryText, messagesToKeep);
}
/**
* Enforces message size limits on the most recent message and the message
* that just exited the grace zone.
* - Recent messages have a high MAXIMUM limit.
* - Older messages (already processed) are restricted to the NORMAL limit
* once they exit the grace period.
*/
private enforceMessageSizeLimits(
history: readonly Content[],
): readonly Content[] {
if (history.length === 0) return history;
let hasChanges = false;
let accumulatedTokens = 0;
// Scan backwards to find the index where the token budget is exhausted
let graceStartIndex = 0;
for (let i = history.length - 1; i >= 0; i--) {
const msgTokens = estimateTokenCountSync(history[i].parts || []);
accumulatedTokens += msgTokens;
if (accumulatedTokens > this.providerConfig.retainedTokens) {
graceStartIndex = i + 1;
break;
}
}
const newHistory = history.map((msg, i) => {
const targetTokens =
i < graceStartIndex
? this.providerConfig.normalMessageTokens
: this.providerConfig.maximumMessageTokens;
const normalizedMsg = this.normalizeMessage(msg, targetTokens);
if (normalizedMsg !== msg) {
hasChanges = true;
}
return normalizedMsg;
});
return hasChanges ? newHistory : history;
}
/**
* Normalizes a message by proportionally masking its text or function response
* if its total token count exceeds the target token limit.
*/
private normalizeMessage(msg: Content, targetTokens: number): Content {
const currentTokens = estimateTokenCountSync(msg.parts || []);
if (currentTokens <= targetTokens) {
return msg;
}
// Calculate the compression ratio to apply to all large parts
const ratio = targetTokens / currentTokens;
// Proportional compression of the parts to fit the targetTokens budget
// while maintaining API structure (never dropping a part completely).
const newParts: Part[] = [];
for (const part of msg.parts || []) {
if (part.text) {
const partTokens = estimateTokenCountSync([part]);
const targetPartTokens = Math.max(
MIN_TARGET_TOKENS,
Math.floor(partTokens * ratio),
);
const targetChars = estimateCharsFromTokens(
part.text,
targetPartTokens,
);
if (
part.text.length > targetChars &&
targetChars > MIN_CHARS_FOR_TRUNCATION
) {
const newText = truncateProportionally(
part.text,
targetChars,
TEXT_TRUNCATION_PREFIX,
this.providerConfig.normalizationHeadRatio,
);
newParts.push({ text: newText });
} else {
newParts.push(part);
}
} else if (part.functionResponse) {
newParts.push(
normalizeFunctionResponse(
part,
ratio,
this.providerConfig.normalizationHeadRatio,
),
);
} else {
newParts.push(part);
}
}
return { ...msg, parts: newParts };
}
/**
* Determines the boundary for splitting history based on the token budget,
* keeping recent messages under a specific target token threshold,
* while ensuring structural integrity (e.g. keeping functionCall/functionResponse pairs).
*/
private splitHistoryForTruncation(history: readonly Content[]): {
messagesToKeep: readonly Content[];
messagesToTruncate: readonly Content[];
} {
let accumulatedTokens = 0;
let truncationBoundary = 0; // The index of the first message to keep
// Scan backwards to calculate the boundary based on token budget
for (let i = history.length - 1; i >= 0; i--) {
const msg = history[i];
const msgTokens = estimateTokenCountSync(msg.parts || []);
// Token Budget
if (accumulatedTokens + msgTokens > this.providerConfig.retainedTokens) {
// Exceeded budget, stop retaining messages here.
truncationBoundary = i + 1;
break;
}
accumulatedTokens += msgTokens;
}
// Ensure structural integrity of the boundary
truncationBoundary = this.adjustBoundaryForIntegrity(
history,
truncationBoundary,
);
const messagesToKeep = history.slice(truncationBoundary);
const messagesToTruncate = history.slice(0, truncationBoundary);
return {
messagesToKeep,
messagesToTruncate,
};
}
/**
* Adjusts the truncation boundary backwards to prevent breaking functionCall/functionResponse pairs.
*/
private adjustBoundaryForIntegrity(
history: readonly Content[],
boundary: number,
): number {
let currentBoundary = boundary;
// Ensure we don't start at index 0 or out of bounds.
if (currentBoundary <= 0 || currentBoundary >= history.length) {
return currentBoundary;
}
while (
currentBoundary > 0 &&
currentBoundary < history.length &&
history[currentBoundary].role === 'user' &&
history[currentBoundary].parts?.some((p) => p.functionResponse) &&
history[currentBoundary - 1].role === 'model' &&
history[currentBoundary - 1].parts?.some((p) => p.functionCall)
) {
currentBoundary--; // Include the functionCall in the retained history
}
return currentBoundary;
}
private getFallbackSummaryText(
messagesToTruncate: readonly Content[],
): string {
const userMessages = messagesToTruncate.filter((m) => m.role === 'user');
const modelMessages = messagesToTruncate.filter((m) => m.role === 'model');
const lastUserText = userMessages
.slice(-1)[0]
?.parts?.map((p) => p.text || '')
.join('')
.trim();
const actionPath = modelMessages
.flatMap(
(m) =>
m.parts
?.filter((p) => p.functionCall)
.map((p) => p.functionCall!.name) || [],
)
.join(' → ');
const summaryParts = [
'### [System Note: Conversation History Truncated]',
'Prior context was offloaded to maintain performance. Key highlights from the truncated history:',
];
if (lastUserText) {
summaryParts.push(`- **Last User Intent:** "${lastUserText}"`);
}
if (actionPath) {
summaryParts.push(`- **Action Path:** ${actionPath}`);
}
summaryParts.push(
'- **Notice:** For deeper context, review persistent memory or task-specific logs.',
);
return summaryParts.join('\n');
}
private async getSummaryText(
messagesToTruncate: readonly Content[],
messagesToKeep: readonly Content[],
abortSignal?: AbortSignal,
): Promise<string> {
if (messagesToTruncate.length === 0) return '';
if (!this.providerConfig.isSummarizationEnabled) {
debugLogger.log(
'AgentHistoryProvider: Summarization disabled, using fallback note.',
);
return this.getFallbackSummaryText(messagesToTruncate);
}
try {
// Use the first few messages of the Grace Zone as a "contextual bridge"
// to give the summarizer lookahead into the current state.
const bridge = messagesToKeep.slice(0, 5);
return await this.generateIntentSummary(
messagesToTruncate,
bridge,
abortSignal,
);
} catch (error) {
debugLogger.log('AgentHistoryProvider: Summarization failed.', error);
return this.getFallbackSummaryText(messagesToTruncate);
}
}
private mergeSummaryWithHistory(
summaryText: string,
messagesToKeep: readonly Content[],
): readonly Content[] {
if (!summaryText) return messagesToKeep;
if (messagesToKeep.length === 0) {
return [{ role: 'user', parts: [{ text: summaryText }] }];
}
// To ensure strict user/model alternating roles required by the Gemini API,
// we merge the summary into the first retained message if it's from the 'user'.
const firstRetainedMessage = messagesToKeep[0];
if (firstRetainedMessage.role === 'user') {
const mergedParts = [
{ text: summaryText },
...(firstRetainedMessage.parts || []),
];
const mergedMessage: Content = {
role: 'user',
parts: mergedParts,
};
return [mergedMessage, ...messagesToKeep.slice(1)];
} else {
const summaryMessage: Content = {
role: 'user',
parts: [{ text: summaryText }],
};
return [summaryMessage, ...messagesToKeep];
}
}
private async generateIntentSummary(
messagesToTruncate: readonly Content[],
bridge: readonly Content[],
abortSignal?: AbortSignal,
): Promise<string> {
// 1. Identify and extract any existing summary from the truncated head
const firstMsg = messagesToTruncate[0];
const firstPartText = firstMsg?.parts?.[0]?.text || '';
const hasPreviousSummary = firstPartText.includes('<intent_summary>');
// 2. Extract "The Action Path" (necklace of function names)
const actionPath = messagesToTruncate
.filter((m) => m.role === 'model')
.flatMap(
(m) =>
m.parts
?.filter((p) => p.functionCall)
.map((p) => p.functionCall!.name) || [],
)
.join(' → ');
const prompt = `### State Update: Agent Continuity
The conversation history has been truncated. You are generating a highly factual state summary to preserve the agent's exact working context.
You have these signals to synthesize:
${hasPreviousSummary ? '1. **Previous Summary:** The existing state before this truncation.\n' : ''}2. **The Action Path:** A chronological list of tools called: [${actionPath}]
3. **Truncated History:** The specific actions, tool inputs, and tool outputs being offloaded.
4. **Active Bridge:** The first few turns of the "Grace Zone" (what follows immediately after this summary), showing the current tactical moment.
### Your Goal:
Distill these into a high-density Markdown block that orientates the agent on the CONCRETE STATE of the workspace:
- **Primary Goal:** The ultimate objective requested by the user.
- **Verified Facts:** What has been definitively completed or proven (e.g., "File X was created", "Bug Y was reproduced").
- **Working Set:** The exact file paths currently being analyzed or modified.
- **Active Blockers:** Exact error messages or failing test names currently preventing progress.
### Constraints:
- **Format:** Wrap the entire response in <intent_summary> tags.
- **Factuality:** Base all points strictly on the provided history. Do not invent rationale or assume success without proof. Use exact names and quotes.
- **Brevity:** Maximum 15 lines. No conversational preamble.
${hasPreviousSummary ? 'PREVIOUS SUMMARY AND TRUNCATED HISTORY:' : 'TRUNCATED HISTORY:'}
${JSON.stringify(messagesToTruncate)}
ACTIVE BRIDGE (LOOKAHEAD):
${JSON.stringify(bridge)}`;
const summaryResponse = await this.config
.getBaseLlmClient()
.generateContent({
modelConfigKey: { model: 'agent-history-provider-summarizer' },
contents: [
{
role: 'user',
parts: [{ text: prompt }],
},
],
promptId: 'agent-history-provider',
abortSignal: abortSignal ?? new AbortController().signal,
role: LlmRole.UTILITY_COMPRESSOR,
});
let summary = getResponseText(summaryResponse) ?? '';
// Clean up if the model included extra tags or markdown
summary = summary
.replace(/```markdown/g, '')
.replace(/```/g, '')
.trim();
if (!summary.includes('<intent_summary>')) {
summary = `<intent_summary>\n${summary}\n</intent_summary>`;
}
return summary;
}
}
@@ -0,0 +1,900 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
ChatCompressionService,
findCompressSplitPoint,
modelStringToModelConfigAlias,
} from './chatCompressionService.js';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import { CompressionStatus } from '../core/turn.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { GeminiChat } from '../core/geminiChat.js';
import type { Config } from '../config/config.js';
import * as fileUtils from '../utils/fileUtils.js';
import { getInitialChatHistory } from '../utils/environmentContext.js';
const { TOOL_OUTPUTS_DIR } = fileUtils;
import * as tokenCalculation from '../utils/tokenCalculation.js';
import { tokenLimit } from '../core/tokenLimits.js';
import os from 'node:os';
import path from 'node:path';
import fs from 'node:fs';
vi.mock('../telemetry/loggers.js');
vi.mock('../utils/environmentContext.js');
vi.mock('../core/tokenLimits.js');
describe('findCompressSplitPoint', () => {
it('should throw an error for non-positive numbers', () => {
expect(() => findCompressSplitPoint([], 0)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should throw an error for a fraction greater than or equal to 1', () => {
expect(() => findCompressSplitPoint([], 1)).toThrow(
'Fraction must be between 0 and 1',
);
});
it('should handle an empty history', () => {
expect(findCompressSplitPoint([], 0.5)).toBe(0);
});
it('should handle a fraction in the middle', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
];
expect(findCompressSplitPoint(history, 0.5)).toBe(4);
});
it('should handle a fraction of last index', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
];
expect(findCompressSplitPoint(history, 0.9)).toBe(4);
});
it('should handle a fraction of after last index', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (24%)
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (50%)
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (74%)
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (100%)
];
expect(findCompressSplitPoint(history, 0.8)).toBe(4);
});
it('should return earlier splitpoint if no valid ones are after threshold', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
{ role: 'model', parts: [{ functionCall: { name: 'foo', args: {} } }] },
];
// Can't return 4 because the previous item has a function call.
expect(findCompressSplitPoint(history, 0.99)).toBe(2);
});
it('should handle a history with only one item', () => {
const historyWithEmptyParts: Content[] = [
{ role: 'user', parts: [{ text: 'Message 1' }] },
];
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(0);
});
it('should handle history with weird parts', () => {
const historyWithEmptyParts: Content[] = [
{ role: 'user', parts: [{ text: 'Message 1' }] },
{
role: 'model',
parts: [{ fileData: { fileUri: 'derp', mimeType: 'text/plain' } }],
},
{ role: 'user', parts: [{ text: 'Message 2' }] },
];
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(2);
});
});
describe('modelStringToModelConfigAlias', () => {
it('should return the default model for unexpected aliases', () => {
expect(modelStringToModelConfigAlias('gemini-flash-flash')).toBe(
'chat-compression-default',
);
});
it('should handle valid names', () => {
expect(modelStringToModelConfigAlias('gemini-3-pro-preview')).toBe(
'chat-compression-3-pro',
);
expect(modelStringToModelConfigAlias('gemini-2.5-pro')).toBe(
'chat-compression-2.5-pro',
);
expect(modelStringToModelConfigAlias('gemini-2.5-flash')).toBe(
'chat-compression-2.5-flash',
);
expect(modelStringToModelConfigAlias('gemini-2.5-flash-lite')).toBe(
'chat-compression-2.5-flash-lite',
);
});
});
describe('ChatCompressionService', () => {
let service: ChatCompressionService;
let mockChat: GeminiChat;
let mockConfig: Config;
let testTempDir: string;
const mockModel = 'gemini-2.5-pro';
const mockPromptId = 'test-prompt-id';
beforeEach(() => {
testTempDir = fs.mkdtempSync(
path.join(os.tmpdir(), 'chat-compression-test-'),
);
service = new ChatCompressionService();
mockChat = {
getHistory: vi.fn(),
getLastPromptTokenCount: vi.fn().mockReturnValue(500),
} as unknown as GeminiChat;
const mockGenerateContent = vi
.fn()
.mockResolvedValueOnce({
candidates: [
{
content: {
parts: [{ text: 'Initial Summary' }],
},
},
],
} as unknown as GenerateContentResponse)
.mockResolvedValueOnce({
candidates: [
{
content: {
parts: [{ text: 'Verified Summary' }],
},
},
],
} as unknown as GenerateContentResponse);
mockConfig = {
get config() {
return this;
},
getCompressionThreshold: vi.fn(),
getBaseLlmClient: vi.fn().mockReturnValue({
generateContent: mockGenerateContent,
}),
isInteractive: vi.fn().mockReturnValue(false),
getActiveModel: vi.fn().mockReturnValue(mockModel),
getContentGenerator: vi.fn().mockReturnValue({
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
}),
getEnableHooks: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn().mockReturnValue(undefined),
getHookSystem: () => undefined,
getNextCompressionTruncationId: vi.fn().mockReturnValue(1),
getTruncateToolOutputThreshold: vi.fn().mockReturnValue(40000),
storage: {
getProjectTempDir: vi.fn().mockReturnValue(testTempDir),
},
getApprovedPlanPath: vi.fn().mockReturnValue('/path/to/plan.md'),
} as unknown as Config;
vi.mocked(getInitialChatHistory).mockImplementation(
async (_config, extraHistory) => extraHistory || [],
);
});
afterEach(() => {
vi.restoreAllMocks();
if (fs.existsSync(testTempDir)) {
fs.rmSync(testTempDir, { recursive: true, force: true });
}
});
it('should return NOOP if history is empty', async () => {
vi.mocked(mockChat.getHistory).mockReturnValue([]);
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
expect(result.newHistory).toBeNull();
});
it('should return NOOP if previously failed and not forced', async () => {
vi.mocked(mockChat.getHistory).mockReturnValue([
{ role: 'user', parts: [{ text: 'hi' }] },
]);
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
// It should now attempt compression even if previously failed (logic removed)
// But since history is small, it will be NOOP due to threshold
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
expect(result.newHistory).toBeNull();
});
it('should return NOOP if under token threshold and not forced', async () => {
vi.mocked(mockChat.getHistory).mockReturnValue([
{ role: 'user', parts: [{ text: 'hi' }] },
]);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600);
vi.mocked(tokenLimit).mockReturnValue(1000);
// Threshold is 0.5 * 1000 = 500. 600 > 500, so it SHOULD compress.
// Wait, the default threshold is 0.5.
// Let's set it explicitly.
vi.mocked(mockConfig.getCompressionThreshold).mockResolvedValue(0.7);
// 600 < 700, so NOOP.
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
expect(result.newHistory).toBeNull();
});
it('should compress if over token threshold with verification turn', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
{ role: 'user', parts: [{ text: 'msg3' }] },
{ role: 'model', parts: [{ text: 'msg4' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
// 600k > 500k (0.5 * 1M), so should compress.
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
expect(result.newHistory).not.toBeNull();
// It should contain the final verified summary
expect(result.newHistory![0].parts![0].text).toBe('Verified Summary');
expect(mockConfig.getBaseLlmClient().generateContent).toHaveBeenCalledTimes(
2,
);
});
it('should fall back to initial summary if verification response is empty', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
// Completely override the LLM client for this test to avoid conflicting with beforeEach mocks
const mockLlmClient = {
generateContent: vi
.fn()
.mockResolvedValueOnce({
candidates: [{ content: { parts: [{ text: 'Initial Summary' }] } }],
} as unknown as GenerateContentResponse)
.mockResolvedValueOnce({
candidates: [{ content: { parts: [{ text: ' ' }] } }],
} as unknown as GenerateContentResponse),
};
vi.mocked(mockConfig.getBaseLlmClient).mockReturnValue(
mockLlmClient as unknown as BaseLlmClient,
);
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
expect(result.newHistory![0].parts![0].text).toBe('Initial Summary');
});
it('should use anchored instruction when a previous snapshot is present', async () => {
const history: Content[] = [
{
role: 'user',
parts: [{ text: '<state_snapshot>old</state_snapshot>' }],
},
{ role: 'model', parts: [{ text: 'msg2' }] },
{ role: 'user', parts: [{ text: 'msg3' }] },
{ role: 'model', parts: [{ text: 'msg4' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(800);
vi.mocked(tokenLimit).mockReturnValue(1000);
await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
const firstCall = vi.mocked(mockConfig.getBaseLlmClient().generateContent)
.mock.calls[0][0];
const lastContent = firstCall.contents?.[firstCall.contents.length - 1];
expect(lastContent?.parts?.[0].text).toContain(
'A previous <state_snapshot> exists',
);
});
it('should include the approved plan path in the system instruction', async () => {
const planPath = '/custom/plan/path.md';
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(planPath);
vi.mocked(mockConfig.getActiveModel).mockReturnValue(
'gemini-3.1-pro-preview',
);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
const firstCallText = (
vi.mocked(mockConfig.getBaseLlmClient().generateContent).mock.calls[0][0]
.systemInstruction as Part
).text;
expect(firstCallText).toContain('### APPROVED PLAN PRESERVATION');
expect(firstCallText).toContain(planPath);
});
it('should not include the approved plan section if no approved plan path exists', async () => {
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(undefined);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
const firstCallText = (
vi.mocked(mockConfig.getBaseLlmClient().generateContent).mock.calls[0][0]
.systemInstruction as Part
).text;
expect(firstCallText).not.toContain('### APPROVED PLAN PRESERVATION');
});
it('should force compress even if under threshold', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
{ role: 'user', parts: [{ text: 'msg3' }] },
{ role: 'model', parts: [{ text: 'msg4' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(100);
const result = await service.compress(
mockChat,
mockPromptId,
true, // forced
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
expect(result.newHistory).not.toBeNull();
});
it('should return FAILED if new token count is inflated', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(100);
const longSummary = 'a'.repeat(1000); // Long summary to inflate token count
vi.mocked(mockConfig.getBaseLlmClient().generateContent).mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: longSummary }],
},
},
],
} as unknown as GenerateContentResponse);
// Inflate the token count by spying on calculateRequestTokenCount
vi.spyOn(tokenCalculation, 'calculateRequestTokenCount').mockResolvedValue(
10000,
);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
);
expect(result.newHistory).toBeNull();
});
it('should return COMPRESSION_FAILED_EMPTY_SUMMARY if summary is empty', async () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'msg1' }] },
{ role: 'model', parts: [{ text: 'msg2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(800);
vi.mocked(tokenLimit).mockReturnValue(1000);
// Completely override the LLM client for this test
const mockLlmClient = {
generateContent: vi.fn().mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: ' ' }],
},
},
],
} as unknown as GenerateContentResponse),
};
vi.mocked(mockConfig.getBaseLlmClient).mockReturnValue(
mockLlmClient as unknown as BaseLlmClient,
);
const result = await service.compress(
mockChat,
mockPromptId,
false,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(
CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY,
);
expect(result.newHistory).toBeNull();
});
describe('Reverse Token Budget Truncation', () => {
it('should truncate older function responses when budget is exceeded', async () => {
vi.mocked(mockConfig.getCompressionThreshold).mockResolvedValue(0.5);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
// Large response part that exceeds budget (40k tokens).
// Heuristic is roughly chars / 4, so 170k chars should exceed it.
const largeResponse = 'a'.repeat(170000);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'old msg' }] },
{ role: 'model', parts: [{ text: 'old resp' }] },
// History to keep
{ role: 'user', parts: [{ text: 'msg 1' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: largeResponse },
},
},
],
},
{ role: 'model', parts: [{ text: 'resp 2' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: largeResponse },
},
},
],
},
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
// Verify the new history contains the truncated message
const keptHistory = result.newHistory!.slice(2); // After summary and 'Got it'
const truncatedPart = keptHistory[1].parts![0].functionResponse;
expect(truncatedPart?.response?.['output']).toContain(
'Output too large.',
);
// Verify a file was actually created in the tool_output subdirectory
const toolOutputDir = path.join(testTempDir, TOOL_OUTPUTS_DIR);
const files = fs.readdirSync(toolOutputDir);
expect(files.length).toBeGreaterThan(0);
expect(files[0]).toMatch(/grep_.*\.txt/);
});
it('should correctly handle massive single-line strings inside JSON by using multi-line Elephant Line logic', async () => {
vi.mocked(mockConfig.getCompressionThreshold).mockResolvedValue(0.5);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
// 170,000 chars on a single line to exceed budget
const massiveSingleLine = 'a'.repeat(170000);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'old msg 1' }] },
{ role: 'model', parts: [{ text: 'old resp 1' }] },
{ role: 'user', parts: [{ text: 'old msg 2' }] },
{ role: 'model', parts: [{ text: 'old resp 2' }] },
{ role: 'user', parts: [{ text: 'old msg 3' }] },
{ role: 'model', parts: [{ text: 'old resp 3' }] },
{ role: 'user', parts: [{ text: 'msg 1' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'shell',
response: { output: massiveSingleLine },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 'shell',
response: { output: massiveSingleLine },
},
},
],
},
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
// Verify it compressed
expect(result.newHistory).not.toBeNull();
// Find the shell response in the kept history (the older one was truncated)
const keptHistory = result.newHistory!.slice(2); // after summary and 'Got it'
const shellResponse = keptHistory.find(
(h) =>
h.parts?.some((p) => p.functionResponse?.name === 'shell') &&
(h.parts?.[0].functionResponse?.response?.['output'] as string)
?.length < 100000,
);
const truncatedPart = shellResponse!.parts![0].functionResponse;
const content = truncatedPart?.response?.['output'] as string;
// DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD = 40000 -> head=8000 (20%), tail=32000 (80%)
expect(content).toContain(
'Showing first 8,000 and last 32,000 characters',
);
});
it('should use character-based truncation for massive single-line raw strings', async () => {
vi.mocked(mockConfig.getCompressionThreshold).mockResolvedValue(0.5);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
const massiveRawString = 'c'.repeat(170000);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'old msg 1' }] },
{ role: 'model', parts: [{ text: 'old resp 1' }] },
{ role: 'user', parts: [{ text: 'old msg 2' }] },
{ role: 'model', parts: [{ text: 'old resp 2' }] },
{ role: 'user', parts: [{ text: 'msg 1' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'raw_tool',
response: { content: massiveRawString },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 'raw_tool',
response: { content: massiveRawString },
},
},
],
},
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.newHistory).not.toBeNull();
const keptHistory = result.newHistory!.slice(2);
const rawResponse = keptHistory.find(
(h) =>
h.parts?.some((p) => p.functionResponse?.name === 'raw_tool') &&
(h.parts?.[0].functionResponse?.response?.['output'] as string)
?.length < 100000,
);
const truncatedPart = rawResponse!.parts![0].functionResponse;
const content = truncatedPart?.response?.['output'] as string;
// DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD = 40000 -> head=8000 (20%), tail=32000 (80%)
expect(content).toContain(
'Showing first 8,000 and last 32,000 characters',
);
});
it('should fallback to original content and still update budget if truncation fails', async () => {
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
const largeResponse = 'd'.repeat(170000);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'old msg 1' }] },
{ role: 'model', parts: [{ text: 'old resp 1' }] },
{ role: 'user', parts: [{ text: 'old msg 2' }] },
{ role: 'model', parts: [{ text: 'old resp 2' }] },
{ role: 'user', parts: [{ text: 'msg 1' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: largeResponse },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: largeResponse },
},
},
],
},
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
// Simulate failure in saving the truncated output
vi.spyOn(fileUtils, 'saveTruncatedToolOutput').mockRejectedValue(
new Error('Disk Full'),
);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
// Verify the new history contains the ORIGINAL message (not truncated)
const keptHistory = result.newHistory!.slice(2);
const toolResponseTurn = keptHistory.find((h) =>
h.parts?.some((p) => p.functionResponse?.name === 'grep'),
);
const preservedPart = toolResponseTurn!.parts![0].functionResponse;
expect(preservedPart?.response).toEqual({ content: largeResponse });
});
it('should use high-fidelity original history for summarization when under the limit, but truncated version for active window', async () => {
// Large response in the "to compress" section (first message)
// 300,000 chars is ~75k tokens, well under the 1,000,000 summarizer limit.
const massiveText = 'a'.repeat(300000);
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: massiveText },
},
},
],
},
// More history to ensure the first message is in the "to compress" group
{ role: 'user', parts: [{ text: 'msg 2' }] },
{ role: 'model', parts: [{ text: 'resp 2' }] },
{ role: 'user', parts: [{ text: 'preserved msg' }] },
{
role: 'user',
parts: [
{
functionResponse: {
name: 'massive_preserved',
response: { content: massiveText },
},
},
],
},
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(600000);
vi.mocked(tokenLimit).mockReturnValue(1_000_000);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
// 1. Verify that the summary was generated from the ORIGINAL high-fidelity history
const generateContentCall = vi.mocked(
mockConfig.getBaseLlmClient().generateContent,
).mock.calls[0][0];
const historySentToSummarizer = generateContentCall.contents;
const summarizerGrepResponse =
historySentToSummarizer[0].parts![0].functionResponse;
// Should be original content because total tokens < 1M
expect(summarizerGrepResponse?.response).toEqual({
content: massiveText,
});
// 2. Verify that the PRESERVED history (the active window) IS truncated
const keptHistory = result.newHistory!.slice(2); // Skip summary + ack
const preservedToolTurn = keptHistory.find((h) =>
h.parts?.some((p) => p.functionResponse?.name === 'massive_preserved'),
);
const preservedPart = preservedToolTurn!.parts![0].functionResponse;
expect(preservedPart?.response?.['output']).toContain(
'Output too large.',
);
});
it('should fall back to truncated history for summarization when original is massive (>1M tokens)', async () => {
// 5,000,000 chars is ~1.25M tokens, exceeding the 1M limit.
const superMassiveText = 'a'.repeat(5000000);
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'grep',
response: { content: superMassiveText },
},
},
],
},
{ role: 'user', parts: [{ text: 'msg 2' }] },
{ role: 'model', parts: [{ text: 'resp 2' }] },
];
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(tokenLimit).mockReturnValue(1_000_000);
const result = await service.compress(
mockChat,
mockPromptId,
true,
mockModel,
mockConfig,
false,
);
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
// Verify that the summary was generated from the TRUNCATED history
const generateContentCall = vi.mocked(
mockConfig.getBaseLlmClient().generateContent,
).mock.calls[0][0];
const historySentToSummarizer = generateContentCall.contents;
const summarizerGrepResponse =
historySentToSummarizer[0].parts![0].functionResponse;
// Should be truncated because original > 1M tokens
expect(summarizerGrepResponse?.response?.['output']).toContain(
'Output too large.',
);
});
});
});
@@ -0,0 +1,479 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content } from '@google/genai';
import type { Config } from '../config/config.js';
import type { GeminiChat } from '../core/geminiChat.js';
import { type ChatCompressionInfo, CompressionStatus } from '../core/turn.js';
import { tokenLimit } from '../core/tokenLimits.js';
import { getCompressionPrompt } from '../core/prompts.js';
import { getResponseText } from '../utils/partUtils.js';
import { logChatCompression } from '../telemetry/loggers.js';
import { makeChatCompressionEvent, LlmRole } from '../telemetry/types.js';
import {
saveTruncatedToolOutput,
formatTruncatedToolOutput,
} from '../utils/fileUtils.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getInitialChatHistory } from '../utils/environmentContext.js';
import {
calculateRequestTokenCount,
estimateTokenCountSync,
} from '../utils/tokenCalculation.js';
import {
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL,
} from '../config/models.js';
import { PreCompressTrigger } from '../hooks/types.js';
/**
* Default threshold for compression token count as a fraction of the model's
* token limit. If the chat history exceeds this threshold, it will be compressed.
*/
const DEFAULT_COMPRESSION_TOKEN_THRESHOLD = 0.5;
/**
* The fraction of the latest chat history to keep. A value of 0.3
* means that only the last 30% of the chat history will be kept after compression.
*/
const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
/**
* The budget for function response tokens in the preserved history.
*/
const COMPRESSION_FUNCTION_RESPONSE_TOKEN_BUDGET = 50_000;
/**
* Returns the index of the oldest item to keep when compressing. May return
* contents.length which indicates that everything should be compressed.
*
* Exported for testing purposes.
*/
export function findCompressSplitPoint(
contents: Content[],
fraction: number,
): number {
if (fraction <= 0 || fraction >= 1) {
throw new Error('Fraction must be between 0 and 1');
}
const charCounts = contents.map((content) => JSON.stringify(content).length);
const totalCharCount = charCounts.reduce((a, b) => a + b, 0);
const targetCharCount = totalCharCount * fraction;
let lastSplitPoint = 0; // 0 is always valid (compress nothing)
let cumulativeCharCount = 0;
for (let i = 0; i < contents.length; i++) {
const content = contents[i];
if (
content.role === 'user' &&
!content.parts?.some((part) => !!part.functionResponse)
) {
if (cumulativeCharCount >= targetCharCount) {
return i;
}
lastSplitPoint = i;
}
cumulativeCharCount += charCounts[i];
}
// We found no split points after targetCharCount.
// Check if it's safe to compress everything.
const lastContent = contents[contents.length - 1];
if (
lastContent?.role === 'model' &&
!lastContent?.parts?.some((part) => part.functionCall)
) {
return contents.length;
}
// Can't compress everything so just compress at last splitpoint.
return lastSplitPoint;
}
export function modelStringToModelConfigAlias(model: string): string {
switch (model) {
case PREVIEW_GEMINI_MODEL:
case PREVIEW_GEMINI_3_1_MODEL:
return 'chat-compression-3-pro';
case PREVIEW_GEMINI_FLASH_MODEL:
return 'chat-compression-3-flash';
case PREVIEW_GEMINI_3_1_FLASH_LITE_MODEL:
return 'chat-compression-3.1-flash-lite';
case DEFAULT_GEMINI_MODEL:
return 'chat-compression-2.5-pro';
case DEFAULT_GEMINI_FLASH_MODEL:
return 'chat-compression-2.5-flash';
case DEFAULT_GEMINI_FLASH_LITE_MODEL:
return 'chat-compression-2.5-flash-lite';
default:
return 'chat-compression-default';
}
}
/**
* Processes the chat history to ensure function responses don't exceed a specific token budget.
*
* This function implements a "Reverse Token Budget" strategy:
* 1. It iterates through the history from the most recent turn to the oldest.
* 2. It keeps a running tally of tokens used by function responses.
* 3. Recent tool outputs are preserved in full to maintain high-fidelity context for the current turn.
* 4. Once the budget (COMPRESSION_FUNCTION_RESPONSE_TOKEN_BUDGET) is exceeded, any older large
* tool responses are truncated to their last 30 lines and saved to a temporary file.
*
* This ensures that compression effectively reduces context size even when recent turns
* contain massive tool outputs (like large grep results or logs).
*/
async function truncateHistoryToBudget(
history: readonly Content[],
config: Config,
): Promise<Content[]> {
let functionResponseTokenCounter = 0;
const truncatedHistory: Content[] = [];
// Iterate backwards: newest messages first to prioritize their context.
for (let i = history.length - 1; i >= 0; i--) {
const content = history[i];
const newParts = [];
if (content.parts) {
// Process parts of the message backwards as well.
for (let j = content.parts.length - 1; j >= 0; j--) {
const part = content.parts[j];
if (part.functionResponse) {
const responseObj = part.functionResponse.response;
// Ensure we have a string representation to truncate.
// If the response is an object, we try to extract a primary string field (output or content).
let contentStr: string;
if (typeof responseObj === 'string') {
contentStr = responseObj;
} else if (responseObj && typeof responseObj === 'object') {
if (
'output' in responseObj &&
// eslint-disable-next-line no-restricted-syntax
typeof responseObj['output'] === 'string'
) {
contentStr = responseObj['output'];
} else if (
'content' in responseObj &&
// eslint-disable-next-line no-restricted-syntax
typeof responseObj['content'] === 'string'
) {
contentStr = responseObj['content'];
} else {
contentStr = JSON.stringify(responseObj, null, 2);
}
} else {
contentStr = JSON.stringify(responseObj, null, 2);
}
const tokens = estimateTokenCountSync([{ text: contentStr }]);
if (
functionResponseTokenCounter + tokens >
COMPRESSION_FUNCTION_RESPONSE_TOKEN_BUDGET
) {
try {
// Budget exceeded: Truncate this response.
const { outputFile } = await saveTruncatedToolOutput(
contentStr,
part.functionResponse.name ?? 'unknown_tool',
config.getNextCompressionTruncationId(),
config.storage.getProjectTempDir(),
);
const truncatedMessage = formatTruncatedToolOutput(
contentStr,
outputFile,
config.getTruncateToolOutputThreshold(),
);
newParts.unshift({
functionResponse: {
// eslint-disable-next-line @typescript-eslint/no-misused-spread
...part.functionResponse,
response: { output: truncatedMessage },
},
});
// Count the small truncated placeholder towards the budget.
functionResponseTokenCounter += estimateTokenCountSync([
{ text: truncatedMessage },
]);
} catch (error) {
// Fallback: if truncation fails, keep the original part to avoid data loss in the chat.
debugLogger.debug('Failed to truncate history to budget:', error);
newParts.unshift(part);
functionResponseTokenCounter += tokens;
}
} else {
// Within budget: keep the full response.
functionResponseTokenCounter += tokens;
newParts.unshift(part);
}
} else {
// Non-tool response part: always keep.
newParts.unshift(part);
}
}
}
// Reconstruct the message with processed (potentially truncated) parts.
truncatedHistory.unshift({ ...content, parts: newParts });
}
return truncatedHistory;
}
export class ChatCompressionService {
async compress(
chat: GeminiChat,
promptId: string,
force: boolean,
model: string,
config: Config,
hasFailedCompressionAttempt: boolean,
abortSignal?: AbortSignal,
): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> {
const curatedHistory = chat.getHistory(true);
// Regardless of `force`, don't do anything if the history is empty.
if (curatedHistory.length === 0) {
return {
newHistory: null,
info: {
originalTokenCount: 0,
newTokenCount: 0,
compressionStatus: CompressionStatus.NOOP,
},
};
}
// Fire PreCompress hook before compression
// This fires for both manual and auto compression attempts
const trigger = force ? PreCompressTrigger.Manual : PreCompressTrigger.Auto;
await config.getHookSystem()?.firePreCompressEvent(trigger);
const originalTokenCount = chat.getLastPromptTokenCount();
// Don't compress if not forced and we are under the limit.
if (!force) {
const threshold =
(await config.getCompressionThreshold()) ??
DEFAULT_COMPRESSION_TOKEN_THRESHOLD;
if (originalTokenCount < threshold * tokenLimit(model)) {
return {
newHistory: null,
info: {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus: CompressionStatus.NOOP,
},
};
}
}
// Apply token-based truncation to the entire history before splitting.
// This ensures that even the "to compress" portion is within safe limits for the summarization model.
const truncatedHistory = await truncateHistoryToBudget(
curatedHistory,
config,
);
// If summarization previously failed (and not forced), we only rely on truncation.
// We do NOT attempt to invoke the LLM for summarization again to avoid repeated failures/costs.
if (hasFailedCompressionAttempt && !force) {
const truncatedTokenCount = estimateTokenCountSync(
truncatedHistory.flatMap((c) => c.parts || []),
);
// If truncation reduced the size, we consider it a successful "compression" (truncation only).
if (truncatedTokenCount < originalTokenCount) {
return {
newHistory: truncatedHistory,
info: {
originalTokenCount,
newTokenCount: truncatedTokenCount,
compressionStatus: CompressionStatus.CONTENT_TRUNCATED,
},
};
}
return {
newHistory: null,
info: {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus: CompressionStatus.NOOP,
},
};
}
const splitPoint = findCompressSplitPoint(
truncatedHistory,
1 - COMPRESSION_PRESERVE_THRESHOLD,
);
const historyToCompressTruncated = truncatedHistory.slice(0, splitPoint);
const historyToKeepTruncated = truncatedHistory.slice(splitPoint);
if (historyToCompressTruncated.length === 0) {
return {
newHistory: null,
info: {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus: CompressionStatus.NOOP,
},
};
}
// High Fidelity Decision: Should we send the original or truncated history to the summarizer?
const originalHistoryToCompress = curatedHistory.slice(0, splitPoint);
const originalToCompressTokenCount = estimateTokenCountSync(
originalHistoryToCompress.flatMap((c) => c.parts || []),
);
const historyForSummarizer =
originalToCompressTokenCount < tokenLimit(model)
? originalHistoryToCompress
: historyToCompressTruncated;
const hasPreviousSnapshot = historyForSummarizer.some((c) =>
c.parts?.some((p) => p.text?.includes('<state_snapshot>')),
);
const anchorInstruction = hasPreviousSnapshot
? 'A previous <state_snapshot> exists in the history. You MUST integrate all still-relevant information from that snapshot into the new one, updating it with the more recent events. Do not lose established constraints or critical knowledge.'
: 'Generate a new <state_snapshot> based on the provided history.';
const summaryResponse = await config.getBaseLlmClient().generateContent({
modelConfigKey: { model: modelStringToModelConfigAlias(model) },
contents: [
...historyForSummarizer,
{
role: 'user',
parts: [
{
text: `${anchorInstruction}\n\nFirst, reason in your scratchpad. Then, generate the updated <state_snapshot>.`,
},
],
},
],
systemInstruction: { text: getCompressionPrompt(config) },
promptId,
// TODO(joshualitt): wire up a sensible abort signal,
abortSignal: abortSignal ?? new AbortController().signal,
role: LlmRole.UTILITY_COMPRESSOR,
});
const summary = getResponseText(summaryResponse) ?? '';
// Phase 3: The "Probe" Verification (Self-Correction)
// We perform a second lightweight turn to ensure no critical information was lost.
const verificationResponse = await config
.getBaseLlmClient()
.generateContent({
modelConfigKey: { model: modelStringToModelConfigAlias(model) },
contents: [
...historyForSummarizer,
{
role: 'model',
parts: [{ text: summary }],
},
{
role: 'user',
parts: [
{
text: 'Critically evaluate the <state_snapshot> you just generated. Did you omit any specific technical details, file paths, tool results, or user constraints mentioned in the history? If anything is missing or could be more precise, generate a FINAL, improved <state_snapshot>. Otherwise, repeat the exact same <state_snapshot> again.',
},
],
},
],
systemInstruction: { text: getCompressionPrompt(config) },
promptId: `${promptId}-verify`,
role: LlmRole.UTILITY_COMPRESSOR,
abortSignal: abortSignal ?? new AbortController().signal,
});
const finalSummary = (
getResponseText(verificationResponse)?.trim() || summary
).trim();
if (!finalSummary) {
logChatCompression(
config,
makeChatCompressionEvent({
tokens_before: originalTokenCount,
tokens_after: originalTokenCount, // No change since it failed
}),
);
return {
newHistory: null,
info: {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus: CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY,
},
};
}
const extraHistory: Content[] = [
{
role: 'user',
parts: [{ text: finalSummary }],
},
{
role: 'model',
parts: [{ text: 'Got it. Thanks for the additional context!' }],
},
...historyToKeepTruncated,
];
// Use a shared utility to construct the initial history for an accurate token count.
const fullNewHistory = await getInitialChatHistory(config, extraHistory);
const newTokenCount = await calculateRequestTokenCount(
fullNewHistory.flatMap((c) => c.parts || []),
config.getContentGenerator(),
model,
);
logChatCompression(
config,
makeChatCompressionEvent({
tokens_before: originalTokenCount,
tokens_after: newTokenCount,
}),
);
if (newTokenCount > originalTokenCount) {
return {
newHistory: null,
info: {
originalTokenCount,
newTokenCount,
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
},
};
} else {
return {
newHistory: extraHistory,
info: {
originalTokenCount,
newTokenCount,
compressionStatus: CompressionStatus.COMPRESSED,
},
};
}
}
}
@@ -0,0 +1,262 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ContextManager } from './contextManager.js';
import * as memoryDiscovery from '../utils/memoryDiscovery.js';
import type { Config } from '../config/config.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
// Mock memoryDiscovery module
vi.mock('../utils/memoryDiscovery.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('../utils/memoryDiscovery.js')>();
return {
...actual,
getGlobalMemoryPaths: vi.fn(),
getUserProjectMemoryPaths: vi.fn(),
getExtensionMemoryPaths: vi.fn(),
getEnvironmentMemoryPaths: vi.fn(),
readGeminiMdFiles: vi.fn(),
loadJitSubdirectoryMemory: vi.fn(),
deduplicatePathsByFileIdentity: vi.fn(),
concatenateInstructions: vi
.fn()
.mockImplementation(actual.concatenateInstructions),
};
});
describe('ContextManager', () => {
let contextManager: ContextManager;
let mockConfig: Config;
beforeEach(() => {
mockConfig = {
getWorkingDir: vi.fn().mockReturnValue('/app'),
getImportFormat: vi.fn().mockReturnValue('tree'),
getWorkspaceContext: vi.fn().mockReturnValue({
getDirectories: vi.fn().mockReturnValue(['/app']),
}),
getExtensionLoader: vi.fn().mockReturnValue({
getExtensions: vi.fn().mockReturnValue([]),
}),
getMcpClientManager: vi.fn().mockReturnValue({
getMcpInstructions: vi.fn().mockReturnValue('MCP Instructions'),
}),
isTrustedFolder: vi.fn().mockReturnValue(true),
getMemoryBoundaryMarkers: vi.fn().mockReturnValue(['.git']),
storage: {
getProjectMemoryDir: vi
.fn()
.mockReturnValue('/home/user/.gemini/memory/test-project'),
},
} as unknown as Config;
contextManager = new ContextManager(mockConfig);
vi.clearAllMocks();
vi.spyOn(coreEvents, 'emit');
vi.mocked(memoryDiscovery.getExtensionMemoryPaths).mockReturnValue([]);
vi.mocked(memoryDiscovery.getUserProjectMemoryPaths).mockResolvedValue([]);
// default mock: deduplication returns paths as-is (no deduplication)
vi.mocked(
memoryDiscovery.deduplicatePathsByFileIdentity,
).mockImplementation(async (paths: string[]) => ({
paths,
identityMap: new Map<string, string>(),
}));
});
describe('refresh', () => {
it('should load and format global and environment memory', async () => {
const globalPaths = ['/home/user/.gemini/GEMINI.md'];
const envPaths = ['/app/GEMINI.md'];
vi.mocked(memoryDiscovery.getGlobalMemoryPaths).mockResolvedValue(
globalPaths,
);
vi.mocked(memoryDiscovery.getEnvironmentMemoryPaths).mockResolvedValue(
envPaths,
);
vi.mocked(memoryDiscovery.readGeminiMdFiles).mockResolvedValue([
{ filePath: globalPaths[0], content: 'Global Content' },
{ filePath: envPaths[0], content: 'Env Content' },
]);
await contextManager.refresh();
expect(memoryDiscovery.getGlobalMemoryPaths).toHaveBeenCalled();
expect(memoryDiscovery.getEnvironmentMemoryPaths).toHaveBeenCalledWith(
['/app'],
['.git'],
);
expect(memoryDiscovery.readGeminiMdFiles).toHaveBeenCalledWith(
expect.arrayContaining([...globalPaths, ...envPaths]),
'tree',
['.git'],
);
expect(contextManager.getGlobalMemory()).toContain('Global Content');
expect(contextManager.getEnvironmentMemory()).toContain('Env Content');
expect(contextManager.getEnvironmentMemory()).toContain(
'MCP Instructions',
);
expect(contextManager.getLoadedPaths()).toContain(globalPaths[0]);
expect(contextManager.getLoadedPaths()).toContain(envPaths[0]);
});
it('should emit MemoryChanged event when memory is refreshed', async () => {
vi.mocked(memoryDiscovery.getGlobalMemoryPaths).mockResolvedValue([
'/app/GEMINI.md',
]);
vi.mocked(memoryDiscovery.getEnvironmentMemoryPaths).mockResolvedValue([
'/app/src/GEMINI.md',
]);
vi.mocked(memoryDiscovery.readGeminiMdFiles).mockResolvedValue([
{ filePath: '/app/GEMINI.md', content: 'content' },
{ filePath: '/app/src/GEMINI.md', content: 'env content' },
]);
await contextManager.refresh();
expect(coreEvents.emit).toHaveBeenCalledWith(CoreEvent.MemoryChanged, {
fileCount: 2,
});
});
it('should not load environment memory if folder is not trusted', async () => {
vi.mocked(mockConfig.isTrustedFolder).mockReturnValue(false);
vi.mocked(memoryDiscovery.getGlobalMemoryPaths).mockResolvedValue([
'/home/user/.gemini/GEMINI.md',
]);
vi.mocked(memoryDiscovery.readGeminiMdFiles).mockResolvedValue([
{ filePath: '/home/user/.gemini/GEMINI.md', content: 'Global Content' },
]);
await contextManager.refresh();
expect(memoryDiscovery.getEnvironmentMemoryPaths).not.toHaveBeenCalled();
expect(contextManager.getEnvironmentMemory()).toBe('');
expect(contextManager.getGlobalMemory()).toContain('Global Content');
});
it('should deduplicate files by file identity in case-insensitive filesystems', async () => {
const globalPaths = ['/home/user/.gemini/GEMINI.md'];
const envPaths = ['/app/gemini.md', '/app/GEMINI.md'];
vi.mocked(memoryDiscovery.getGlobalMemoryPaths).mockResolvedValue(
globalPaths,
);
vi.mocked(memoryDiscovery.getEnvironmentMemoryPaths).mockResolvedValue(
envPaths,
);
// mock deduplication to return deduplicated paths (simulating same file)
vi.mocked(
memoryDiscovery.deduplicatePathsByFileIdentity,
).mockResolvedValue({
paths: ['/home/user/.gemini/GEMINI.md', '/app/gemini.md'],
identityMap: new Map<string, string>(),
});
vi.mocked(memoryDiscovery.readGeminiMdFiles).mockResolvedValue([
{ filePath: '/home/user/.gemini/GEMINI.md', content: 'Global Content' },
{ filePath: '/app/gemini.md', content: 'Project Content' },
]);
await contextManager.refresh();
expect(
memoryDiscovery.deduplicatePathsByFileIdentity,
).toHaveBeenCalledWith(
expect.arrayContaining([
'/home/user/.gemini/GEMINI.md',
'/app/gemini.md',
'/app/GEMINI.md',
]),
);
expect(memoryDiscovery.readGeminiMdFiles).toHaveBeenCalledWith(
['/home/user/.gemini/GEMINI.md', '/app/gemini.md'],
'tree',
['.git'],
);
expect(contextManager.getEnvironmentMemory()).toContain(
'Project Content',
);
});
});
describe('discoverContext', () => {
it('should discover and load new context', async () => {
const mockResult: memoryDiscovery.MemoryLoadResult = {
files: [{ path: '/app/src/GEMINI.md', content: 'Src Content' }],
};
vi.mocked(memoryDiscovery.loadJitSubdirectoryMemory).mockResolvedValue(
mockResult,
);
const result = await contextManager.discoverContext('/app/src/file.ts', [
'/app',
]);
expect(memoryDiscovery.loadJitSubdirectoryMemory).toHaveBeenCalledWith(
'/app/src/file.ts',
['/app'],
expect.any(Set),
expect.any(Set),
['.git'],
);
expect(result).toMatch(/--- Context from: \/app\/src\/GEMINI\.md ---/);
expect(result).toContain('Src Content');
expect(contextManager.getLoadedPaths()).toContain('/app/src/GEMINI.md');
});
it('should return empty string if no new files found', async () => {
const mockResult: memoryDiscovery.MemoryLoadResult = { files: [] };
vi.mocked(memoryDiscovery.loadJitSubdirectoryMemory).mockResolvedValue(
mockResult,
);
const result = await contextManager.discoverContext('/app/src/file.ts', [
'/app',
]);
expect(result).toBe('');
});
it('should return empty string if folder is not trusted', async () => {
vi.mocked(mockConfig.isTrustedFolder).mockReturnValue(false);
const result = await contextManager.discoverContext('/app/src/file.ts', [
'/app',
]);
expect(memoryDiscovery.loadJitSubdirectoryMemory).not.toHaveBeenCalled();
expect(result).toBe('');
});
it('should pass custom boundary markers from config', async () => {
const customMarkers = ['.monorepo-root', 'package.json'];
vi.mocked(mockConfig.getMemoryBoundaryMarkers).mockReturnValue(
customMarkers,
);
vi.mocked(memoryDiscovery.loadJitSubdirectoryMemory).mockResolvedValue({
files: [],
});
await contextManager.discoverContext('/app/src/file.ts', ['/app']);
expect(memoryDiscovery.loadJitSubdirectoryMemory).toHaveBeenCalledWith(
'/app/src/file.ts',
['/app'],
expect.any(Set),
expect.any(Set),
customMarkers,
);
});
});
});
+203
View File
@@ -0,0 +1,203 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
loadJitSubdirectoryMemory,
concatenateInstructions,
getGlobalMemoryPaths,
getUserProjectMemoryPaths,
getExtensionMemoryPaths,
getEnvironmentMemoryPaths,
readGeminiMdFiles,
categorizeAndConcatenate,
type GeminiFileContent,
deduplicatePathsByFileIdentity,
} from '../utils/memoryDiscovery.js';
import type { Config } from '../config/config.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
export class ContextManager {
private readonly loadedPaths: Set<string> = new Set();
private readonly loadedFileIdentities: Set<string> = new Set();
private readonly config: Config;
private globalMemory: string = '';
private extensionMemory: string = '';
private projectMemory: string = '';
private userProjectMemoryContent: string = '';
constructor(config: Config) {
this.config = config;
}
/**
* Refreshes the memory by reloading global, extension, and project memory.
*/
async refresh(): Promise<void> {
this.loadedPaths.clear();
this.loadedFileIdentities.clear();
const paths = await this.discoverMemoryPaths();
const contentsMap = await this.loadMemoryContents(paths);
this.categorizeMemoryContents(paths, contentsMap);
this.emitMemoryChanged();
}
private async discoverMemoryPaths() {
const [global, extension, project, userProjectMemory] = await Promise.all([
getGlobalMemoryPaths(),
Promise.resolve(
getExtensionMemoryPaths(this.config.getExtensionLoader()),
),
this.config.isTrustedFolder()
? getEnvironmentMemoryPaths(
[...this.config.getWorkspaceContext().getDirectories()],
this.config.getMemoryBoundaryMarkers(),
)
: Promise.resolve([]),
getUserProjectMemoryPaths(this.config.storage.getProjectMemoryDir()),
]);
return { global, extension, project, userProjectMemory };
}
private async loadMemoryContents(paths: {
global: string[];
extension: string[];
project: string[];
userProjectMemory: string[];
}) {
const allPathsStringDeduped = Array.from(
new Set([
...paths.global,
...paths.extension,
...paths.project,
...paths.userProjectMemory,
]),
);
// deduplicate by file identity to handle case-insensitive filesystems
const { paths: allPaths, identityMap: pathIdentityMap } =
await deduplicatePathsByFileIdentity(allPathsStringDeduped);
const allContents = await readGeminiMdFiles(
allPaths,
this.config.getImportFormat(),
this.config.getMemoryBoundaryMarkers(),
);
const loadedFilePaths = allContents
.filter((c) => c.content !== null)
.map((c) => c.filePath);
this.markAsLoaded(loadedFilePaths);
// Cache file identities for performance optimization
for (const filePath of loadedFilePaths) {
const identity = pathIdentityMap.get(filePath);
if (identity) {
this.loadedFileIdentities.add(identity);
}
}
return new Map(allContents.map((c) => [c.filePath, c]));
}
private categorizeMemoryContents(
paths: {
global: string[];
extension: string[];
project: string[];
userProjectMemory: string[];
},
contentsMap: Map<string, GeminiFileContent>,
) {
const hierarchicalMemory = categorizeAndConcatenate(paths, contentsMap);
this.globalMemory = hierarchicalMemory.global || '';
this.extensionMemory = hierarchicalMemory.extension || '';
this.userProjectMemoryContent = hierarchicalMemory.userProjectMemory || '';
const mcpInstructions =
this.config.getMcpClientManager()?.getMcpInstructions() || '';
const projectMemoryWithMcp = [
hierarchicalMemory.project,
mcpInstructions.trimStart(),
]
.filter(Boolean)
.join('\n\n');
this.projectMemory = this.config.isTrustedFolder()
? projectMemoryWithMcp
: '';
}
/**
* Discovers and loads context for a specific accessed path (Tier 3 - JIT).
* Traverses upwards from the accessed path to the project root.
*/
async discoverContext(
accessedPath: string,
trustedRoots: string[],
): Promise<string> {
if (!this.config.isTrustedFolder()) {
return '';
}
const result = await loadJitSubdirectoryMemory(
accessedPath,
trustedRoots,
this.loadedPaths,
this.loadedFileIdentities,
this.config.getMemoryBoundaryMarkers(),
);
if (result.files.length === 0) {
return '';
}
const newFilePaths = result.files.map((f) => f.path);
this.markAsLoaded(newFilePaths);
// Cache identities for newly loaded files
if (result.fileIdentities) {
for (const identity of result.fileIdentities) {
this.loadedFileIdentities.add(identity);
}
}
return concatenateInstructions(
result.files.map((f) => ({ filePath: f.path, content: f.content })),
);
}
private emitMemoryChanged(): void {
coreEvents.emit(CoreEvent.MemoryChanged, {
fileCount: this.loadedPaths.size,
});
}
getGlobalMemory(): string {
return this.globalMemory;
}
getExtensionMemory(): string {
return this.extensionMemory;
}
getEnvironmentMemory(): string {
return this.projectMemory;
}
getUserProjectMemory(): string {
return this.userProjectMemoryContent;
}
private markAsLoaded(paths: string[]): void {
paths.forEach((p) => this.loadedPaths.add(p));
}
getLoadedPaths(): ReadonlySet<string> {
return this.loadedPaths;
}
}
@@ -0,0 +1,101 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ToolOutputDistillationService } from './toolDistillationService.js';
import type { Config, Part } from '../index.js';
import type { GeminiClient } from '../core/client.js';
describe('ToolOutputDistillationService', () => {
let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let service: ToolOutputDistillationService;
beforeEach(() => {
mockConfig = {
getToolMaxOutputTokens: vi.fn().mockReturnValue(100),
getToolSummarizationThresholdTokens: vi.fn().mockReturnValue(100),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/tmp/gemini'),
},
telemetry: {
logEvent: vi.fn(),
},
} as unknown as Config;
mockGeminiClient = {
generateContent: vi.fn().mockResolvedValue({
candidates: [{ content: { parts: [{ text: 'Mock Intent Summary' }] } }],
}),
} as unknown as GeminiClient;
service = new ToolOutputDistillationService(
mockConfig,
mockGeminiClient,
'test-prompt-id',
);
});
it('should generate a structural map for oversized content within limits', async () => {
// > threshold * SUMMARIZATION_THRESHOLD (100 * 4 = 400)
const largeContent = 'A'.repeat(500);
const result = await service.distill('test-tool', 'call-1', largeContent);
expect(mockGeminiClient.generateContent).toHaveBeenCalled();
const text =
typeof result.truncatedContent === 'string'
? result.truncatedContent
: (result.truncatedContent as Array<{ text: string }>)[0].text;
expect(text).toContain('Strategic Significance');
});
it('should structurally truncate functionResponse while preserving schema', async () => {
// threshold is 100
const hugeValue = 'H'.repeat(1000);
const content = [
{
functionResponse: {
name: 'test_tool',
id: '123',
response: {
stdout: hugeValue,
stderr: 'no error',
},
},
},
] as unknown as Part[];
const result = await service.distill('test-tool', 'call-1', content);
const truncatedParts = result.truncatedContent as Part[];
expect(truncatedParts.length).toBe(1);
const fr = truncatedParts[0].functionResponse!;
const resp = fr.response as Record<string, unknown>;
expect(fr.name).toBe('test_tool');
expect(resp['stderr']).toBe('no error');
expect(resp['stdout'] as string).toContain('[Message Normalized');
expect(resp['stdout'] as string).toContain('Full output saved to');
});
it('should skip structural map for extremely large content exceeding MAX_DISTILLATION_SIZE', async () => {
const massiveContent = 'A'.repeat(1_000_001); // > MAX_DISTILLATION_SIZE
const result = await service.distill('test-tool', 'call-2', massiveContent);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
const text =
typeof result.truncatedContent === 'string'
? result.truncatedContent
: (result.truncatedContent as Array<{ text: string }>)[0].text;
expect(text).not.toContain('Strategic Significance');
});
it('should skip structural map for content below summarization threshold', async () => {
// > threshold but < threshold * SUMMARIZATION_THRESHOLD
const mediumContent = 'A'.repeat(110);
const result = await service.distill('test-tool', 'call-3', mediumContent);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
expect(result.truncatedContent).not.toContain('Mock Intent Summary');
});
});
@@ -0,0 +1,293 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
LlmRole,
ToolOutputTruncatedEvent,
logToolOutputTruncated,
debugLogger,
type Config,
} from '../index.js';
import type { PartListUnion } from '@google/genai';
import { type GeminiClient } from '../core/client.js';
import { saveTruncatedToolOutput } from '../utils/fileUtils.js';
import {
READ_FILE_TOOL_NAME,
READ_MANY_FILES_TOOL_NAME,
} from '../tools/tool-names.js';
import {
truncateProportionally,
TOOL_TRUNCATION_PREFIX,
MIN_TARGET_TOKENS,
estimateCharsFromTokens,
normalizeFunctionResponse,
} from './truncation.js';
// Skip structural map generation for outputs larger than this threshold (in characters)
// as it consumes excessive tokens and may not be representative of the full content.
const MAX_DISTILLATION_SIZE = 1_000_000;
export interface DistilledToolOutput {
truncatedContent: PartListUnion;
outputFile?: string;
}
export class ToolOutputDistillationService {
constructor(
private readonly config: Config,
private readonly geminiClient: GeminiClient,
private readonly promptId: string,
) {}
/**
* Distills a tool's output if it exceeds configured length thresholds, preserving
* the agent's context window. This includes saving the raw output to disk, replacing
* the output with a truncated placeholder, and optionally summarizing the output
* via a secondary LLM call if the output is massively oversized.
*/
async distill(
toolName: string,
callId: string,
content: PartListUnion,
): Promise<DistilledToolOutput> {
// Explicitly bypass escape hatches that natively handle large outputs
if (this.isExemptFromDistillation(toolName)) {
return { truncatedContent: content };
}
const maxTokens = this.config.getToolMaxOutputTokens();
const thresholdChars = maxTokens * 4;
if (thresholdChars <= 0) {
return { truncatedContent: content };
}
const originalContentLength = this.calculateContentLength(content);
if (originalContentLength > thresholdChars) {
return this.performDistillation(
toolName,
callId,
content,
originalContentLength,
thresholdChars,
);
}
return { truncatedContent: content };
}
private isExemptFromDistillation(toolName: string): boolean {
return (
toolName === READ_FILE_TOOL_NAME || toolName === READ_MANY_FILES_TOOL_NAME
);
}
private calculateContentLength(content: PartListUnion): number {
if (typeof content === 'string') {
return content.length;
}
if (Array.isArray(content)) {
return content.reduce((acc, part) => {
if (typeof part === 'string') return acc + part.length;
if (part.text) return acc + part.text.length;
if (part.functionResponse?.response) {
// Estimate length of the response object
return acc + JSON.stringify(part.functionResponse.response).length;
}
return acc;
}, 0);
}
return 0;
}
private stringifyContent(content: PartListUnion): string {
if (typeof content === 'string') return content;
// For arrays or other objects, we preserve the structural JSON to maintain
// the ability to reconstruct the parts if needed from the saved output.
return JSON.stringify(content, null, 2);
}
private async performDistillation(
toolName: string,
callId: string,
content: PartListUnion,
originalContentLength: number,
threshold: number,
): Promise<DistilledToolOutput> {
const stringifiedContent = this.stringifyContent(content);
// Save the raw, untruncated string to disk for human review
const { outputFile: savedPath } = await saveTruncatedToolOutput(
stringifiedContent,
toolName,
callId,
this.config.storage.getProjectTempDir(),
this.promptId,
);
// If the output is massively oversized, attempt to generate an intent summary
let intentSummaryText = '';
const summarizationThresholdTokens =
this.config.getToolSummarizationThresholdTokens();
const summarizationThresholdChars = summarizationThresholdTokens * 4;
if (
originalContentLength > summarizationThresholdChars &&
originalContentLength <= MAX_DISTILLATION_SIZE
) {
const summary = await this.generateIntentSummary(
toolName,
stringifiedContent,
Math.floor(MAX_DISTILLATION_SIZE),
);
if (summary) {
intentSummaryText = `\n\n--- Strategic Significance of Truncated Content ---\n${summary}`;
}
}
// Perform structural truncation
const ratio = threshold / originalContentLength;
const truncatedContent = this.truncateContentStructurally(
content,
ratio,
savedPath || 'Output offloaded to disk',
intentSummaryText,
);
logToolOutputTruncated(
this.config,
new ToolOutputTruncatedEvent(this.promptId, {
toolName,
originalContentLength,
truncatedContentLength: this.calculateContentLength(truncatedContent),
threshold,
}),
);
return {
truncatedContent,
outputFile: savedPath,
};
}
/**
* Truncates content while maintaining its Part structure.
*/
private truncateContentStructurally(
content: PartListUnion,
ratio: number,
savedPath: string,
intentSummary: string,
): PartListUnion {
if (typeof content === 'string') {
const targetTokens = Math.max(
MIN_TARGET_TOKENS,
Math.floor((content.length / 4) * ratio),
);
const targetChars = estimateCharsFromTokens(content, targetTokens);
return (
truncateProportionally(content, targetChars, TOOL_TRUNCATION_PREFIX) +
`\n\nFull output saved to: ${savedPath}` +
intentSummary
);
}
if (!Array.isArray(content)) return content;
return content.map((part) => {
if (typeof part === 'string') {
const text = part;
const targetTokens = Math.max(
MIN_TARGET_TOKENS,
Math.floor((text.length / 4) * ratio),
);
const targetChars = estimateCharsFromTokens(text, targetTokens);
return truncateProportionally(
text,
targetChars,
TOOL_TRUNCATION_PREFIX,
);
}
if (part.text) {
const text = part.text;
const targetTokens = Math.max(
MIN_TARGET_TOKENS,
Math.floor((text.length / 4) * ratio),
);
const targetChars = estimateCharsFromTokens(text, targetTokens);
return {
text:
truncateProportionally(text, targetChars, TOOL_TRUNCATION_PREFIX) +
`\n\nFull output saved to: ${savedPath}` +
intentSummary,
};
}
if (part.functionResponse) {
return normalizeFunctionResponse(
part,
ratio,
0.2, // default headRatio
savedPath,
intentSummary,
);
}
return part;
});
}
/**
* Calls the secondary model to distill the strategic "why" signals and intent
* of the truncated content before it is offloaded.
*/
private async generateIntentSummary(
toolName: string,
stringifiedContent: string,
maxPreviewLen: number,
): Promise<string | undefined> {
try {
const controller = new AbortController();
const timeoutId = setTimeout(() => controller.abort(), 15000); // 15s timeout
const promptText = `The following output from the tool '${toolName}' is large and has been truncated. Extract the most critical factual information from this output so the main agent doesn't lose context.
Focus strictly on concrete data points:
1. Exact error messages, exception types, or exit codes.
2. Specific file paths or line numbers mentioned.
3. Definitive outcomes (e.g., 'Compilation succeeded', '3 tests failed').
Do not philosophize about the strategic intent. Keep the extraction under 10 lines and use exact quotes where helpful.
Output to summarize:
${stringifiedContent.slice(0, maxPreviewLen)}...`;
const summaryResponse = await this.geminiClient.generateContent(
{ model: 'agent-history-provider-summarizer' },
[{ role: 'user', parts: [{ text: promptText }] }],
controller.signal,
LlmRole.UTILITY_COMPRESSOR,
);
clearTimeout(timeoutId);
return summaryResponse.candidates?.[0]?.content?.parts?.[0]?.text;
} catch (e) {
// Fail gracefully, summarization is a progressive enhancement
debugLogger.debug(
'Failed to generate intent summary for truncated output:',
e instanceof Error ? e.message : String(e),
);
return undefined;
}
}
}
@@ -0,0 +1,665 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import fs from 'node:fs';
import path from 'node:path';
import os from 'node:os';
import {
ToolOutputMaskingService,
MASKING_INDICATOR_TAG,
} from './toolOutputMaskingService.js';
import {
SHELL_TOOL_NAME,
ACTIVATE_SKILL_TOOL_NAME,
MEMORY_TOOL_NAME,
} from '../tools/tool-names.js';
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
import type { Config } from '../config/config.js';
import type { Content, Part } from '@google/genai';
vi.mock('../utils/tokenCalculation.js', () => ({
estimateTokenCountSync: vi.fn(),
}));
describe('ToolOutputMaskingService', () => {
let service: ToolOutputMaskingService;
let mockConfig: Config;
let testTempDir: string;
const mockedEstimateTokenCountSync = vi.mocked(estimateTokenCountSync);
beforeEach(async () => {
testTempDir = await fs.promises.mkdtemp(
path.join(os.tmpdir(), 'tool-masking-test-'),
);
service = new ToolOutputMaskingService();
mockConfig = {
storage: {
getHistoryDir: () => path.join(testTempDir, 'history'),
getProjectTempDir: () => testTempDir,
},
getSessionId: () => 'mock-session',
getUsageStatisticsEnabled: () => false,
getToolOutputMaskingEnabled: () => true,
getToolOutputMaskingConfig: async () => ({
enabled: true,
toolProtectionThreshold: 50000,
minPrunableTokensThreshold: 30000,
protectLatestTurn: true,
}),
} as unknown as Config;
vi.clearAllMocks();
});
afterEach(async () => {
vi.restoreAllMocks();
if (testTempDir) {
await fs.promises.rm(testTempDir, { recursive: true, force: true });
}
});
it('should respect remote configuration overrides', async () => {
mockConfig.getToolOutputMaskingConfig = async () => ({
enabled: true,
toolProtectionThreshold: 100, // Very low threshold
minPrunableTokensThreshold: 50,
protectLatestTurn: false,
});
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'test_tool',
response: { output: 'A'.repeat(200) },
},
},
],
},
];
mockedEstimateTokenCountSync.mockImplementation((parts) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
return content.includes(MASKING_INDICATOR_TAG) ? 10 : 200;
});
const result = await service.mask(history, mockConfig);
// With low thresholds and protectLatestTurn=false, it should mask even the latest turn
expect(result.maskedCount).toBe(1);
expect(result.tokensSaved).toBeGreaterThan(0);
});
it('should not mask if total tool tokens are below protection threshold', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'test_tool',
response: { output: 'small output' },
},
},
],
},
];
mockedEstimateTokenCountSync.mockReturnValue(100);
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(0);
expect(result.newHistory).toEqual(history);
});
const getToolResponse = (part: Part | undefined): string => {
const resp = part?.functionResponse?.response as
| { output: string }
| undefined;
return resp?.output ?? (resp as unknown as string) ?? '';
};
it('should protect the latest turn and mask older outputs beyond 50k window if total > 30k', async () => {
// History:
// Turn 1: 60k (Oldest)
// Turn 2: 20k
// Turn 3: 10k (Latest) - Protected because PROTECT_LATEST_TURN is true
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 't1',
response: { output: 'A'.repeat(60000) },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 't2',
response: { output: 'B'.repeat(20000) },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 't3',
response: { output: 'C'.repeat(10000) },
},
},
],
},
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const toolName = parts[0].functionResponse?.name;
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
if (toolName === 't1') return 60000;
if (toolName === 't2') return 20000;
if (toolName === 't3') return 10000;
return 0;
});
// Scanned: Turn 2 (20k), Turn 1 (60k). Total = 80k.
// Turn 2: Cumulative = 20k. Protected (<= 50k).
// Turn 1: Cumulative = 80k. Crossed 50k boundary. Prunabled.
// Total Prunable = 60k (> 30k trigger).
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(1);
expect(getToolResponse(result.newHistory[0].parts?.[0])).toContain(
`<${MASKING_INDICATOR_TAG}`,
);
expect(getToolResponse(result.newHistory[1].parts?.[0])).toEqual(
'B'.repeat(20000),
);
expect(getToolResponse(result.newHistory[2].parts?.[0])).toEqual(
'C'.repeat(10000),
);
});
it('should perform global aggregation for many small parts once boundary is hit', async () => {
// history.length = 12. Skip index 11 (latest).
// Indices 0-10: 10k each.
// Index 10: 10k (Sum 10k)
// Index 9: 10k (Sum 20k)
// Index 8: 10k (Sum 30k)
// Index 7: 10k (Sum 40k)
// Index 6: 10k (Sum 50k) - Boundary hit here?
// Actually, Boundary is 50k. So Index 6 crosses it.
// Index 6, 5, 4, 3, 2, 1, 0 are all prunable. (7 * 10k = 70k).
const history: Content[] = Array.from({ length: 12 }, (_, i) => ({
role: 'user',
parts: [
{
functionResponse: {
name: `tool${i}`,
response: { output: 'A'.repeat(10000) },
},
},
],
}));
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as
| { output?: string; result?: string }
| string
| undefined;
const content =
typeof resp === 'string'
? resp
: resp?.output || resp?.result || JSON.stringify(resp);
if (content?.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
return content?.length || 0;
});
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(6); // boundary at 50k protects 0-5
expect(result.tokensSaved).toBeGreaterThan(0);
});
it('should verify tool-aware previews (shell vs generic)', async () => {
const shellHistory: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: SHELL_TOOL_NAME,
response: {
output:
'Output: line1\nline2\nline3\nline4\nline5\nError: failed\nExit Code: 1',
},
},
},
],
},
// Protection buffer
{
role: 'user',
parts: [
{
functionResponse: {
name: 'p',
response: { output: 'p'.repeat(60000) },
},
},
],
},
// Latest turn
{
role: 'user',
parts: [{ functionResponse: { name: 'l', response: { output: 'l' } } }],
},
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const name = parts[0].functionResponse?.name;
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
if (name === SHELL_TOOL_NAME) return 100000;
if (name === 'p') return 60000;
return 100;
});
const result = await service.mask(shellHistory, mockConfig);
const maskedBash = getToolResponse(result.newHistory[0].parts?.[0]);
expect(maskedBash).toContain('Output: line1\nline2\nline3\nline4\nline5');
expect(maskedBash).toContain('Exit Code: 1');
expect(maskedBash).toContain('Error: failed');
});
it('should skip already masked content and not count it towards totals', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'tool1',
response: {
output: `<${MASKING_INDICATOR_TAG}>...</${MASKING_INDICATOR_TAG}>`,
},
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 'tool2',
response: { output: 'A'.repeat(60000) },
},
},
],
},
];
mockedEstimateTokenCountSync.mockReturnValue(60000);
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(0); // tool1 skipped, tool2 is the "latest" which is protected
});
it('should handle different response keys in masked update', async () => {
const history: Content[] = [
{
role: 'model',
parts: [
{
functionResponse: {
name: 't1',
response: { result: 'A'.repeat(60000) },
},
},
],
},
{
role: 'model',
parts: [
{
functionResponse: {
name: 'p',
response: { output: 'P'.repeat(60000) },
},
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content =
(resp?.['output'] as string) ??
(resp?.['result'] as string) ??
JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
return 60000;
});
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(2); // both t1 and p are prunable (cumulative 60k and 120k)
const responseObj = result.newHistory[0].parts?.[0].functionResponse
?.response as Record<string, unknown>;
expect(Object.keys(responseObj)).toEqual(['output']);
});
it('should preserve multimodal parts while masking tool responses', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 't1',
response: { output: 'A'.repeat(60000) },
},
},
{
inlineData: {
data: 'base64data',
mimeType: 'image/png',
},
},
],
},
// Protection buffer
{
role: 'user',
parts: [
{
functionResponse: {
name: 'p',
response: { output: 'p'.repeat(60000) },
},
},
],
},
// Latest turn
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
if (parts[0].functionResponse?.name === 't1') return 60000;
if (parts[0].functionResponse?.name === 'p') return 60000;
return 100;
});
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(2); //Both t1 and p are prunable (cumulative 60k each > 50k protection)
expect(result.newHistory[0].parts).toHaveLength(2);
expect(result.newHistory[0].parts?.[0].functionResponse).toBeDefined();
expect(
(
result.newHistory[0].parts?.[0].functionResponse?.response as Record<
string,
unknown
>
)['output'],
).toContain(`<${MASKING_INDICATOR_TAG}`);
expect(result.newHistory[0].parts?.[1].inlineData).toEqual({
data: 'base64data',
mimeType: 'image/png',
});
});
it('should match the expected snapshot for a masked tool output', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: SHELL_TOOL_NAME,
response: {
output: 'Line\n'.repeat(25),
exitCode: 0,
},
},
},
],
},
// Buffer to push shell_tool into prunable territory
{
role: 'user',
parts: [
{
functionResponse: {
name: 'padding',
response: { output: 'B'.repeat(60000) },
},
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
if (parts[0].functionResponse?.name === SHELL_TOOL_NAME) return 1000;
if (parts[0].functionResponse?.name === 'padding') return 60000;
return 10;
});
const result = await service.mask(history, mockConfig);
// Verify complete masking: only 'output' key should exist
const responseObj = result.newHistory[0].parts?.[0].functionResponse
?.response as Record<string, unknown>;
expect(Object.keys(responseObj)).toEqual(['output']);
const response = responseObj['output'] as string;
// We replace the random part of the filename for deterministic snapshots
// and normalize path separators for cross-platform compatibility
const normalizedResponse = response.replace(/\\/g, '/');
const deterministicResponse = normalizedResponse
.replace(new RegExp(testTempDir.replace(/\\/g, '/'), 'g'), '/mock/temp')
.replace(
new RegExp(`${SHELL_TOOL_NAME}_[^\\s"]+\\.txt`, 'g'),
`${SHELL_TOOL_NAME}_deterministic.txt`,
);
expect(deterministicResponse).toMatchSnapshot();
});
it('should not mask if masking increases token count (due to overhead)', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: 'tiny_tool',
response: { output: 'tiny' },
},
},
],
},
// Protection buffer to push tiny_tool into prunable territory
{
role: 'user',
parts: [
{
functionResponse: {
name: 'padding',
response: { output: 'B'.repeat(60000) },
},
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
if (parts[0].functionResponse?.name === 'tiny_tool') return 5;
if (parts[0].functionResponse?.name === 'padding') return 60000;
return 1000; // The masked version would be huge due to boilerplate
});
const result = await service.mask(history, mockConfig);
expect(result.maskedCount).toBe(0); // padding is protected, tiny_tool would increase size
});
it('should never mask exempt tools (like activate_skill) even if they are deep in history', async () => {
const history: Content[] = [
{
role: 'user',
parts: [
{
functionResponse: {
name: ACTIVATE_SKILL_TOOL_NAME,
response: { output: 'High value instructions for skill' },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: MEMORY_TOOL_NAME,
response: { output: 'Important user preference' },
},
},
],
},
{
role: 'user',
parts: [
{
functionResponse: {
name: 'bulky_tool',
response: { output: 'A'.repeat(60000) },
},
},
],
},
// Protection buffer
{
role: 'user',
parts: [
{
functionResponse: {
name: 'padding',
response: { output: 'B'.repeat(60000) },
},
},
],
},
{ role: 'user', parts: [{ text: 'latest' }] },
];
mockedEstimateTokenCountSync.mockImplementation((parts: Part[]) => {
const resp = parts[0].functionResponse?.response as Record<
string,
unknown
>;
const content = (resp?.['output'] as string) ?? JSON.stringify(resp);
if (content.includes(`<${MASKING_INDICATOR_TAG}`)) return 100;
const name = parts[0].functionResponse?.name;
if (name === ACTIVATE_SKILL_TOOL_NAME) return 1000;
if (name === MEMORY_TOOL_NAME) return 500;
if (name === 'bulky_tool') return 60000;
if (name === 'padding') return 60000;
return 10;
});
const result = await service.mask(history, mockConfig);
// Both 'bulky_tool' and 'padding' should be masked.
// 'padding' (Index 3) crosses the 50k protection boundary immediately.
// ACTIVATE_SKILL and MEMORY are exempt.
expect(result.maskedCount).toBe(2);
expect(result.newHistory[0].parts?.[0].functionResponse?.name).toBe(
ACTIVATE_SKILL_TOOL_NAME,
);
expect(
(
result.newHistory[0].parts?.[0].functionResponse?.response as Record<
string,
unknown
>
)['output'],
).toBe('High value instructions for skill');
expect(result.newHistory[1].parts?.[0].functionResponse?.name).toBe(
MEMORY_TOOL_NAME,
);
expect(
(
result.newHistory[1].parts?.[0].functionResponse?.response as Record<
string,
unknown
>
)['output'],
).toBe('Important user preference');
expect(result.newHistory[2].parts?.[0].functionResponse?.name).toBe(
'bulky_tool',
);
expect(
(
result.newHistory[2].parts?.[0].functionResponse?.response as Record<
string,
unknown
>
)['output'],
).toContain(MASKING_INDICATOR_TAG);
});
});
@@ -0,0 +1,379 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, Part } from '@google/genai';
import path from 'node:path';
import * as fsPromises from 'node:fs/promises';
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
import { debugLogger } from '../utils/debugLogger.js';
import { sanitizeFilenamePart } from '../utils/fileUtils.js';
import type { Config } from '../config/config.js';
import { logToolOutputMasking } from '../telemetry/loggers.js';
import {
SHELL_TOOL_NAME,
ACTIVATE_SKILL_TOOL_NAME,
MEMORY_TOOL_NAME,
ASK_USER_TOOL_NAME,
ENTER_PLAN_MODE_TOOL_NAME,
EXIT_PLAN_MODE_TOOL_NAME,
} from '../tools/tool-names.js';
import { ToolOutputMaskingEvent } from '../telemetry/types.js';
// Tool output masking defaults
export const DEFAULT_TOOL_PROTECTION_THRESHOLD = 50000;
export const DEFAULT_MIN_PRUNABLE_TOKENS_THRESHOLD = 30000;
export const DEFAULT_PROTECT_LATEST_TURN = true;
export const MASKING_INDICATOR_TAG = 'tool_output_masked';
export const TOOL_OUTPUTS_DIR = 'tool-outputs';
/**
* Tools whose outputs are always high-signal and should never be masked,
* regardless of their position in the conversation history.
*/
const EXEMPT_TOOLS = new Set([
ACTIVATE_SKILL_TOOL_NAME,
MEMORY_TOOL_NAME,
ASK_USER_TOOL_NAME,
ENTER_PLAN_MODE_TOOL_NAME,
EXIT_PLAN_MODE_TOOL_NAME,
]);
export interface MaskingResult {
newHistory: readonly Content[];
maskedCount: number;
tokensSaved: number;
}
/**
* Service to manage context window efficiency by masking bulky tool outputs (Tool Output Masking).
*
* It implements a "Hybrid Backward Scanned FIFO" algorithm to balance context relevance with
* token savings:
* 1. **Protection Window**: Protects the newest `toolProtectionThreshold` (default 50k) tool tokens
* from pruning. Optionally skips the entire latest conversation turn to ensure full context for
* the model's next response.
* 2. **Global Aggregation**: Scans backwards past the protection window to identify all remaining
* tool outputs that haven't been masked yet.
* 3. **Batch Trigger**: Trigger masking only if the total prunable tokens exceed
* `minPrunableTokensThreshold` (default 30k).
*
* @remarks
* Effectively, this means masking only starts once the conversation contains approximately 80k
* tokens of prunable tool outputs (50k protected + 30k prunable buffer). Small tool outputs
* are preserved until they collectively reach the threshold.
*/
export class ToolOutputMaskingService {
async mask(
history: readonly Content[],
config: Config,
): Promise<MaskingResult> {
const maskingConfig = await config.getToolOutputMaskingConfig();
if (!maskingConfig.enabled || history.length === 0) {
return { newHistory: history, maskedCount: 0, tokensSaved: 0 };
}
let cumulativeToolTokens = 0;
let protectionBoundaryReached = false;
let totalPrunableTokens = 0;
let maskedCount = 0;
const prunableParts: Array<{
contentIndex: number;
partIndex: number;
tokens: number;
content: string;
originalPart: Part;
}> = [];
// Decide where to start scanning.
// If PROTECT_LATEST_TURN is true, we skip the most recent message (index history.length - 1).
const scanStartIdx = maskingConfig.protectLatestTurn
? history.length - 2
: history.length - 1;
// Backward scan to identify prunable tool outputs
for (let i = scanStartIdx; i >= 0; i--) {
const content = history[i];
const parts = content.parts || [];
for (let j = parts.length - 1; j >= 0; j--) {
const part = parts[j];
// Tool outputs (functionResponse) are the primary targets for pruning because
// they often contain voluminous data (e.g., shell logs, file content) that
// can exceed context limits. We preserve other parts—such as user text,
// model reasoning, and multimodal data—because they define the conversation's
// core intent and logic, which are harder for the model to recover if lost.
if (!part.functionResponse) continue;
const toolName = part.functionResponse.name;
if (toolName && EXEMPT_TOOLS.has(toolName)) {
continue;
}
const toolOutputContent = this.getToolOutputContent(part);
if (!toolOutputContent || this.isAlreadyMasked(toolOutputContent)) {
continue;
}
const partTokens = estimateTokenCountSync([part]);
if (!protectionBoundaryReached) {
cumulativeToolTokens += partTokens;
if (cumulativeToolTokens > maskingConfig.toolProtectionThreshold) {
protectionBoundaryReached = true;
// The part that crossed the boundary is prunable.
totalPrunableTokens += partTokens;
prunableParts.push({
contentIndex: i,
partIndex: j,
tokens: partTokens,
content: toolOutputContent,
originalPart: part,
});
}
} else {
totalPrunableTokens += partTokens;
prunableParts.push({
contentIndex: i,
partIndex: j,
tokens: partTokens,
content: toolOutputContent,
originalPart: part,
});
}
}
}
// Trigger pruning only if we have accumulated enough savings to justify the
// overhead of masking and file I/O (batch pruning threshold).
if (totalPrunableTokens < maskingConfig.minPrunableTokensThreshold) {
return { newHistory: history, maskedCount: 0, tokensSaved: 0 };
}
debugLogger.debug(
`[ToolOutputMasking] Triggering masking. Prunable tool tokens: ${totalPrunableTokens.toLocaleString()} (> ${maskingConfig.minPrunableTokensThreshold.toLocaleString()})`,
);
// Perform masking and offloading
const newHistory = [...history]; // Shallow copy of history
let actualTokensSaved = 0;
let toolOutputsDir = path.join(
config.storage.getProjectTempDir(),
TOOL_OUTPUTS_DIR,
);
const sessionId = config.getSessionId();
if (sessionId) {
const safeSessionId = sanitizeFilenamePart(sessionId);
toolOutputsDir = path.join(toolOutputsDir, `session-${safeSessionId}`);
}
await fsPromises.mkdir(toolOutputsDir, { recursive: true });
for (const item of prunableParts) {
const { contentIndex, partIndex, content, tokens } = item;
const contentRecord = newHistory[contentIndex];
const part = contentRecord.parts![partIndex];
if (!part.functionResponse) continue;
const toolName = part.functionResponse.name || 'unknown_tool';
const callId = part.functionResponse.id || Date.now().toString();
const safeToolName = sanitizeFilenamePart(toolName).toLowerCase();
const safeCallId = sanitizeFilenamePart(callId).toLowerCase();
const fileName = `${safeToolName}_${safeCallId}_${Math.random()
.toString(36)
.substring(7)}.txt`;
const filePath = path.join(toolOutputsDir, fileName);
await fsPromises.writeFile(filePath, content, 'utf-8');
const originalResponse =
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(part.functionResponse.response as Record<string, unknown>) || {};
const totalLines = content.split('\n').length;
const fileSizeMB = (
Buffer.byteLength(content, 'utf8') /
1024 /
1024
).toFixed(2);
let preview = '';
if (toolName === SHELL_TOOL_NAME) {
preview = this.formatShellPreview(originalResponse);
} else {
// General tools: Head + Tail preview (250 chars each)
if (content.length > 500) {
preview = `${content.slice(0, 250)}\n... [TRUNCATED] ...\n${content.slice(-250)}`;
} else {
preview = content;
}
}
const maskedSnippet = this.formatMaskedSnippet({
toolName,
filePath,
fileSizeMB,
totalLines,
tokens,
preview,
});
const maskedPart = {
...part,
functionResponse: {
// eslint-disable-next-line @typescript-eslint/no-misused-spread
...part.functionResponse,
response: { output: maskedSnippet },
},
};
const newTaskTokens = estimateTokenCountSync([maskedPart]);
const savings = tokens - newTaskTokens;
if (savings > 0) {
const newParts = [...contentRecord.parts!];
newParts[partIndex] = maskedPart;
newHistory[contentIndex] = { ...contentRecord, parts: newParts };
actualTokensSaved += savings;
maskedCount++;
}
}
debugLogger.debug(
`[ToolOutputMasking] Masked ${maskedCount} tool outputs. Saved ~${actualTokensSaved.toLocaleString()} tokens.`,
);
const result = {
newHistory,
maskedCount,
tokensSaved: actualTokensSaved,
};
if (actualTokensSaved <= 0) {
return result;
}
logToolOutputMasking(
config,
new ToolOutputMaskingEvent({
tokens_before: totalPrunableTokens,
tokens_after: totalPrunableTokens - actualTokensSaved,
masked_count: maskedCount,
total_prunable_tokens: totalPrunableTokens,
}),
);
return result;
}
private getToolOutputContent(part: Part): string | null {
if (!part.functionResponse) return null;
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const response = part.functionResponse.response as Record<string, unknown>;
if (!response) return null;
// Stringify the entire response for saving.
// This handles any tool output schema automatically.
const content = JSON.stringify(response, null, 2);
// Multimodal safety check: Sibling parts (inlineData, etc.) are handled by mask()
// by keeping the original part structure and only replacing the functionResponse content.
return content;
}
private isAlreadyMasked(content: string): boolean {
return content.includes(`<${MASKING_INDICATOR_TAG}`);
}
private formatShellPreview(response: Record<string, unknown>): string {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const content = (response['output'] || response['stdout'] || '') as string;
if (typeof content !== 'string') {
return typeof content === 'object'
? JSON.stringify(content)
: String(content);
}
// The shell tool output is structured in shell.ts with specific section prefixes:
const sectionRegex =
/^(Output|Error|Exit Code|Signal|Background PIDs|Process Group PGID): /m;
const parts = content.split(sectionRegex);
if (parts.length < 3) {
// Fallback to simple head/tail if not in expected shell.ts format
return this.formatSimplePreview(content);
}
const previewParts: string[] = [];
if (parts[0].trim()) {
previewParts.push(this.formatSimplePreview(parts[0].trim()));
}
for (let i = 1; i < parts.length; i += 2) {
const name = parts[i];
const sectionContent = parts[i + 1]?.trim() || '';
if (name === 'Output') {
previewParts.push(
`Output: ${this.formatSimplePreview(sectionContent)}`,
);
} else {
// Keep other sections (Error, Exit Code, etc.) in full as they are usually high-signal and small
previewParts.push(`${name}: ${sectionContent}`);
}
}
let preview = previewParts.join('\n');
// Also check root levels just in case some tool uses them or for future-proofing
const exitCode = response['exitCode'] ?? response['exit_code'];
const error = response['error'];
if (
exitCode !== undefined &&
exitCode !== 0 &&
exitCode !== null &&
!content.includes(`Exit Code: ${exitCode}`)
) {
preview += `\n[Exit Code: ${exitCode}]`;
}
if (error && !content.includes(`Error: ${error}`)) {
preview += `\n[Error: ${error}]`;
}
return preview;
}
private formatSimplePreview(content: string): string {
const lines = content.split('\n');
if (lines.length <= 20) return content;
const head = lines.slice(0, 10);
const tail = lines.slice(-10);
return `${head.join('\n')}\n\n... [${
lines.length - head.length - tail.length
} lines omitted] ...\n\n${tail.join('\n')}`;
}
private formatMaskedSnippet(params: MaskedSnippetParams): string {
const { filePath, preview } = params;
return `<${MASKING_INDICATOR_TAG}>
${preview}
Output too large. Full output available at: ${filePath}
</${MASKING_INDICATOR_TAG}>`;
}
}
interface MaskedSnippetParams {
toolName: string;
filePath: string;
fileSizeMB: string;
totalLines: number;
tokens: number;
preview: string;
}
+142
View File
@@ -0,0 +1,142 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Part } from '@google/genai';
import {
estimateTokenCountSync,
ASCII_TOKENS_PER_CHAR,
NON_ASCII_TOKENS_PER_CHAR,
} from '../utils/tokenCalculation.js';
export const MIN_TARGET_TOKENS = 10;
export const MIN_CHARS_FOR_TRUNCATION = 100;
export const TEXT_TRUNCATION_PREFIX =
'[Message Normalized: Exceeded size limit]';
export const TOOL_TRUNCATION_PREFIX =
'[Message Normalized: Tool output exceeded size limit]';
/**
* Estimates the character limit for a target token count, accounting for ASCII vs Non-ASCII.
* Uses a weighted average based on the provided text to decide how many characters
* fit into the target token budget.
*/
export function estimateCharsFromTokens(
text: string,
targetTokens: number,
): number {
if (text.length === 0) return 0;
// Count ASCII vs Non-ASCII in a sample of the text.
let asciiCount = 0;
const sampleLen = Math.min(text.length, 1000);
for (let i = 0; i < sampleLen; i++) {
if (text.charCodeAt(i) <= 127) {
asciiCount++;
}
}
const asciiRatio = asciiCount / sampleLen;
// Weighted tokens per character:
const avgTokensPerChar =
asciiRatio * ASCII_TOKENS_PER_CHAR +
(1 - asciiRatio) * NON_ASCII_TOKENS_PER_CHAR;
// Characters = Tokens / (Tokens per Character)
return Math.floor(targetTokens / avgTokensPerChar);
}
/**
* Truncates a string to a target length, keeping a proportional amount of the head and tail,
* and prepending a prefix.
*/
export function truncateProportionally(
str: string,
targetChars: number,
prefix: string,
headRatio: number = 0.2,
): string {
if (str.length <= targetChars) return str;
const ellipsis = '\n...\n';
const overhead = prefix.length + ellipsis.length + 1; // +1 for the newline after prefix
const availableChars = Math.max(0, targetChars - overhead);
if (availableChars <= 0) {
return prefix; // Safe fallback if target is extremely small
}
const headChars = Math.floor(availableChars * headRatio);
const tailChars = availableChars - headChars;
return `${prefix}\n${str.substring(0, headChars)}${ellipsis}${str.substring(str.length - tailChars)}`;
}
/**
* Safely normalizes a function response by truncating large string values
* within the response object while maintaining its JSON structure.
*/
export function normalizeFunctionResponse(
part: Part,
ratio: number,
headRatio: number = 0.2,
savedPath?: string,
intentSummary?: string,
): Part {
const fr = part.functionResponse;
if (!fr || !fr.response) return part;
const responseObj = fr.response;
if (typeof responseObj !== 'object' || responseObj === null) return part;
let hasChanges = false;
const newResponse: Record<string, unknown> = {};
// For function responses, we truncate individual string values that are large.
// This preserves the schema keys (stdout, stderr, etc).
for (const [key, value] of Object.entries(responseObj)) {
if (typeof value === 'string' && value.length > MIN_CHARS_FOR_TRUNCATION) {
const valueTokens = estimateTokenCountSync([{ text: value }]);
const targetValueTokens = Math.max(
MIN_TARGET_TOKENS,
Math.floor(valueTokens * ratio),
);
const targetChars = estimateCharsFromTokens(value, targetValueTokens);
if (value.length > targetChars) {
let truncated = truncateProportionally(
value,
targetChars,
TOOL_TRUNCATION_PREFIX,
headRatio,
);
if (savedPath) {
truncated += `\n\nFull output saved to: ${savedPath}`;
}
if (intentSummary) {
truncated += intentSummary;
}
newResponse[key] = truncated;
hasChanges = true;
} else {
newResponse[key] = value;
}
} else {
newResponse[key] = value;
}
}
if (!hasChanges) return part;
return {
functionResponse: {
// This spread should be safe as we mostly care about the function
// response properties.
// eslint-disable-next-line @typescript-eslint/no-misused-spread
...fr,
response: newResponse,
},
};
}