feat(core): Migrate chatCompressionService to model configs. (#12863)

This commit is contained in:
joshualitt
2025-11-24 12:24:45 -08:00
committed by GitHub
parent c21b6899e1
commit e50bf6adad
6 changed files with 177 additions and 53 deletions

View File

@@ -8,6 +8,7 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
ChatCompressionService,
findCompressSplitPoint,
modelStringToModelConfigAlias,
} from './chatCompressionService.js';
import type { Content, GenerateContentResponse } from '@google/genai';
import { CompressionStatus } from '../core/turn.js';
@@ -15,7 +16,7 @@ import { tokenLimit } from '../core/tokenLimits.js';
import type { GeminiChat } from '../core/geminiChat.js';
import type { Config } from '../config/config.js';
import { getInitialChatHistory } from '../utils/environmentContext.js';
import type { ContentGenerator } from '../core/contentGenerator.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
vi.mock('../core/tokenLimits.js');
vi.mock('../telemetry/loggers.js');
@@ -101,11 +102,34 @@ describe('findCompressSplitPoint', () => {
});
});
describe('modelStringToModelConfigAlias', () => {
it('should return the default model for unexpected aliases', () => {
expect(modelStringToModelConfigAlias('gemini-flash-flash')).toBe(
DEFAULT_GEMINI_MODEL,
);
});
it('should handle valid names', () => {
expect(modelStringToModelConfigAlias('gemini-3-pro-preview')).toBe(
'chat-compression-3-pro',
);
expect(modelStringToModelConfigAlias('gemini-2.5-pro')).toBe(
'chat-compression-2.5-pro',
);
expect(modelStringToModelConfigAlias('gemini-2.5-flash')).toBe(
'chat-compression-2.5-flash',
);
expect(modelStringToModelConfigAlias('gemini-2.5-flash-lite')).toBe(
'chat-compression-2.5-flash-lite',
);
});
});
describe('ChatCompressionService', () => {
let service: ChatCompressionService;
let mockChat: GeminiChat;
let mockConfig: Config;
const mockModel = 'gemini-pro';
const mockModel = 'gemini-2.5-pro';
const mockPromptId = 'test-prompt-id';
beforeEach(() => {
@@ -114,9 +138,22 @@ describe('ChatCompressionService', () => {
getHistory: vi.fn(),
getLastPromptTokenCount: vi.fn().mockReturnValue(500),
} as unknown as GeminiChat;
const mockGenerateContent = vi.fn().mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: 'Summary' }],
},
},
],
} as unknown as GenerateContentResponse);
mockConfig = {
getCompressionThreshold: vi.fn(),
getContentGenerator: vi.fn(),
getBaseLlmClient: vi.fn().mockReturnValue({
generateContent: mockGenerateContent,
}),
isInteractive: vi.fn().mockReturnValue(false),
} as unknown as Config;
@@ -190,18 +227,6 @@ describe('ChatCompressionService', () => {
vi.mocked(mockChat.getHistory).mockReturnValue(history);
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(800);
vi.mocked(tokenLimit).mockReturnValue(1000);
const mockGenerateContent = vi.fn().mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: 'Summary' }],
},
},
],
} as unknown as GenerateContentResponse);
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
generateContent: mockGenerateContent,
} as unknown as ContentGenerator);
const result = await service.compress(
mockChat,
@@ -215,7 +240,7 @@ describe('ChatCompressionService', () => {
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
expect(result.newHistory).not.toBeNull();
expect(result.newHistory![0].parts![0].text).toBe('Summary');
expect(mockGenerateContent).toHaveBeenCalled();
expect(mockConfig.getBaseLlmClient().generateContent).toHaveBeenCalled();
});
it('should force compress even if under threshold', async () => {
@@ -229,19 +254,6 @@ describe('ChatCompressionService', () => {
vi.mocked(mockChat.getLastPromptTokenCount).mockReturnValue(100);
vi.mocked(tokenLimit).mockReturnValue(1000);
const mockGenerateContent = vi.fn().mockResolvedValue({
candidates: [
{
content: {
parts: [{ text: 'Summary' }],
},
},
],
} as unknown as GenerateContentResponse);
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
generateContent: mockGenerateContent,
} as unknown as ContentGenerator);
const result = await service.compress(
mockChat,
mockPromptId,
@@ -265,7 +277,7 @@ describe('ChatCompressionService', () => {
vi.mocked(tokenLimit).mockReturnValue(1000);
const longSummary = 'a'.repeat(1000); // Long summary to inflate token count
const mockGenerateContent = vi.fn().mockResolvedValue({
vi.mocked(mockConfig.getBaseLlmClient().generateContent).mockResolvedValue({
candidates: [
{
content: {
@@ -274,9 +286,6 @@ describe('ChatCompressionService', () => {
},
],
} as unknown as GenerateContentResponse);
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
generateContent: mockGenerateContent,
} as unknown as ContentGenerator);
const result = await service.compress(
mockChat,