mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-11 22:00:41 -07:00
Add compression mechanism to subagent (#12506)
This commit is contained in:
@@ -4,7 +4,15 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { AgentExecutor, type ActivityCallback } from './executor.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
@@ -20,6 +28,7 @@ import {
|
||||
type Part,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentConfig,
|
||||
type Content,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
@@ -44,10 +53,26 @@ import type {
|
||||
} from './types.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
|
||||
const { mockSendMessageStream, mockExecuteToolCall } = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||
() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
let mockChatHistory: Content[] = [];
|
||||
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||
mockChatHistory = newHistory;
|
||||
});
|
||||
|
||||
vi.mock('../services/chatCompressionService.js', () => ({
|
||||
ChatCompressionService: vi.fn().mockImplementation(() => ({
|
||||
compress: mockCompress,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
@@ -56,6 +81,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
...actual,
|
||||
GeminiChat: vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
})),
|
||||
};
|
||||
});
|
||||
@@ -193,6 +220,8 @@ describe('AgentExecutor', () => {
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
mockCompress.mockClear();
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
@@ -200,10 +229,21 @@ describe('AgentExecutor', () => {
|
||||
mockedPromptIdContext.getStore.mockReset();
|
||||
mockedPromptIdContext.run.mockImplementation((_id, fn) => fn());
|
||||
|
||||
(ChatCompressionService as Mock).mockImplementation(() => ({
|
||||
compress: mockCompress,
|
||||
}));
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
MockedGeminiChat.mockImplementation(
|
||||
() =>
|
||||
({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
}) as unknown as GeminiChat,
|
||||
);
|
||||
|
||||
@@ -1440,4 +1480,205 @@ describe('AgentExecutor', () => {
|
||||
expect(recoveryEvent.reason).toBe(AgentTerminateMode.MAX_TURNS);
|
||||
});
|
||||
});
|
||||
describe('Chat Compression', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
it('should attempt to compress chat history on each turn', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// Mock compression to do nothing
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
|
||||
// Turn 2: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call2',
|
||||
},
|
||||
],
|
||||
'T2',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress test' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should update chat history when compression is successful', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const compressedHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||
];
|
||||
|
||||
mockCompress.mockResolvedValue({
|
||||
newHistory: compressedHistory,
|
||||
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||
});
|
||||
|
||||
// Turn 1: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 'call1',
|
||||
},
|
||||
],
|
||||
'T1',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress success' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||
});
|
||||
|
||||
it('should pass hasFailedCompressionAttempt=true to compression after a failure', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
// First call fails
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: {
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
},
|
||||
});
|
||||
// Second call is neutral
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
// Turn 2: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 't2',
|
||||
},
|
||||
],
|
||||
'T2',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress fail' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(2);
|
||||
// First call, hasFailedCompressionAttempt is false
|
||||
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||
// Second call, hasFailedCompressionAttempt is true
|
||||
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||
});
|
||||
|
||||
it('should reset hasFailedCompressionAttempt flag after a successful compression', async () => {
|
||||
const definition = createTestDefinition();
|
||||
const executor = await AgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
const compressedHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'compressed' }] },
|
||||
];
|
||||
|
||||
// Turn 1: Fails
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: {
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
},
|
||||
});
|
||||
// Turn 2: Succeeds
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: compressedHistory,
|
||||
info: { compressionStatus: CompressionStatus.COMPRESSED },
|
||||
});
|
||||
// Turn 3: Neutral
|
||||
mockCompress.mockResolvedValueOnce({
|
||||
newHistory: null,
|
||||
info: { compressionStatus: CompressionStatus.NOOP },
|
||||
});
|
||||
|
||||
// Turn 1
|
||||
mockWorkResponse('t1');
|
||||
// Turn 2
|
||||
mockWorkResponse('t2');
|
||||
// Turn 3: Complete
|
||||
mockModelResponse(
|
||||
[
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'Done' },
|
||||
id: 't3',
|
||||
},
|
||||
],
|
||||
'T3',
|
||||
);
|
||||
|
||||
await executor.run({ goal: 'Compress reset' }, signal);
|
||||
|
||||
expect(mockCompress).toHaveBeenCalledTimes(3);
|
||||
// Call 1: hasFailed... is false
|
||||
expect(mockCompress.mock.calls[0][5]).toBe(false);
|
||||
// Call 2: hasFailed... is true
|
||||
expect(mockCompress.mock.calls[1][5]).toBe(true);
|
||||
// Call 3: hasFailed... is false again
|
||||
expect(mockCompress.mock.calls[2][5]).toBe(false);
|
||||
|
||||
expect(mockSetHistory).toHaveBeenCalledTimes(1);
|
||||
expect(mockSetHistory).toHaveBeenCalledWith(compressedHistory);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user