From 1d2f90c7e76c5c199a9ce00fe9aa6fbb153af54d Mon Sep 17 00:00:00 2001 From: Silvio Junior Date: Wed, 5 Nov 2025 16:15:28 -0500 Subject: [PATCH] Add compression mechanism to subagent (#12506) --- packages/core/src/agents/executor.test.ts | 249 +++++++++++++++++++++- packages/core/src/agents/executor.ts | 36 +++- 2 files changed, 280 insertions(+), 5 deletions(-) diff --git a/packages/core/src/agents/executor.test.ts b/packages/core/src/agents/executor.test.ts index 13e56c6a87..3d58df3704 100644 --- a/packages/core/src/agents/executor.test.ts +++ b/packages/core/src/agents/executor.test.ts @@ -4,7 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import { AgentExecutor, type ActivityCallback } from './executor.js'; import { makeFakeConfig } from '../test-utils/config.js'; import { ToolRegistry } from '../tools/tool-registry.js'; @@ -20,6 +28,7 @@ import { type Part, type GenerateContentResponse, type GenerateContentConfig, + type Content, } from '@google/genai'; import type { Config } from '../config/config.js'; import { MockTool } from '../test-utils/mock-tool.js'; @@ -44,10 +53,26 @@ import type { } from './types.js'; import { AgentTerminateMode } from './types.js'; import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; +import { CompressionStatus } from '../core/turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; -const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({ - mockSendMessageStream: vi.fn(), - mockExecuteToolCall: vi.fn(), +const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted( + () => ({ + mockSendMessageStream: vi.fn(), + mockExecuteToolCall: vi.fn(), + mockCompress: vi.fn(), + }), +); + +let mockChatHistory: Content[] = []; +const mockSetHistory = vi.fn((newHistory: Content[]) => { + mockChatHistory = newHistory; +}); + +vi.mock('../services/chatCompressionService.js', () => ({ + ChatCompressionService: vi.fn().mockImplementation(() => ({ + compress: mockCompress, + })), })); vi.mock('../core/geminiChat.js', async (importOriginal) => { @@ -56,6 +81,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => { ...actual, GeminiChat: vi.fn().mockImplementation(() => ({ sendMessageStream: mockSendMessageStream, + getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), + setHistory: mockSetHistory, })), }; }); @@ -193,6 +220,8 @@ describe('AgentExecutor', () => { beforeEach(async () => { vi.resetAllMocks(); + mockCompress.mockClear(); + mockSetHistory.mockClear(); mockSendMessageStream.mockReset(); mockExecuteToolCall.mockReset(); mockedLogAgentStart.mockReset(); @@ -200,10 +229,21 @@ describe('AgentExecutor', () => { mockedPromptIdContext.getStore.mockReset(); mockedPromptIdContext.run.mockImplementation((_id, fn) => fn()); + (ChatCompressionService as Mock).mockImplementation(() => ({ + compress: mockCompress, + })); + mockCompress.mockResolvedValue({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + MockedGeminiChat.mockImplementation( () => ({ sendMessageStream: mockSendMessageStream, + getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]), + getLastPromptTokenCount: vi.fn(() => 100), + setHistory: mockSetHistory, }) as unknown as GeminiChat, ); @@ -1440,4 +1480,205 @@ describe('AgentExecutor', () => { expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS); }); }); + describe('Chat Compression', () => { + const mockWorkResponse = (id: string) => { + mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); + mockExecuteToolCall.mockResolvedValueOnce({ + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }); + }; + + it('should attempt to compress chat history on each turn', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + // Mock compression to do nothing + mockCompress.mockResolvedValue({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + + // Turn 2: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 'call2', + }, + ], + 'T2', + ); + + await executor.run({ goal: 'Compress test' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(2); + }); + + it('should update chat history when compression is successful', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'compressed' }] }, + ]; + + mockCompress.mockResolvedValue({ + newHistory: compressedHistory, + info: { compressionStatus: CompressionStatus.COMPRESSED }, + }); + + // Turn 1: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 'call1', + }, + ], + 'T1', + ); + + await executor.run({ goal: 'Compress success' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory); + }); + + it('should pass hasFailedCompressionAttempt=true to compression after a failure', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + // First call fails + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + // Second call is neutral + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + // Turn 2: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 't2', + }, + ], + 'T2', + ); + + await executor.run({ goal: 'Compress fail' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(2); + // First call, hasFailedCompressionAttempt is false + expect(mockCompress.mock.calls[0][5]).toBe(false); + // Second call, hasFailedCompressionAttempt is true + expect(mockCompress.mock.calls[1][5]).toBe(true); + }); + + it('should reset hasFailedCompressionAttempt flag after a successful compression', async () => { + const definition = createTestDefinition(); + const executor = await AgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + const compressedHistory: Content[] = [ + { role: 'user', parts: [{ text: 'compressed' }] }, + ]; + + // Turn 1: Fails + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + }); + // Turn 2: Succeeds + mockCompress.mockResolvedValueOnce({ + newHistory: compressedHistory, + info: { compressionStatus: CompressionStatus.COMPRESSED }, + }); + // Turn 3: Neutral + mockCompress.mockResolvedValueOnce({ + newHistory: null, + info: { compressionStatus: CompressionStatus.NOOP }, + }); + + // Turn 1 + mockWorkResponse('t1'); + // Turn 2 + mockWorkResponse('t2'); + // Turn 3: Complete + mockModelResponse( + [ + { + name: TASK_COMPLETE_TOOL_NAME, + args: { finalResult: 'Done' }, + id: 't3', + }, + ], + 'T3', + ); + + await executor.run({ goal: 'Compress reset' }, signal); + + expect(mockCompress).toHaveBeenCalledTimes(3); + // Call 1: hasFailed... is false + expect(mockCompress.mock.calls[0][5]).toBe(false); + // Call 2: hasFailed... is true + expect(mockCompress.mock.calls[1][5]).toBe(true); + // Call 3: hasFailed... is false again + expect(mockCompress.mock.calls[2][5]).toBe(false); + + expect(mockSetHistory).toHaveBeenCalledTimes(1); + expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory); + }); + }); }); diff --git a/packages/core/src/agents/executor.ts b/packages/core/src/agents/executor.ts index 8928a75e69..ac0b9a27af 100644 --- a/packages/core/src/agents/executor.ts +++ b/packages/core/src/agents/executor.ts @@ -18,7 +18,8 @@ import type { } from '@google/genai'; import { executeToolCall } from '../core/nonInteractiveToolExecutor.js'; import { ToolRegistry } from '../tools/tool-registry.js'; -import type { ToolCallRequestInfo } from '../core/turn.js'; +import { type ToolCallRequestInfo, CompressionStatus } from '../core/turn.js'; +import { ChatCompressionService } from '../services/chatCompressionService.js'; import { getDirectoryContextString } from '../utils/environmentContext.js'; import { GLOB_TOOL_NAME, @@ -84,6 +85,8 @@ export class AgentExecutor { private readonly toolRegistry: ToolRegistry; private readonly runtimeContext: Config; private readonly onActivity?: ActivityCallback; + private readonly compressionService: ChatCompressionService; + private hasFailedCompressionAttempt = false; /** * Creates and validates a new `AgentExecutor` instance. @@ -159,6 +162,7 @@ export class AgentExecutor { this.runtimeContext = runtimeContext; this.toolRegistry = toolRegistry; this.onActivity = onActivity; + this.compressionService = new ChatCompressionService(); const randomIdPart = Math.random().toString(36).slice(2, 8); // parentPromptId will be undefined if this agent is invoked directly @@ -184,6 +188,8 @@ export class AgentExecutor { ): Promise { const promptId = `${this.agentId}#${turnCounter}`; + await this.tryCompressChat(chat, promptId); + const { functionCalls } = await promptIdContext.run(promptId, async () => this.callModel(chat, currentMessage, tools, combinedSignal, promptId), ); @@ -548,6 +554,34 @@ export class AgentExecutor { } } + private async tryCompressChat( + chat: GeminiChat, + prompt_id: string, + ): Promise { + const model = this.definition.modelConfig.model; + + const { newHistory, info } = await this.compressionService.compress( + chat, + prompt_id, + false, + model, + this.runtimeContext, + this.hasFailedCompressionAttempt, + ); + + if ( + info.compressionStatus === + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT + ) { + this.hasFailedCompressionAttempt = true; + } else if (info.compressionStatus === CompressionStatus.COMPRESSED) { + if (newHistory) { + chat.setHistory(newHistory); + this.hasFailedCompressionAttempt = false; + } + } + } + /** * Calls the generative model with the current context and tools. *