feat(core): Migrate generateJson to resolved model configs. (#12626)

This commit is contained in:
joshualitt
2025-11-07 14:18:45 -08:00
committed by GitHub
parent f3a8b73717
commit fdb6088603
16 changed files with 175 additions and 118 deletions
@@ -53,6 +53,13 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
model: 'gemini-2.5-flash-lite',
},
},
// Bases for the internal model configs.
'gemini-2.5-flash-base': {
extends: 'base',
modelConfig: {
model: 'gemini-2.5-flash',
},
},
classifier: {
extends: 'base',
modelConfig: {
@@ -108,22 +115,32 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
},
},
'web-search-tool': {
extends: 'base',
extends: 'gemini-2.5-flash-base',
modelConfig: {
model: 'gemini-2.5-flash',
generateContentConfig: {
tools: [{ googleSearch: {} }],
},
},
},
'web-fetch-tool': {
extends: 'base',
extends: 'gemini-2.5-flash-base',
modelConfig: {
model: 'gemini-2.5-flash',
generateContentConfig: {
tools: [{ urlContext: {} }],
},
},
},
'loop-detection': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
},
'llm-edit-fixer': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
},
'next-speaker-checker': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
},
},
};
+13 -27
View File
@@ -75,6 +75,15 @@ const mockConfig = {
.fn()
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
model,
generateContentConfig: {
temperature: 0,
topP: 1,
},
})),
},
} as unknown as Mocked<Config>;
// Helper to create a mock GenerateContentResponse
@@ -97,9 +106,9 @@ describe('BaseLlmClient', () => {
client = new BaseLlmClient(mockContentGenerator, mockConfig);
abortController = new AbortController();
defaultOptions = {
modelConfigKey: { model: 'test-model' },
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
schema: { type: 'object', properties: { color: { type: 'string' } } },
model: 'test-model',
abortSignal: abortController.signal,
promptId: 'test-prompt-id',
};
@@ -135,10 +144,10 @@ describe('BaseLlmClient', () => {
contents: defaultOptions.contents,
config: {
abortSignal: defaultOptions.abortSignal,
temperature: 0,
topP: 1,
responseJsonSchema: defaultOptions.schema,
responseMimeType: 'application/json',
temperature: 0,
topP: 1,
// Crucial: systemInstruction should NOT be in the config object if not provided
},
},
@@ -146,29 +155,6 @@ describe('BaseLlmClient', () => {
);
});
it('should respect configuration overrides', async () => {
const mockResponse = createMockResponse('{"color": "red"}');
mockGenerateContent.mockResolvedValue(mockResponse);
const options: GenerateJsonOptions = {
...defaultOptions,
config: { temperature: 0.8, topK: 10 },
};
await client.generateJson(options);
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.8,
topP: 1, // Default should remain if not overridden
topK: 10,
}),
}),
expect.any(String),
);
});
it('should include system instructions when provided', async () => {
const mockResponse = createMockResponse('{"color": "green"}');
mockGenerateContent.mockResolvedValue(mockResponse);
@@ -296,7 +282,7 @@ describe('BaseLlmClient', () => {
const calls = vi.mocked(logMalformedJsonResponse).mock.calls;
const lastCall = calls[calls.length - 1];
const event = lastCall[1] as MalformedJsonResponseEvent;
expect(event.model).toBe('test-model');
expect(event.model).toBe(defaultOptions.modelConfigKey.model);
});
it('should handle extra whitespace correctly without logging malformed telemetry', async () => {
+7 -22
View File
@@ -19,6 +19,7 @@ import { getErrorMessage } from '../utils/errors.js';
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { retryWithBackoff } from '../utils/retry.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
const DEFAULT_MAX_ATTEMPTS = 5;
@@ -26,28 +27,17 @@ const DEFAULT_MAX_ATTEMPTS = 5;
* Options for the generateJson utility function.
*/
export interface GenerateJsonOptions {
/** The desired model config. */
modelConfigKey: ModelConfigKey;
/** The input prompt or history. */
contents: Content[];
/** The required JSON schema for the output. */
schema: Record<string, unknown>;
/** The specific model to use for this task. */
model: string;
/**
* Task-specific system instructions.
* If omitted, no system instruction is sent.
*/
systemInstruction?: string | Part | Part[] | Content;
/**
* Overrides for generation configuration (e.g., temperature).
*/
config?: Omit<
GenerateContentConfig,
| 'systemInstruction'
| 'responseJsonSchema'
| 'responseMimeType'
| 'tools'
| 'abortSignal'
>;
/** Signal for cancellation. */
abortSignal: AbortSignal;
/**
@@ -64,12 +54,6 @@ export interface GenerateJsonOptions {
* A client dedicated to stateless, utility-focused LLM calls.
*/
export class BaseLlmClient {
// Default configuration for utility tasks
private readonly defaultUtilityConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
constructor(
private readonly contentGenerator: ContentGenerator,
private readonly config: Config,
@@ -79,19 +63,20 @@ export class BaseLlmClient {
options: GenerateJsonOptions,
): Promise<Record<string, unknown>> {
const {
modelConfigKey,
contents,
schema,
model,
abortSignal,
systemInstruction,
promptId,
maxAttempts,
} = options;
const { model, generateContentConfig } =
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
const requestConfig: GenerateContentConfig = {
abortSignal,
...this.defaultUtilityConfig,
...options.config,
...generateContentConfig,
...(systemInstruction && { systemInstruction }),
responseJsonSchema: schema,
responseMimeType: 'application/json',
@@ -15,11 +15,11 @@ import {
} from '../../utils/messageInspectors.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai';
import type { ResolvedModelConfig } from '../../services/modelConfigService.js';
vi.mock('../../core/baseLlmClient.js');
vi.mock('../../utils/promptIdContext.js');
@@ -29,6 +29,7 @@ describe('ClassifierStrategy', () => {
let mockContext: RoutingContext;
let mockConfig: Config;
let mockBaseLlmClient: BaseLlmClient;
let mockResolvedConfig: ResolvedModelConfig;
beforeEach(() => {
vi.clearAllMocks();
@@ -39,7 +40,15 @@ describe('ClassifierStrategy', () => {
request: [{ text: 'simple task' }],
signal: new AbortController().signal,
};
mockConfig = {} as Config;
mockResolvedConfig = {
model: 'classifier',
generateContentConfig: {},
} as unknown as ResolvedModelConfig;
mockConfig = {
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
} as unknown as Config;
mockBaseLlmClient = {
generateJson: vi.fn(),
} as unknown as BaseLlmClient;
@@ -60,14 +69,7 @@ describe('ClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
config: expect.objectContaining({
temperature: 0,
maxOutputTokens: 1024,
thinkingConfig: {
thinkingBudget: 512,
},
}),
modelConfigKey: { model: mockResolvedConfig.model },
promptId: 'test-prompt-id',
}),
);
@@ -14,14 +14,9 @@ import type {
} from '../routingStrategy.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../../config/models.js';
import {
type GenerateContentConfig,
createUserContent,
Type,
} from '@google/genai';
import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js';
import {
isFunctionCall,
@@ -29,14 +24,6 @@ import {
} from '../../utils/messageInspectors.js';
import { debugLogger } from '../../utils/debugLogger.js';
const CLASSIFIER_GENERATION_CONFIG: GenerateContentConfig = {
temperature: 0,
maxOutputTokens: 1024,
thinkingConfig: {
thinkingBudget: 512, // This counts towards output max, so we don't want -1.
},
};
// The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 4;
const HISTORY_SEARCH_WINDOW = 20;
@@ -171,11 +158,10 @@ export class ClassifierStrategy implements RoutingStrategy {
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
const jsonResponse = await baseLlmClient.generateJson({
modelConfigKey: { model: 'classifier' },
contents: [...finalHistory, createUserContent(context.request)],
schema: RESPONSE_SCHEMA,
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
config: CLASSIFIER_GENERATION_CONFIG,
abortSignal: context.signal,
promptId,
});
@@ -735,6 +735,12 @@ describe('LoopDetectionService LLM Checks', () => {
getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false,
getTelemetryEnabled: () => true,
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'cognitive-loop-v1',
generateContentConfig: {},
}),
},
} as unknown as Config;
service = new LoopDetectionService(mockConfig);
@@ -765,9 +771,9 @@ describe('LoopDetectionService LLM Checks', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: expect.any(Object),
systemInstruction: expect.any(String),
contents: expect.any(Array),
model: expect.any(String),
schema: expect.any(Object),
promptId: expect.any(String),
}),
@@ -18,7 +18,6 @@ import {
LoopType,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
import {
isFunctionCall,
isFunctionResponse,
@@ -436,9 +435,9 @@ export class LoopDetectionService {
let result;
try {
result = await this.config.getBaseLlmClient().generateJson({
modelConfigKey: { model: 'loop-detection' },
contents,
schema,
model: DEFAULT_GEMINI_FLASH_MODEL,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
@@ -48,6 +48,13 @@
}
}
},
"gemini-2.5-flash-base": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"classifier": {
"model": "gemini-2.5-flash-lite",
"generateContentConfig": {
@@ -119,5 +126,26 @@
}
]
}
},
"loop-detection": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"llm-edit-fixer": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"next-speaker-checker": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
}
}
@@ -236,6 +236,14 @@ describe('editCorrector', () => {
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
mockBaseLlmClientInstance = {
generateJson: mockGenerateJson,
config: {
generationConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'edit-corrector',
generateContentConfig: {},
}),
},
},
} as unknown as Mocked<BaseLlmClient>;
resetEditCorrectorCaches_TEST_ONLY();
});
@@ -634,6 +642,14 @@ describe('editCorrector', () => {
mockBaseLlmClientInstance = {
generateJson: mockGenerateJson,
config: {
generationConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'edit-corrector',
generateContentConfig: {},
}),
},
},
} as unknown as Mocked<BaseLlmClient>;
resetEditCorrectorCaches_TEST_ONLY();
});
+5 -17
View File
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, GenerateContentConfig } from '@google/genai';
import type { Content } from '@google/genai';
import type { GeminiClient } from '../core/client.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { EditToolParams } from '../tools/edit.js';
@@ -16,7 +16,6 @@ import {
WRITE_FILE_TOOL_NAME,
} from '../tools/tool-names.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
import {
isFunctionResponse,
isFunctionCall,
@@ -24,13 +23,6 @@ import {
import * as fs from 'node:fs';
import { promptIdContext } from './promptIdContext.js';
const EDIT_MODEL = DEFAULT_GEMINI_FLASH_LITE_MODEL;
const EDIT_CONFIG: GenerateContentConfig = {
thinkingConfig: {
thinkingBudget: 0,
},
};
const CODE_CORRECTION_SYSTEM_PROMPT = `
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
The correction should be as minimal as possible, staying very close to the original.
@@ -420,11 +412,10 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
try {
const result = await baseLlmClient.generateJson({
modelConfigKey: { model: 'edit-corrector' },
contents,
schema: OLD_STRING_CORRECTION_SCHEMA,
abortSignal,
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
@@ -510,11 +501,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
try {
const result = await baseLlmClient.generateJson({
modelConfigKey: { model: 'edit-corrector' },
contents,
schema: NEW_STRING_CORRECTION_SCHEMA,
abortSignal,
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
@@ -581,11 +571,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
try {
const result = await baseLlmClient.generateJson({
modelConfigKey: { model: 'edit-corrector' },
contents,
schema: CORRECT_NEW_STRING_ESCAPING_SCHEMA,
abortSignal,
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
@@ -649,11 +638,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
try {
const result = await baseLlmClient.generateJson({
modelConfigKey: { model: 'edit-corrector' },
contents,
schema: CORRECT_STRING_ESCAPING_SCHEMA,
abortSignal,
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
@@ -17,6 +17,14 @@ import type { BaseLlmClient } from '../core/baseLlmClient.js';
const mockGenerateJson = vi.fn();
const mockBaseLlmClient = {
generateJson: mockGenerateJson,
config: {
generationConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'edit-corrector',
generateContentConfig: {},
}),
},
},
} as unknown as BaseLlmClient;
describe('FixLLMEditWithInstruction', () => {
+1 -2
View File
@@ -8,7 +8,6 @@ import { createHash } from 'node:crypto';
import { type Content, Type } from '@google/genai';
import { type BaseLlmClient } from '../core/baseLlmClient.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { promptIdContext } from './promptIdContext.js';
import { debugLogger } from './debugLogger.js';
@@ -176,10 +175,10 @@ export async function FixLLMEditWithInstruction(
const result = await generateJsonWithTimeout<SearchReplaceEdit>(
baseLlmClient,
{
modelConfigKey: { model: 'llm-edit-fixer' },
contents,
schema: SearchReplaceEditSchema,
abortSignal,
model: DEFAULT_GEMINI_FLASH_MODEL,
systemInstruction: EDIT_SYS_PROMPT,
promptId,
maxAttempts: 1,
@@ -7,7 +7,6 @@
import type { Mock } from 'vitest';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type { Content } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { BaseLlmClient } from '../core/baseLlmClient.js';
import type { ContentGenerator } from '../core/contentGenerator.js';
import type { Config } from '../config/config.js';
@@ -54,6 +53,10 @@ describe('checkNextSpeaker', () => {
beforeEach(() => {
vi.resetAllMocks();
const mockResolvedConfig = {
model: 'next-speaker-v1',
generateContentConfig: {},
};
mockConfig = {
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
@@ -61,6 +64,9 @@ describe('checkNextSpeaker', () => {
storage: {
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
},
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
},
} as unknown as Config;
mockBaseLlmClient = new BaseLlmClient(
@@ -265,8 +271,8 @@ describe('checkNextSpeaker', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
const generateJsonCall = (mockBaseLlmClient.generateJson as Mock).mock
.calls[0];
expect(generateJsonCall[0].model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(generateJsonCall[0].promptId).toBe(promptId);
.calls[0][0];
expect(generateJsonCall.modelConfigKey.model).toBe('next-speaker-checker');
expect(generateJsonCall.promptId).toBe(promptId);
});
});
@@ -5,7 +5,6 @@
*/
import type { Content } from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { GeminiChat } from '../core/geminiChat.js';
import { isFunctionResponse } from './messageInspectors.js';
@@ -111,9 +110,9 @@ export async function checkNextSpeaker(
try {
const parsedResponse = (await baseLlmClient.generateJson({
modelConfigKey: { model: 'next-speaker-checker' },
contents,
schema: RESPONSE_SCHEMA,
model: DEFAULT_GEMINI_FLASH_MODEL,
abortSignal,
promptId,
})) as unknown as NextSpeakerResponse;