mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
Add compression mechanism to subagent (#12506)
This commit is contained in:
@@ -4,7 +4,15 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import {
|
||||||
|
describe,
|
||||||
|
it,
|
||||||
|
expect,
|
||||||
|
vi,
|
||||||
|
beforeEach,
|
||||||
|
afterEach,
|
||||||
|
type Mock,
|
||||||
|
} from 'vitest';
|
||||||
import { AgentExecutor, type ActivityCallback } from './executor.js';
|
import { AgentExecutor, type ActivityCallback } from './executor.js';
|
||||||
import { makeFakeConfig } from '../test-utils/config.js';
|
import { makeFakeConfig } from '../test-utils/config.js';
|
||||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
@@ -20,6 +28,7 @@ import {
|
|||||||
type Part,
|
type Part,
|
||||||
type GenerateContentResponse,
|
type GenerateContentResponse,
|
||||||
type GenerateContentConfig,
|
type GenerateContentConfig,
|
||||||
|
type Content,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import { MockTool } from '../test-utils/mock-tool.js';
|
import { MockTool } from '../test-utils/mock-tool.js';
|
||||||
@@ -44,10 +53,26 @@ import type {
|
|||||||
} from './types.js';
|
} from './types.js';
|
||||||
import { AgentTerminateMode } from './types.js';
|
import { AgentTerminateMode } from './types.js';
|
||||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||||
|
import { CompressionStatus } from '../core/turn.js';
|
||||||
|
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||||
|
|
||||||
const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({
|
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||||
|
() => ({
|
||||||
mockSendMessageStream: vi.fn(),
|
mockSendMessageStream: vi.fn(),
|
||||||
mockExecuteToolCall: vi.fn(),
|
mockExecuteToolCall: vi.fn(),
|
||||||
|
mockCompress: vi.fn(),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mockChatHistory: Content[] = [];
|
||||||
|
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||||
|
mockChatHistory = newHistory;
|
||||||
|
});
|
||||||
|
|
||||||
|
vi.mock('../services/chatCompressionService.js', () => ({
|
||||||
|
ChatCompressionService: vi.fn().mockImplementation(() => ({
|
||||||
|
compress: mockCompress,
|
||||||
|
})),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||||
@@ -56,6 +81,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
|||||||
...actual,
|
...actual,
|
||||||
GeminiChat: vi.fn().mockImplementation(() => ({
|
GeminiChat: vi.fn().mockImplementation(() => ({
|
||||||
sendMessageStream: mockSendMessageStream,
|
sendMessageStream: mockSendMessageStream,
|
||||||
|
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||||
|
setHistory: mockSetHistory,
|
||||||
})),
|
})),
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
@@ -193,6 +220,8 @@ describe('AgentExecutor', () => {
|
|||||||
|
|
||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
vi.resetAllMocks();
|
vi.resetAllMocks();
|
||||||
|
mockCompress.mockClear();
|
||||||
|
mockSetHistory.mockClear();
|
||||||
mockSendMessageStream.mockReset();
|
mockSendMessageStream.mockReset();
|
||||||
mockExecuteToolCall.mockReset();
|
mockExecuteToolCall.mockReset();
|
||||||
mockedLogAgentStart.mockReset();
|
mockedLogAgentStart.mockReset();
|
||||||
@@ -200,10 +229,21 @@ describe('AgentExecutor', () => {
|
|||||||
mockedPromptIdContext.getStore.mockReset();
|
mockedPromptIdContext.getStore.mockReset();
|
||||||
mockedPromptIdContext.run.mockImplementation((_id, fn) => fn());
|
mockedPromptIdContext.run.mockImplementation((_id, fn) => fn());
|
||||||
|
|
||||||
|
(ChatCompressionService as Mock).mockImplementation(() => ({
|
||||||
|
compress: mockCompress,
|
||||||
|
}));
|
||||||
|
mockCompress.mockResolvedValue({
|
||||||
|
newHistory: null,
|
||||||
|
info: { compressionStatus: CompressionStatus.NOOP },
|
||||||
|
});
|
||||||
|
|
||||||
MockedGeminiChat.mockImplementation(
|
MockedGeminiChat.mockImplementation(
|
||||||
() =>
|
() =>
|
||||||
({
|
({
|
||||||
sendMessageStream: mockSendMessageStream,
|
sendMessageStream: mockSendMessageStream,
|
||||||
|
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||||
|
getLastPromptTokenCount: vi.fn(() => 100),
|
||||||
|
setHistory: mockSetHistory,
|
||||||
}) as unknown as GeminiChat,
|
}) as unknown as GeminiChat,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -1440,4 +1480,205 @@ describe('AgentExecutor', () => {
|
|||||||
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
describe('Chat Compression', () => {
|
||||||
|
const mockWorkResponse = (id: string) => {
|
||||||
|
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||||
|
mockExecuteToolCall.mockResolvedValueOnce({
|
||||||
|
status: 'success',
|
||||||
|
request: {
|
||||||
|
callId: id,
|
||||||
|
name: LS_TOOL_NAME,
|
||||||
|
args: { path: '.' },
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
tool: {} as AnyDeclarativeTool,
|
||||||
|
invocation: {} as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: id,
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
responseParts: [
|
||||||
|
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
it('should attempt to compress chat history on each turn', async () => {
|
||||||
|
const definition = createTestDefinition();
|
||||||
|
const executor = await AgentExecutor.create(
|
||||||
|
definition,
|
||||||
|
mockConfig,
|
||||||
|
onActivity,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Mock compression to do nothing
|
||||||
|
mockCompress.mockResolvedValue({
|
||||||
|
newHistory: null,
|
||||||
|
info: { compressionStatus: CompressionStatus.NOOP },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
mockWorkResponse('t1');
|
||||||
|
|
||||||
|
// Turn 2: Complete
|
||||||
|
mockModelResponse(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
|
args: { finalResult: 'Done' },
|
||||||
|
id: 'call2',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'T2',
|
||||||
|
);
|
||||||
|
|
||||||
|
await executor.run({ goal: 'Compress test' }, signal);
|
||||||
|
|
||||||
|
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should update chat history when compression is successful', async () => {
|
||||||
|
const definition = createTestDefinition();
|
||||||
|
const executor = await AgentExecutor.create(
|
||||||
|
definition,
|
||||||
|
mockConfig,
|
||||||
|
onActivity,
|
||||||
|
);
|
||||||
|
const compressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
mockCompress.mockResolvedValue({
|
||||||
|
newHistory: compressedHistory,
|
||||||
|
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Turn 1: Complete
|
||||||
|
mockModelResponse(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
|
args: { finalResult: 'Done' },
|
||||||
|
id: 'call1',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'T1',
|
||||||
|
);
|
||||||
|
|
||||||
|
await executor.run({ goal: 'Compress success' }, signal);
|
||||||
|
|
||||||
|
expect(mockCompress).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should pass hasFailedCompressionAttempt=true to compression after a failure', async () => {
|
||||||
|
const definition = createTestDefinition();
|
||||||
|
const executor = await AgentExecutor.create(
|
||||||
|
definition,
|
||||||
|
mockConfig,
|
||||||
|
onActivity,
|
||||||
|
);
|
||||||
|
|
||||||
|
// First call fails
|
||||||
|
mockCompress.mockResolvedValueOnce({
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// Second call is neutral
|
||||||
|
mockCompress.mockResolvedValueOnce({
|
||||||
|
newHistory: null,
|
||||||
|
info: { compressionStatus: CompressionStatus.NOOP },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
mockWorkResponse('t1');
|
||||||
|
// Turn 2: Complete
|
||||||
|
mockModelResponse(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
|
args: { finalResult: 'Done' },
|
||||||
|
id: 't2',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'T2',
|
||||||
|
);
|
||||||
|
|
||||||
|
await executor.run({ goal: 'Compress fail' }, signal);
|
||||||
|
|
||||||
|
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||||
|
// First call, hasFailedCompressionAttempt is false
|
||||||
|
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||||
|
// Second call, hasFailedCompressionAttempt is true
|
||||||
|
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reset hasFailedCompressionAttempt flag after a successful compression', async () => {
|
||||||
|
const definition = createTestDefinition();
|
||||||
|
const executor = await AgentExecutor.create(
|
||||||
|
definition,
|
||||||
|
mockConfig,
|
||||||
|
onActivity,
|
||||||
|
);
|
||||||
|
const compressedHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
// Turn 1: Fails
|
||||||
|
mockCompress.mockResolvedValueOnce({
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
// Turn 2: Succeeds
|
||||||
|
mockCompress.mockResolvedValueOnce({
|
||||||
|
newHistory: compressedHistory,
|
||||||
|
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||||
|
});
|
||||||
|
// Turn 3: Neutral
|
||||||
|
mockCompress.mockResolvedValueOnce({
|
||||||
|
newHistory: null,
|
||||||
|
info: { compressionStatus: CompressionStatus.NOOP },
|
||||||
|
});
|
||||||
|
|
||||||
|
// Turn 1
|
||||||
|
mockWorkResponse('t1');
|
||||||
|
// Turn 2
|
||||||
|
mockWorkResponse('t2');
|
||||||
|
// Turn 3: Complete
|
||||||
|
mockModelResponse(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
|
args: { finalResult: 'Done' },
|
||||||
|
id: 't3',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
'T3',
|
||||||
|
);
|
||||||
|
|
||||||
|
await executor.run({ goal: 'Compress reset' }, signal);
|
||||||
|
|
||||||
|
expect(mockCompress).toHaveBeenCalledTimes(3);
|
||||||
|
// Call 1: hasFailed... is false
|
||||||
|
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||||
|
// Call 2: hasFailed... is true
|
||||||
|
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||||
|
// Call 3: hasFailed... is false again
|
||||||
|
expect(mockCompress.mock.calls[2][5]).toBe(false);
|
||||||
|
|
||||||
|
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -18,7 +18,8 @@ import type {
|
|||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
import { type ToolCallRequestInfo, CompressionStatus } from '../core/turn.js';
|
||||||
|
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||||
import { getDirectoryContextString } from '../utils/environmentContext.js';
|
import { getDirectoryContextString } from '../utils/environmentContext.js';
|
||||||
import {
|
import {
|
||||||
GLOB_TOOL_NAME,
|
GLOB_TOOL_NAME,
|
||||||
@@ -84,6 +85,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
private readonly toolRegistry: ToolRegistry;
|
private readonly toolRegistry: ToolRegistry;
|
||||||
private readonly runtimeContext: Config;
|
private readonly runtimeContext: Config;
|
||||||
private readonly onActivity?: ActivityCallback;
|
private readonly onActivity?: ActivityCallback;
|
||||||
|
private readonly compressionService: ChatCompressionService;
|
||||||
|
private hasFailedCompressionAttempt = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates and validates a new `AgentExecutor` instance.
|
* Creates and validates a new `AgentExecutor` instance.
|
||||||
@@ -159,6 +162,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
this.runtimeContext = runtimeContext;
|
this.runtimeContext = runtimeContext;
|
||||||
this.toolRegistry = toolRegistry;
|
this.toolRegistry = toolRegistry;
|
||||||
this.onActivity = onActivity;
|
this.onActivity = onActivity;
|
||||||
|
this.compressionService = new ChatCompressionService();
|
||||||
|
|
||||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||||
// parentPromptId will be undefined if this agent is invoked directly
|
// parentPromptId will be undefined if this agent is invoked directly
|
||||||
@@ -184,6 +188,8 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
): Promise<AgentTurnResult> {
|
): Promise<AgentTurnResult> {
|
||||||
const promptId = `${this.agentId}#${turnCounter}`;
|
const promptId = `${this.agentId}#${turnCounter}`;
|
||||||
|
|
||||||
|
await this.tryCompressChat(chat, promptId);
|
||||||
|
|
||||||
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
const { functionCalls } = await promptIdContext.run(promptId, async () =>
|
||||||
this.callModel(chat, currentMessage, tools, combinedSignal, promptId),
|
this.callModel(chat, currentMessage, tools, combinedSignal, promptId),
|
||||||
);
|
);
|
||||||
@@ -548,6 +554,34 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async tryCompressChat(
|
||||||
|
chat: GeminiChat,
|
||||||
|
prompt_id: string,
|
||||||
|
): Promise<void> {
|
||||||
|
const model = this.definition.modelConfig.model;
|
||||||
|
|
||||||
|
const { newHistory, info } = await this.compressionService.compress(
|
||||||
|
chat,
|
||||||
|
prompt_id,
|
||||||
|
false,
|
||||||
|
model,
|
||||||
|
this.runtimeContext,
|
||||||
|
this.hasFailedCompressionAttempt,
|
||||||
|
);
|
||||||
|
|
||||||
|
if (
|
||||||
|
info.compressionStatus ===
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT
|
||||||
|
) {
|
||||||
|
this.hasFailedCompressionAttempt = true;
|
||||||
|
} else if (info.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||||
|
if (newHistory) {
|
||||||
|
chat.setHistory(newHistory);
|
||||||
|
this.hasFailedCompressionAttempt = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calls the generative model with the current context and tools.
|
* Calls the generative model with the current context and tools.
|
||||||
*
|
*
|
||||||
|
|||||||
Reference in New Issue
Block a user