mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 20:14:44 -07:00
feat(core): Migrate generateJson to resolved model configs. (#12626)
This commit is contained in:
@@ -53,6 +53,13 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
model: 'gemini-2.5-flash-lite',
|
||||
},
|
||||
},
|
||||
// Bases for the internal model configs.
|
||||
'gemini-2.5-flash-base': {
|
||||
extends: 'base',
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-flash',
|
||||
},
|
||||
},
|
||||
classifier: {
|
||||
extends: 'base',
|
||||
modelConfig: {
|
||||
@@ -108,22 +115,32 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
},
|
||||
'web-search-tool': {
|
||||
extends: 'base',
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-flash',
|
||||
generateContentConfig: {
|
||||
tools: [{ googleSearch: {} }],
|
||||
},
|
||||
},
|
||||
},
|
||||
'web-fetch-tool': {
|
||||
extends: 'base',
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {
|
||||
model: 'gemini-2.5-flash',
|
||||
generateContentConfig: {
|
||||
tools: [{ urlContext: {} }],
|
||||
},
|
||||
},
|
||||
},
|
||||
'loop-detection': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
},
|
||||
'llm-edit-fixer': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
},
|
||||
'next-speaker-checker': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
@@ -75,6 +75,15 @@ const mockConfig = {
|
||||
.fn()
|
||||
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
model,
|
||||
generateContentConfig: {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
},
|
||||
})),
|
||||
},
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
// Helper to create a mock GenerateContentResponse
|
||||
@@ -97,9 +106,9 @@ describe('BaseLlmClient', () => {
|
||||
client = new BaseLlmClient(mockContentGenerator, mockConfig);
|
||||
abortController = new AbortController();
|
||||
defaultOptions = {
|
||||
modelConfigKey: { model: 'test-model' },
|
||||
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
||||
schema: { type: 'object', properties: { color: { type: 'string' } } },
|
||||
model: 'test-model',
|
||||
abortSignal: abortController.signal,
|
||||
promptId: 'test-prompt-id',
|
||||
};
|
||||
@@ -135,10 +144,10 @@ describe('BaseLlmClient', () => {
|
||||
contents: defaultOptions.contents,
|
||||
config: {
|
||||
abortSignal: defaultOptions.abortSignal,
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
responseJsonSchema: defaultOptions.schema,
|
||||
responseMimeType: 'application/json',
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
// Crucial: systemInstruction should NOT be in the config object if not provided
|
||||
},
|
||||
},
|
||||
@@ -146,29 +155,6 @@ describe('BaseLlmClient', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect configuration overrides', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "red"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
|
||||
const options: GenerateJsonOptions = {
|
||||
...defaultOptions,
|
||||
config: { temperature: 0.8, topK: 10 },
|
||||
};
|
||||
|
||||
await client.generateJson(options);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.8,
|
||||
topP: 1, // Default should remain if not overridden
|
||||
topK: 10,
|
||||
}),
|
||||
}),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should include system instructions when provided', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "green"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
@@ -296,7 +282,7 @@ describe('BaseLlmClient', () => {
|
||||
const calls = vi.mocked(logMalformedJsonResponse).mock.calls;
|
||||
const lastCall = calls[calls.length - 1];
|
||||
const event = lastCall[1] as MalformedJsonResponseEvent;
|
||||
expect(event.model).toBe('test-model');
|
||||
expect(event.model).toBe(defaultOptions.modelConfigKey.model);
|
||||
});
|
||||
|
||||
it('should handle extra whitespace correctly without logging malformed telemetry', async () => {
|
||||
|
||||
@@ -19,6 +19,7 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
|
||||
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
|
||||
const DEFAULT_MAX_ATTEMPTS = 5;
|
||||
|
||||
@@ -26,28 +27,17 @@ const DEFAULT_MAX_ATTEMPTS = 5;
|
||||
* Options for the generateJson utility function.
|
||||
*/
|
||||
export interface GenerateJsonOptions {
|
||||
/** The desired model config. */
|
||||
modelConfigKey: ModelConfigKey;
|
||||
/** The input prompt or history. */
|
||||
contents: Content[];
|
||||
/** The required JSON schema for the output. */
|
||||
schema: Record<string, unknown>;
|
||||
/** The specific model to use for this task. */
|
||||
model: string;
|
||||
/**
|
||||
* Task-specific system instructions.
|
||||
* If omitted, no system instruction is sent.
|
||||
*/
|
||||
systemInstruction?: string | Part | Part[] | Content;
|
||||
/**
|
||||
* Overrides for generation configuration (e.g., temperature).
|
||||
*/
|
||||
config?: Omit<
|
||||
GenerateContentConfig,
|
||||
| 'systemInstruction'
|
||||
| 'responseJsonSchema'
|
||||
| 'responseMimeType'
|
||||
| 'tools'
|
||||
| 'abortSignal'
|
||||
>;
|
||||
/** Signal for cancellation. */
|
||||
abortSignal: AbortSignal;
|
||||
/**
|
||||
@@ -64,12 +54,6 @@ export interface GenerateJsonOptions {
|
||||
* A client dedicated to stateless, utility-focused LLM calls.
|
||||
*/
|
||||
export class BaseLlmClient {
|
||||
// Default configuration for utility tasks
|
||||
private readonly defaultUtilityConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
};
|
||||
|
||||
constructor(
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly config: Config,
|
||||
@@ -79,19 +63,20 @@ export class BaseLlmClient {
|
||||
options: GenerateJsonOptions,
|
||||
): Promise<Record<string, unknown>> {
|
||||
const {
|
||||
modelConfigKey,
|
||||
contents,
|
||||
schema,
|
||||
model,
|
||||
abortSignal,
|
||||
systemInstruction,
|
||||
promptId,
|
||||
maxAttempts,
|
||||
} = options;
|
||||
|
||||
const { model, generateContentConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...this.defaultUtilityConfig,
|
||||
...options.config,
|
||||
...generateContentConfig,
|
||||
...(systemInstruction && { systemInstruction }),
|
||||
responseJsonSchema: schema,
|
||||
responseMimeType: 'application/json',
|
||||
|
||||
@@ -15,11 +15,11 @@ import {
|
||||
} from '../../utils/messageInspectors.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../../config/models.js';
|
||||
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ResolvedModelConfig } from '../../services/modelConfigService.js';
|
||||
|
||||
vi.mock('../../core/baseLlmClient.js');
|
||||
vi.mock('../../utils/promptIdContext.js');
|
||||
@@ -29,6 +29,7 @@ describe('ClassifierStrategy', () => {
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockResolvedConfig: ResolvedModelConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
@@ -39,7 +40,15 @@ describe('ClassifierStrategy', () => {
|
||||
request: [{ text: 'simple task' }],
|
||||
signal: new AbortController().signal,
|
||||
};
|
||||
mockConfig = {} as Config;
|
||||
mockResolvedConfig = {
|
||||
model: 'classifier',
|
||||
generateContentConfig: {},
|
||||
} as unknown as ResolvedModelConfig;
|
||||
mockConfig = {
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
},
|
||||
} as unknown as Config;
|
||||
mockBaseLlmClient = {
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as BaseLlmClient;
|
||||
@@ -60,14 +69,7 @@ describe('ClassifierStrategy', () => {
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
config: expect.objectContaining({
|
||||
temperature: 0,
|
||||
maxOutputTokens: 1024,
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 512,
|
||||
},
|
||||
}),
|
||||
modelConfigKey: { model: mockResolvedConfig.model },
|
||||
promptId: 'test-prompt-id',
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -14,14 +14,9 @@ import type {
|
||||
} from '../routingStrategy.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../../config/models.js';
|
||||
import {
|
||||
type GenerateContentConfig,
|
||||
createUserContent,
|
||||
Type,
|
||||
} from '@google/genai';
|
||||
import { createUserContent, Type } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
@@ -29,14 +24,6 @@ import {
|
||||
} from '../../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
|
||||
const CLASSIFIER_GENERATION_CONFIG: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
maxOutputTokens: 1024,
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 512, // This counts towards output max, so we don't want -1.
|
||||
},
|
||||
};
|
||||
|
||||
// The number of recent history turns to provide to the router for context.
|
||||
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
||||
const HISTORY_SEARCH_WINDOW = 20;
|
||||
@@ -171,11 +158,10 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
|
||||
const jsonResponse = await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'classifier' },
|
||||
contents: [...finalHistory, createUserContent(context.request)],
|
||||
schema: RESPONSE_SCHEMA,
|
||||
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
|
||||
config: CLASSIFIER_GENERATION_CONFIG,
|
||||
abortSignal: context.signal,
|
||||
promptId,
|
||||
});
|
||||
|
||||
@@ -735,6 +735,12 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
getBaseLlmClient: () => mockBaseLlmClient,
|
||||
getDebugMode: () => false,
|
||||
getTelemetryEnabled: () => true,
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue({
|
||||
model: 'cognitive-loop-v1',
|
||||
generateContentConfig: {},
|
||||
}),
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
service = new LoopDetectionService(mockConfig);
|
||||
@@ -765,9 +771,9 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
modelConfigKey: expect.any(Object),
|
||||
systemInstruction: expect.any(String),
|
||||
contents: expect.any(Array),
|
||||
model: expect.any(String),
|
||||
schema: expect.any(Object),
|
||||
promptId: expect.any(String),
|
||||
}),
|
||||
|
||||
@@ -18,7 +18,6 @@ import {
|
||||
LoopType,
|
||||
} from '../telemetry/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
|
||||
import {
|
||||
isFunctionCall,
|
||||
isFunctionResponse,
|
||||
@@ -436,9 +435,9 @@ export class LoopDetectionService {
|
||||
let result;
|
||||
try {
|
||||
result = await this.config.getBaseLlmClient().generateJson({
|
||||
modelConfigKey: { model: 'loop-detection' },
|
||||
contents,
|
||||
schema,
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
|
||||
abortSignal: signal,
|
||||
promptId: this.promptId,
|
||||
|
||||
@@ -48,6 +48,13 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"gemini-2.5-flash-base": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
},
|
||||
"classifier": {
|
||||
"model": "gemini-2.5-flash-lite",
|
||||
"generateContentConfig": {
|
||||
@@ -119,5 +126,26 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"loop-detection": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
},
|
||||
"llm-edit-fixer": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
},
|
||||
"next-speaker-checker": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,6 +236,14 @@ describe('editCorrector', () => {
|
||||
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
config: {
|
||||
generationConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue({
|
||||
model: 'edit-corrector',
|
||||
generateContentConfig: {},
|
||||
}),
|
||||
},
|
||||
},
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
@@ -634,6 +642,14 @@ describe('editCorrector', () => {
|
||||
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
config: {
|
||||
generationConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue({
|
||||
model: 'edit-corrector',
|
||||
generateContentConfig: {},
|
||||
}),
|
||||
},
|
||||
},
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, GenerateContentConfig } from '@google/genai';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { EditToolParams } from '../tools/edit.js';
|
||||
@@ -16,7 +16,6 @@ import {
|
||||
WRITE_FILE_TOOL_NAME,
|
||||
} from '../tools/tool-names.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
|
||||
import {
|
||||
isFunctionResponse,
|
||||
isFunctionCall,
|
||||
@@ -24,13 +23,6 @@ import {
|
||||
import * as fs from 'node:fs';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
|
||||
const EDIT_MODEL = DEFAULT_GEMINI_FLASH_LITE_MODEL;
|
||||
const EDIT_CONFIG: GenerateContentConfig = {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const CODE_CORRECTION_SYSTEM_PROMPT = `
|
||||
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
|
||||
The correction should be as minimal as possible, staying very close to the original.
|
||||
@@ -420,11 +412,10 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||
|
||||
try {
|
||||
const result = await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'edit-corrector' },
|
||||
contents,
|
||||
schema: OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
@@ -510,11 +501,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
|
||||
try {
|
||||
const result = await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'edit-corrector' },
|
||||
contents,
|
||||
schema: NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
@@ -581,11 +571,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
|
||||
try {
|
||||
const result = await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'edit-corrector' },
|
||||
contents,
|
||||
schema: CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
@@ -649,11 +638,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
|
||||
try {
|
||||
const result = await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'edit-corrector' },
|
||||
contents,
|
||||
schema: CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
@@ -17,6 +17,14 @@ import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
const mockGenerateJson = vi.fn();
|
||||
const mockBaseLlmClient = {
|
||||
generateJson: mockGenerateJson,
|
||||
config: {
|
||||
generationConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue({
|
||||
model: 'edit-corrector',
|
||||
generateContentConfig: {},
|
||||
}),
|
||||
},
|
||||
},
|
||||
} as unknown as BaseLlmClient;
|
||||
|
||||
describe('FixLLMEditWithInstruction', () => {
|
||||
|
||||
@@ -8,7 +8,6 @@ import { createHash } from 'node:crypto';
|
||||
import { type Content, Type } from '@google/genai';
|
||||
import { type BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
|
||||
@@ -176,10 +175,10 @@ export async function FixLLMEditWithInstruction(
|
||||
const result = await generateJsonWithTimeout<SearchReplaceEdit>(
|
||||
baseLlmClient,
|
||||
{
|
||||
modelConfigKey: { model: 'llm-edit-fixer' },
|
||||
contents,
|
||||
schema: SearchReplaceEditSchema,
|
||||
abortSignal,
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
systemInstruction: EDIT_SYS_PROMPT,
|
||||
promptId,
|
||||
maxAttempts: 1,
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import type { Mock } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type { Content } from '@google/genai';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -54,6 +53,10 @@ describe('checkNextSpeaker', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
const mockResolvedConfig = {
|
||||
model: 'next-speaker-v1',
|
||||
generateContentConfig: {},
|
||||
};
|
||||
mockConfig = {
|
||||
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
@@ -61,6 +64,9 @@ describe('checkNextSpeaker', () => {
|
||||
storage: {
|
||||
getProjectTempDir: vi.fn().mockReturnValue('/test/temp'),
|
||||
},
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
},
|
||||
} as unknown as Config;
|
||||
|
||||
mockBaseLlmClient = new BaseLlmClient(
|
||||
@@ -265,8 +271,8 @@ describe('checkNextSpeaker', () => {
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
|
||||
const generateJsonCall = (mockBaseLlmClient.generateJson as Mock).mock
|
||||
.calls[0];
|
||||
expect(generateJsonCall[0].model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(generateJsonCall[0].promptId).toBe(promptId);
|
||||
.calls[0][0];
|
||||
expect(generateJsonCall.modelConfigKey.model).toBe('next-speaker-checker');
|
||||
expect(generateJsonCall.promptId).toBe(promptId);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
*/
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { GeminiChat } from '../core/geminiChat.js';
|
||||
import { isFunctionResponse } from './messageInspectors.js';
|
||||
@@ -111,9 +110,9 @@ export async function checkNextSpeaker(
|
||||
|
||||
try {
|
||||
const parsedResponse = (await baseLlmClient.generateJson({
|
||||
modelConfigKey: { model: 'next-speaker-checker' },
|
||||
contents,
|
||||
schema: RESPONSE_SCHEMA,
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
abortSignal,
|
||||
promptId,
|
||||
})) as unknown as NextSpeakerResponse;
|
||||
|
||||
Reference in New Issue
Block a user