mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-15 06:12:50 -07:00
feat(core): implement bounded history retention for GeminiChat
- Improved ToolOutputMaskingService to mask exceptionally large outputs (> 2x threshold) even in the latest turn. - Implemented pruneHistory in GeminiChat with configurable turn and token limits to provide a hard safety net against OOM. - Added getters and initialization for experimental history truncation settings in Config. - Added comprehensive unit tests in geminiChat_pruning.test.ts. - Verified with npm test and build.
This commit is contained in:
@@ -1,15 +1,20 @@
|
||||
{
|
||||
"experimental": {
|
||||
"plan": true,
|
||||
"extensionReloading": true,
|
||||
"modelSteering": true,
|
||||
"memoryManager": true,
|
||||
"topicUpdateNarration": true
|
||||
},
|
||||
"general": {
|
||||
"devtools": true
|
||||
"devtools": true,
|
||||
"plan": {
|
||||
"enabled": true
|
||||
}
|
||||
},
|
||||
"security": {
|
||||
"toolSandboxing": true
|
||||
},
|
||||
"agents": {
|
||||
"overrides": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
+5
-2
@@ -166,8 +166,11 @@ too aggressively.
|
||||
ChatRecordingService and GeminiChat history).
|
||||
- [x] Implement bounded retention for ChatRecordingService (implemented
|
||||
memory-based cache eviction and leak prevention during resets).
|
||||
- [ ] Implement bounded retention for GeminiChat (improve Tool Output Masking
|
||||
or add hard history bounds).
|
||||
- [x] Implement bounded retention for GeminiChat (improve Tool Output Masking
|
||||
or add hard history bounds). - Improved `ToolOutputMaskingService` to
|
||||
mask massive outputs (> 2x threshold) even in the latest turn. -
|
||||
Implemented `pruneHistory` in `GeminiChat` with configurable turn and
|
||||
token limits to provide a hard memory safety net.
|
||||
- [ ] Audit React/Ink components for event listener leaks.
|
||||
|
||||
3. **Phase 3: Validation & CI**
|
||||
|
||||
@@ -957,6 +957,10 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
readonly injectionService: InjectionService;
|
||||
private approvedPlanPath: string | undefined;
|
||||
|
||||
private readonly experimentalAgentHistoryTruncation: boolean;
|
||||
private readonly experimentalAgentHistoryTruncationThreshold: number;
|
||||
private readonly experimentalAgentHistoryRetainedMessages: number;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this._sessionId = params.sessionId;
|
||||
this.clientName = params.clientName;
|
||||
@@ -964,6 +968,12 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
this.approvedPlanPath = undefined;
|
||||
this.embeddingModel =
|
||||
params.embeddingModel ?? DEFAULT_GEMINI_EMBEDDING_MODEL;
|
||||
this.experimentalAgentHistoryTruncation =
|
||||
params.experimentalAgentHistoryTruncation ?? false;
|
||||
this.experimentalAgentHistoryTruncationThreshold =
|
||||
params.experimentalAgentHistoryTruncationThreshold ?? 100000;
|
||||
this.experimentalAgentHistoryRetainedMessages =
|
||||
params.experimentalAgentHistoryRetainedMessages ?? 20;
|
||||
this.sandbox = params.sandbox
|
||||
? {
|
||||
enabled: params.sandbox.enabled || params.toolSandboxing || false,
|
||||
@@ -3624,6 +3634,21 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
return this.disabledHooks;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get experimental agent history truncation settings
|
||||
*/
|
||||
isExperimentalAgentHistoryTruncationEnabled(): boolean {
|
||||
return this.experimentalAgentHistoryTruncation;
|
||||
}
|
||||
|
||||
getExperimentalAgentHistoryTruncationThreshold(): number {
|
||||
return this.experimentalAgentHistoryTruncationThreshold;
|
||||
}
|
||||
|
||||
getExperimentalAgentHistoryRetainedMessages(): number {
|
||||
return this.experimentalAgentHistoryRetainedMessages;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get experiments configuration
|
||||
*/
|
||||
|
||||
@@ -49,6 +49,7 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
@@ -744,16 +745,71 @@ export class GeminiChat {
|
||||
*/
|
||||
addHistory(content: Content): void {
|
||||
this.history.push(content);
|
||||
this.pruneHistory();
|
||||
}
|
||||
|
||||
setHistory(history: readonly Content[]): void {
|
||||
this.history = [...history];
|
||||
this.pruneHistory();
|
||||
this.lastPromptTokenCount = estimateTokenCountSync(
|
||||
this.history.flatMap((c) => c.parts || []),
|
||||
);
|
||||
this.chatRecordingService.updateMessagesFromHistory(history);
|
||||
}
|
||||
|
||||
/**
|
||||
* Prunes the conversation history to stay within memory-safe bounds.
|
||||
* This is a "hard" truncation that acts as a safety net against OOM
|
||||
* when context management or compression is disabled or insufficient.
|
||||
*/
|
||||
private pruneHistory(): void {
|
||||
const config = this.context.config;
|
||||
if (!config.isExperimentalAgentHistoryTruncationEnabled()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const maxTokens = config.getExperimentalAgentHistoryTruncationThreshold();
|
||||
const maxMessages = config.getExperimentalAgentHistoryRetainedMessages();
|
||||
|
||||
// Check if we need to prune at all
|
||||
const totalTokens = estimateTokenCountSync(
|
||||
this.history.flatMap((c) => c.parts || []),
|
||||
);
|
||||
|
||||
if (this.history.length <= maxMessages && totalTokens <= maxTokens) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Always keep at least the last 'maxMessages' messages
|
||||
let prunedHistory =
|
||||
this.history.length > maxMessages
|
||||
? this.history.slice(-maxMessages)
|
||||
: [...this.history];
|
||||
|
||||
// Further prune based on token count if still over threshold
|
||||
let currentTokens = estimateTokenCountSync(
|
||||
prunedHistory.flatMap((c) => c.parts || []),
|
||||
);
|
||||
|
||||
while (
|
||||
prunedHistory.length > 2 && // Keep at least one exchange (user + model)
|
||||
currentTokens > maxTokens
|
||||
) {
|
||||
// Remove the oldest message from the pruned history
|
||||
prunedHistory = prunedHistory.slice(1);
|
||||
currentTokens = estimateTokenCountSync(
|
||||
prunedHistory.flatMap((c) => c.parts || []),
|
||||
);
|
||||
}
|
||||
|
||||
if (prunedHistory.length !== this.history.length) {
|
||||
debugLogger.debug(
|
||||
`[GeminiChat] Pruning history: ${this.history.length} -> ${prunedHistory.length} messages. Tokens: ${currentTokens.toLocaleString()}`,
|
||||
);
|
||||
this.history = prunedHistory;
|
||||
}
|
||||
}
|
||||
|
||||
stripThoughtsFromHistory(): void {
|
||||
this.history = this.history.map((content) => {
|
||||
const newContent = { ...content };
|
||||
@@ -994,6 +1050,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
this.history.push({ role: 'model', parts: consolidatedParts });
|
||||
this.pruneHistory();
|
||||
}
|
||||
|
||||
getLastPromptTokenCount(): number {
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { AgentLoopContext } from '../config/agent-loop-context.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
import { ToolOutputMaskingService } from '../services/toolOutputMaskingService.js';
|
||||
import type { Content, Part } from '@google/genai';
|
||||
|
||||
// Mock token calculation to be predictable
|
||||
vi.mock('../utils/tokenCalculation.js', () => ({
|
||||
estimateTokenCountSync: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('GeminiChat Pruning and Masking', () => {
|
||||
let mockConfig: Partial<Config>;
|
||||
let context: AgentLoopContext;
|
||||
let chat: GeminiChat;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockConfig = {
|
||||
isExperimentalAgentHistoryTruncationEnabled: vi
|
||||
.fn()
|
||||
.mockReturnValue(true),
|
||||
getExperimentalAgentHistoryTruncationThreshold: vi
|
||||
.fn()
|
||||
.mockReturnValue(1000),
|
||||
getExperimentalAgentHistoryRetainedMessages: vi.fn().mockReturnValue(10),
|
||||
getToolOutputMaskingConfig: vi.fn().mockResolvedValue({
|
||||
enabled: true,
|
||||
toolProtectionThreshold: 50,
|
||||
minPrunableTokensThreshold: 10,
|
||||
protectLatestTurn: true,
|
||||
}),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session'),
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project'),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: 'oauth' }),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getExperiments: vi.fn().mockReturnValue({ experimentIds: [] }),
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/tmp/test'),
|
||||
} as unknown as Config['storage'],
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
|
||||
} as unknown as Config['modelConfigService'],
|
||||
};
|
||||
|
||||
context = {
|
||||
config: mockConfig as Config,
|
||||
promptId: 'test-session',
|
||||
} as unknown as AgentLoopContext;
|
||||
|
||||
chat = new GeminiChat(context);
|
||||
|
||||
// Default token estimation: 1 token per message for simplicity
|
||||
vi.mocked(estimateTokenCountSync).mockImplementation(() => 1);
|
||||
});
|
||||
|
||||
describe('History Pruning', () => {
|
||||
it('should prune history when turn limit is exceeded', () => {
|
||||
(
|
||||
mockConfig.getExperimentalAgentHistoryRetainedMessages as Mock
|
||||
).mockReturnValue(4);
|
||||
|
||||
// Add 6 messages (3 turns: user + model)
|
||||
for (let i = 0; i < 6; i++) {
|
||||
chat.addHistory({
|
||||
role: i % 2 === 0 ? 'user' : 'model',
|
||||
parts: [{ text: `msg ${i}` }],
|
||||
});
|
||||
}
|
||||
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(4);
|
||||
expect(history[0].parts![0].text).toBe('msg 2');
|
||||
expect(history[3].parts![0].text).toBe('msg 5');
|
||||
});
|
||||
|
||||
it('should prune history when token limit is exceeded', () => {
|
||||
(
|
||||
mockConfig.getExperimentalAgentHistoryRetainedMessages as Mock
|
||||
).mockReturnValue(10);
|
||||
(
|
||||
mockConfig.getExperimentalAgentHistoryTruncationThreshold as Mock
|
||||
).mockReturnValue(5);
|
||||
|
||||
// Mock token count: each message is 2 tokens
|
||||
vi.mocked(estimateTokenCountSync).mockImplementation(
|
||||
(parts: readonly Part[]) => {
|
||||
if (parts.length === 0) return 0;
|
||||
// If it's a list of parts from multiple messages
|
||||
if (Array.isArray(parts) && parts.length > 1) {
|
||||
return parts.length * 2;
|
||||
}
|
||||
return 2;
|
||||
},
|
||||
);
|
||||
|
||||
// Add 6 messages. Total tokens should be 12.
|
||||
for (let i = 0; i < 6; i++) {
|
||||
chat.addHistory({
|
||||
role: i % 2 === 0 ? 'user' : 'model',
|
||||
parts: [{ text: `msg ${i}` }],
|
||||
});
|
||||
}
|
||||
|
||||
const history = chat.getHistory();
|
||||
// Threshold is 5.
|
||||
// 3 messages = 6 tokens (over)
|
||||
// 2 messages = 4 tokens (under)
|
||||
// So it should prune to 2 messages.
|
||||
expect(history.length).toBe(2);
|
||||
expect(history[0].parts![0].text).toBe('msg 4');
|
||||
expect(history[1].parts![0].text).toBe('msg 5');
|
||||
});
|
||||
|
||||
it('should NOT prune if experimental feature is disabled', () => {
|
||||
(
|
||||
mockConfig.isExperimentalAgentHistoryTruncationEnabled as Mock
|
||||
).mockReturnValue(false);
|
||||
(
|
||||
mockConfig.getExperimentalAgentHistoryRetainedMessages as Mock
|
||||
).mockReturnValue(4);
|
||||
|
||||
for (let i = 0; i < 6; i++) {
|
||||
chat.addHistory({
|
||||
role: i % 2 === 0 ? 'user' : 'model',
|
||||
parts: [{ text: `msg ${i}` }],
|
||||
});
|
||||
}
|
||||
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(6);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Output Masking (Improved)', () => {
|
||||
it('should mask large outputs even in the latest turn if they exceed 2x threshold', async () => {
|
||||
const maskingService = new ToolOutputMaskingService();
|
||||
|
||||
(mockConfig.getToolOutputMaskingConfig as Mock).mockResolvedValue({
|
||||
enabled: true,
|
||||
toolProtectionThreshold: 50,
|
||||
minPrunableTokensThreshold: 10,
|
||||
protectLatestTurn: true,
|
||||
});
|
||||
|
||||
const history: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'huge_tool',
|
||||
response: { output: 'X'.repeat(1000) },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Mock token count: huge_tool output is 200 tokens (> 50 * 2)
|
||||
vi.mocked(estimateTokenCountSync).mockImplementation(
|
||||
(parts: readonly Part[]) => {
|
||||
const response = parts[0]?.functionResponse?.response as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
if (
|
||||
typeof response === 'object' &&
|
||||
typeof response?.output === 'string' &&
|
||||
response.output.includes('tool_output_masked')
|
||||
) {
|
||||
return 5; // Small value for masked content
|
||||
}
|
||||
if (parts[0]?.functionResponse?.name === 'huge_tool') return 200;
|
||||
return 1;
|
||||
},
|
||||
);
|
||||
|
||||
const result = await maskingService.mask(history, mockConfig as Config);
|
||||
expect(result.maskedCount).toBe(1);
|
||||
expect(JSON.stringify(result.newHistory)).toContain('tool_output_masked');
|
||||
});
|
||||
|
||||
it('should NOT mask latest turn if it is below 2x threshold and protectLatestTurn is true', async () => {
|
||||
const maskingService = new ToolOutputMaskingService();
|
||||
|
||||
(mockConfig.getToolOutputMaskingConfig as Mock).mockResolvedValue({
|
||||
enabled: true,
|
||||
toolProtectionThreshold: 50,
|
||||
minPrunableTokensThreshold: 10,
|
||||
protectLatestTurn: true,
|
||||
});
|
||||
|
||||
const history: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: 'normal_tool',
|
||||
response: { output: 'normal' },
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
// Mock token count: normal_tool output is 60 tokens (> 50 but < 50 * 2)
|
||||
vi.mocked(estimateTokenCountSync).mockImplementation(
|
||||
(parts: readonly Part[]) => {
|
||||
if (parts[0]?.functionResponse?.name === 'normal_tool') return 60;
|
||||
return 1;
|
||||
},
|
||||
);
|
||||
|
||||
const result = await maskingService.mask(history, mockConfig as Config);
|
||||
expect(result.maskedCount).toBe(0);
|
||||
expect(JSON.stringify(result.newHistory)).not.toContain(
|
||||
'tool_output_masked',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -90,10 +90,9 @@ export class ToolOutputMaskingService {
|
||||
}> = [];
|
||||
|
||||
// Decide where to start scanning.
|
||||
// If PROTECT_LATEST_TURN is true, we skip the most recent message (index history.length - 1).
|
||||
const scanStartIdx = maskingConfig.protectLatestTurn
|
||||
? history.length - 2
|
||||
: history.length - 1;
|
||||
// If PROTECT_LATEST_TURN is true, we still scan the latest turn but are more
|
||||
// conservative about masking it.
|
||||
const scanStartIdx = history.length - 1;
|
||||
|
||||
// Backward scan to identify prunable tool outputs
|
||||
for (let i = scanStartIdx; i >= 0; i--) {
|
||||
@@ -121,8 +120,22 @@ export class ToolOutputMaskingService {
|
||||
}
|
||||
|
||||
const partTokens = estimateTokenCountSync([part]);
|
||||
const isLatestTurn = i === history.length - 1;
|
||||
|
||||
if (!protectionBoundaryReached) {
|
||||
// If we are in the latest turn and protectLatestTurn is enabled,
|
||||
// we only mask if the part itself is exceptionally large (> 2x threshold).
|
||||
// This ensures that the model usually has full context for its current
|
||||
// task while preventing massive outputs from causing OOM or context overflow.
|
||||
if (
|
||||
isLatestTurn &&
|
||||
maskingConfig.protectLatestTurn &&
|
||||
partTokens <= maskingConfig.toolProtectionThreshold * 2
|
||||
) {
|
||||
cumulativeToolTokens += partTokens;
|
||||
continue;
|
||||
}
|
||||
|
||||
cumulativeToolTokens += partTokens;
|
||||
if (cumulativeToolTokens > maskingConfig.toolProtectionThreshold) {
|
||||
protectionBoundaryReached = true;
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
2026-04-03: Completed Phase 1. Implemented scripts/simulate-long-session.ts to reproduce memory growth. Identified ChatRecordingService and GeminiChat history as primary growth sources. Captured baseline metrics showing ~180MB growth per 200MB of tool output data.
|
||||
2026-04-03: Implemented memory-based cache eviction in ChatRecordingService (50MB threshold). Optimized initialization to prevent carrying over large session records in memory across chat resets. Verified with new unit tests.
|
||||
2026-04-03: Implemented bounded retention for GeminiChat. Improved ToolOutputMaskingService to allow masking massive outputs in the latest turn. Added hard history pruning in GeminiChat with configurable token and turn limits. Verified with new unit tests and successful build.
|
||||
|
||||
Reference in New Issue
Block a user