mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 19:44:30 -07:00
feat(core): Migrate chatCompressionService to model configs. (#12863)
This commit is contained in:
@@ -183,5 +183,25 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
},
|
||||
'chat-compression-3-pro': {
|
||||
modelConfig: {
|
||||
model: 'gemini-3-pro-preview',
|
||||
},
|
||||
},
|
||||
'chat-compression-2.5-pro': {
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-pro',
|
||||
},
|
||||
},
|
||||
'chat-compression-2.5-flash': {
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-flash',
|
||||
},
|
||||
},
|
||||
'chat-compression-2.5-flash-lite': {
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-flash-lite',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -8,6 +8,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
ChatCompressionService,
|
||||
findCompressSplitPoint,
|
||||
modelStringToModelConfigAlias,
|
||||
} from './chatCompressionService.js';
|
||||
import type { Content, GenerateContentResponse } from '@google/genai';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
@@ -15,7 +16,7 @@ import { tokenLimit } from '../core/tokenLimits.js';
|
||||
import type { GeminiChat } from '../core/geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
|
||||
vi.mock('../core/tokenLimits.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
@@ -101,11 +102,34 @@ describe('findCompressSplitPoint', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('modelStringToModelConfigAlias', () => {
|
||||
it('should return the default model for unexpected aliases', () => {
|
||||
expect(modelStringToModelConfigAlias('gemini-flash-flash')).toBe(
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle valid names', () => {
|
||||
expect(modelStringToModelConfigAlias('gemini-3-pro-preview')).toBe(
|
||||
'chat-compression-3-pro',
|
||||
);
|
||||
expect(modelStringToModelConfigAlias('gemini-2.5-pro')).toBe(
|
||||
'chat-compression-2.5-pro',
|
||||
);
|
||||
expect(modelStringToModelConfigAlias('gemini-2.5-flash')).toBe(
|
||||
'chat-compression-2.5-flash',
|
||||
);
|
||||
expect(modelStringToModelConfigAlias('gemini-2.5-flash-lite')).toBe(
|
||||
'chat-compression-2.5-flash-lite',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ChatCompressionService', () => {
|
||||
let service: ChatCompressionService;
|
||||
let mockChat: GeminiChat;
|
||||
let mockConfig: Config;
|
||||
const mockModel = 'gemini-pro';
|
||||
const mockModel = 'gemini-2.5-pro';
|
||||
const mockPromptId = 'test-prompt-id';
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -114,9 +138,22 @@ describe('ChatCompressionService', () => {
|
||||
getHistory: vi.fn(),
|
||||
getLastPromptTokenCount: vi.fn().mockReturnValue(500),
|
||||
} as unknown as GeminiChat;
|
||||
|
||||
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Summary' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
|
||||
mockConfig = {
|
||||
getCompressionThreshold: vi.fn(),
|
||||
getContentGenerator: vi.fn(),
|
||||
getBaseLlmClient: vi.fn().mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
}),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -190,18 +227,6 @@ describe('ChatCompressionService', () => {
|
||||
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(800);
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Summary' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
} as unknown as ContentGenerator);
|
||||
|
||||
const result = await service.compress(
|
||||
mockChat,
|
||||
@@ -215,7 +240,7 @@ describe('ChatCompressionService', () => {
|
||||
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
|
||||
expect(result.newHistory).not.toBeNull();
|
||||
expect(result.newHistory![0].parts![0].text).toBe('Summary');
|
||||
expect(mockGenerateContent).toHaveBeenCalled();
|
||||
expect(mockConfig.getBaseLlmClient().generateContent).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should force compress even if under threshold', async () => {
|
||||
@@ -229,19 +254,6 @@ describe('ChatCompressionService', () => {
|
||||
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(100);
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
|
||||
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Summary' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
} as unknown as ContentGenerator);
|
||||
|
||||
const result = await service.compress(
|
||||
mockChat,
|
||||
mockPromptId,
|
||||
@@ -265,7 +277,7 @@ describe('ChatCompressionService', () => {
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
|
||||
const longSummary = 'a'.repeat(1000); // Long summary to inflate token count
|
||||
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||
vi.mocked(mockConfig.getBaseLlmClient().generateContent).mockResolvedValue({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
@@ -274,9 +286,6 @@ describe('ChatCompressionService', () => {
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse);
|
||||
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||
generateContent: mockGenerateContent,
|
||||
} as unknown as ContentGenerator);
|
||||
|
||||
const result = await service.compress(
|
||||
mockChat,
|
||||
|
||||
@@ -14,6 +14,12 @@ import { getResponseText } from '../utils/partUtils.js';
|
||||
import { logChatCompression } from '../telemetry/loggers.js';
|
||||
import { makeChatCompressionEvent } from '../telemetry/types.js';
|
||||
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
|
||||
/**
|
||||
* Default threshold for compression token count as a fraction of the model's
|
||||
@@ -75,6 +81,21 @@ export function findCompressSplitPoint(
|
||||
return lastSplitPoint;
|
||||
}
|
||||
|
||||
export function modelStringToModelConfigAlias(model: string): string {
|
||||
switch (model) {
|
||||
case PREVIEW_GEMINI_MODEL:
|
||||
return 'chat-compression-3-pro';
|
||||
case DEFAULT_GEMINI_MODEL:
|
||||
return 'chat-compression-2.5-pro';
|
||||
case DEFAULT_GEMINI_FLASH_MODEL:
|
||||
return 'chat-compression-2.5-flash';
|
||||
case DEFAULT_GEMINI_FLASH_LITE_MODEL:
|
||||
return 'chat-compression-2.5-flash-lite';
|
||||
default:
|
||||
return DEFAULT_GEMINI_MODEL;
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatCompressionService {
|
||||
async compress(
|
||||
chat: GeminiChat,
|
||||
@@ -139,26 +160,24 @@ export class ChatCompressionService {
|
||||
};
|
||||
}
|
||||
|
||||
const summaryResponse = await config.getContentGenerator().generateContent(
|
||||
{
|
||||
model,
|
||||
contents: [
|
||||
...historyToCompress,
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
config: {
|
||||
systemInstruction: { text: getCompressionPrompt() },
|
||||
const summaryResponse = await config.getBaseLlmClient().generateContent({
|
||||
modelConfigKey: { model: modelStringToModelConfigAlias(model) },
|
||||
contents: [
|
||||
...historyToCompress,
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
],
|
||||
systemInstruction: { text: getCompressionPrompt() },
|
||||
promptId,
|
||||
);
|
||||
// TODO(joshualitt): wire up a sensible abort signal,
|
||||
abortSignal: new AbortController().signal,
|
||||
});
|
||||
const summary = getResponseText(summaryResponse) ?? '';
|
||||
|
||||
const extraHistory: Content[] = [
|
||||
|
||||
@@ -198,5 +198,21 @@
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
},
|
||||
"chat-compression-3-pro": {
|
||||
"model": "gemini-3-pro-preview",
|
||||
"generateContentConfig": {}
|
||||
},
|
||||
"chat-compression-2.5-pro": {
|
||||
"model": "gemini-2.5-pro",
|
||||
"generateContentConfig": {}
|
||||
},
|
||||
"chat-compression-2.5-flash": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {}
|
||||
},
|
||||
"chat-compression-2.5-flash-lite": {
|
||||
"model": "gemini-2.5-flash-lite",
|
||||
"generateContentConfig": {}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user