mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 14:40:52 -07:00
Add compression mechanism to subagent (#12506)
This commit is contained in:
@@ -4,7 +4,15 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { AgentExecutor, type ActivityCallback } from './executor.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
@@ -20,6 +28,7 @@ import {
|
||||
type Part,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentConfig,
|
||||
type Content,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
@@ -44,10 +53,26 @@ import type {
|
||||
} from './types.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
|
||||
const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||
() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
let mockChatHistory: Content[] = [];
|
||||
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||
mockChatHistory = newHistory;
|
||||
});
|
||||
|
||||
vi.mock('../services/chatCompressionService.js', () => ({
|
||||
ChatCompressionService: vi.fn().mockImplementation(() => ({
|
||||
compress: mockCompress,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
@@ -56,6 +81,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
...actual,
|
||||
GeminiChat: vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
})),
|
||||
};
|
||||
});
|
||||
@@ -193,6 +220,8 @@ describe('AgentExecutor', () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
mockCompress.mockClear();
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
@@ -200,10 +229,21 @@ describe('AgentExecutor', () => {
|
||||
mockedPromptIdContext.getStore.mockReset();
|
||||
mockedPromptIdContext.run.mockImplementation((_id, fn) => fn());
|
||||
|
||||
(ChatCompressionService as Mock).mockImplementation(() => ({
|
||||
compress: mockCompress,
|
||||
}));
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
MockedGeminiChat.mockImplementation(
|
||||
() =>
|
||||
({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
}) as unknown as GeminiChat,
|
||||
);
|
||||
|
||||
@@ -1440,4 +1480,205 @@ describe('AgentExecutor', () => {
|
||||
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
});
|
||||
});
|
||||
describe('Chat Compression', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
it('should attempt to compress chat history on each turn', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Mock compression to do nothing
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Turn 2: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call2',
|
||||
},
|
||||
],
|
||||
'T2',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress test' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should update chat history when compression is successful', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const compressedHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||
];
|
||||
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: compressedHistory,
|
||||
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||
});
|
||||
|
||||
// Turn 1: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call1',
|
||||
},
|
||||
],
|
||||
'T1',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress success' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||
});
|
||||
|
||||
it('should pass hasFailedCompressionAttempt=true to compression after a failure', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// First call fails
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: {
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
},
|
||||
});
|
||||
// Second call is neutral
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
// Turn 2: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 't2',
|
||||
},
|
||||
],
|
||||
'T2',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress fail' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||
// First call, hasFailedCompressionAttempt is false
|
||||
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||
// Second call, hasFailedCompressionAttempt is true
|
||||
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||
});
|
||||
|
||||
it('should reset hasFailedCompressionAttempt flag after a successful compression', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const compressedHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||
];
|
||||
|
||||
// Turn 1: Fails
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: {
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
},
|
||||
});
|
||||
// Turn 2: Succeeds
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: compressedHistory,
|
||||
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||
});
|
||||
// Turn 3: Neutral
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
// Turn 2
|
||||
mockWorkResponse('t2');
|
||||
// Turn 3: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 't3',
|
||||
},
|
||||
],
|
||||
'T3',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress reset' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(3);
|
||||
// Call 1: hasFailed... is false
|
||||
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||
// Call 2: hasFailed... is true
|
||||
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||
// Call 3: hasFailed... is false again
|
||||
expect(mockCompress.mock.calls[2][5]).toBe(false);
|
||||
|
||||
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,7 +18,8 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
||||
import { type ToolCallRequestInfo, CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import { getDirectoryContextString } from '../utils/environmentContext.js';
|
||||
import {
|
||||
GLOB_TOOL_NAME,
|
||||
@@ -84,6 +85,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly runtimeContext: Config;
|
||||
private readonly onActivity?: ActivityCallback;
|
||||
private readonly compressionService: ChatCompressionService;
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
/**
|
||||
* Creates and validates a new `AgentExecutor` instance.
|
||||
@@ -159,6 +162,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.runtimeContext = runtimeContext;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.onActivity = onActivity;
|
||||
this.compressionService = new ChatCompressionService();
|
||||
|
||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||
// parentPromptId will be undefined if this agent is invoked directly
|
||||
@@ -184,6 +188,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
): Promise<AgentTurnResult> {
|
||||
const promptId = `${this.agentId}#${turnCounter}`;
|
||||
|
||||
await this.tryCompressChat(chat, promptId);
|
||||
|
||||
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
||||
this.callModel(chat, currentMessage, tools, combinedSignal, promptId),
|
||||
);
|
||||
@@ -548,6 +554,34 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
}
|
||||
|
||||
private async tryCompressChat(
|
||||
chat: GeminiChat,
|
||||
prompt_id: string,
|
||||
): Promise<void> {
|
||||
const model = this.definition.modelConfig.model;
|
||||
|
||||
const { newHistory, info } = await this.compressionService.compress(
|
||||
chat,
|
||||
prompt_id,
|
||||
false,
|
||||
model,
|
||||
this.runtimeContext,
|
||||
this.hasFailedCompressionAttempt,
|
||||
);
|
||||
|
||||
if (
|
||||
info.compressionStatus ===
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT
|
||||
) {
|
||||
this.hasFailedCompressionAttempt = true;
|
||||
} else if (info.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
if (newHistory) {
|
||||
chat.setHistory(newHistory);
|
||||
this.hasFailedCompressionAttempt = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Calls the generative model with the current context and tools.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user