feat(core): Migrate generateContent to model configs. (#12834)

This commit is contained in:
joshualitt
2025-11-11 08:10:50 -08:00
committed by GitHub
parent cbbf565121
commit a4415f15d3
15 changed files with 169 additions and 95 deletions
@@ -114,7 +114,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
},
},
},
'web-search-tool': {
'web-search': {
extends: 'gemini-2.5-flash-base',
modelConfig: {
generateContentConfig: {
@@ -122,7 +122,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
},
},
},
'web-fetch-tool': {
'web-fetch': {
extends: 'gemini-2.5-flash-base',
modelConfig: {
generateContentConfig: {
@@ -130,6 +130,11 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
},
},
},
// TODO(joshualitt): During cleanup, make modelConfig optional.
'web-fetch-fallback': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
},
'loop-detection': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
+19 -9
View File
@@ -42,6 +42,10 @@ import { ideContextStore } from '../ide/ideContext.js';
import type { ModelRouterService } from '../routing/modelRouterService.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import type {
ModelConfigKey,
ResolvedModelConfig,
} from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
vi.mock('../services/chatCompressionService.js');
@@ -262,6 +266,17 @@ describe('Gemini Client (client.ts)', () => {
reasoning: 'test',
}),
}),
modelConfigService: {
getResolvedConfig(modelConfigKey: ModelConfigKey) {
return {
model: modelConfigKey.model,
generateContentConfig: {
temperature: 0,
topP: 1,
} as unknown as ResolvedModelConfig,
};
},
},
isInteractive: vi.fn().mockReturnValue(false),
} as unknown as Config;
@@ -2268,14 +2283,12 @@ ${JSON.stringify(
describe('generateContent', () => {
it('should call generateContent with the correct parameters', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const generationConfig = { temperature: 0.5 };
const abortSignal = new AbortController().signal;
await client.generateContent(
{ model: DEFAULT_GEMINI_FLASH_MODEL },
contents,
generationConfig,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
@@ -2284,7 +2297,7 @@ ${JSON.stringify(
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt({} as unknown as Config, ''),
temperature: 0.5,
temperature: 0,
topP: 1,
},
contents,
@@ -2301,10 +2314,9 @@ ${JSON.stringify(
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
await client.generateContent(
{ model: DEFAULT_GEMINI_FLASH_MODEL },
contents,
{},
new AbortController().signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({
@@ -2324,7 +2336,6 @@ ${JSON.stringify(
it('should use the Flash model when fallback mode is active', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const generationConfig = { temperature: 0.5 };
const abortSignal = new AbortController().signal;
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
@@ -2332,10 +2343,9 @@ ${JSON.stringify(
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
await client.generateContent(
{ model: requestedModel },
contents,
generationConfig,
abortSignal,
requestedModel,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
+26 -20
View File
@@ -54,6 +54,7 @@ import type { IdeContext, File } from '../ide/types.js';
import { handleFallback } from '../fallback/handler.js';
import type { RoutingContext } from '../routing/routingStrategy.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
export function isThinkingSupported(model: string) {
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
@@ -602,37 +603,42 @@ export class GeminiClient {
}
async generateContent(
modelConfigKey: ModelConfigKey,
contents: Content[],
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
model: string,
): Promise<GenerateContentResponse> {
let currentAttemptModel: string = model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
};
const desiredModelConfig =
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
let {
model: currentAttemptModel,
generateContentConfig: currentAttemptGenerateContentConfig,
} = desiredModelConfig;
const fallbackModelConfig =
this.config.modelConfigService.getResolvedConfig({
...modelConfigKey,
model: DEFAULT_GEMINI_FLASH_MODEL,
});
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
const requestConfig: GenerateContentConfig = {
abortSignal,
...configToUse,
systemInstruction,
};
const apiCall = () => {
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: model;
currentAttemptModel = modelToUse;
const modelConfigToUse = this.config.isInFallbackMode()
? fallbackModelConfig
: desiredModelConfig;
currentAttemptModel = modelConfigToUse.model;
currentAttemptGenerateContentConfig =
modelConfigToUse.generateContentConfig;
const requestConfig: GenerateContentConfig = {
...currentAttemptGenerateContentConfig,
abortSignal,
systemInstruction,
};
return this.getContentGeneratorOrFail().generateContent(
{
model: modelToUse,
model: currentAttemptModel,
config: requestConfig,
contents,
},
@@ -661,7 +667,7 @@ export class GeminiClient {
`Error generating content via API with model ${currentAttemptModel}.`,
{
requestContents: contents,
requestConfig: configToUse,
requestConfig: currentAttemptGenerateContentConfig,
},
'generateContent-api',
);
@@ -103,7 +103,7 @@
"maxOutputTokens": 2000
}
},
"web-search-tool": {
"web-search": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
@@ -115,7 +115,7 @@
]
}
},
"web-fetch-tool": {
"web-fetch": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
@@ -127,6 +127,13 @@
]
}
},
"web-fetch-fallback": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"loop-detection": {
"model": "gemini-2.5-flash",
"generateContentConfig": {
+2 -1
View File
@@ -366,10 +366,11 @@ describe('ShellTool', () => {
const result = await promise;
expect(summarizer.summarizeToolOutput).toHaveBeenCalledWith(
mockConfig,
{ model: 'summarizer-shell' },
expect.any(String),
mockConfig.getGeminiClient(),
mockAbortSignal,
1000,
);
expect(result.llmContent).toBe('summarized output');
expect(result.returnDisplay).toBe('long output');
+2 -1
View File
@@ -308,10 +308,11 @@ export class ShellToolInvocation extends BaseToolInvocation<
: {};
if (summarizeConfig && summarizeConfig[SHELL_TOOL_NAME]) {
const summary = await summarizeToolOutput(
this.config,
{ model: 'summarizer-shell' },
llmContent,
this.config.getGeminiClient(),
signal,
summarizeConfig[SHELL_TOOL_NAME].tokenBudget,
);
return {
llmContent: summary,
+10 -4
View File
@@ -142,6 +142,12 @@ describe('WebFetchTool', () => {
setApprovalMode: vi.fn(),
getProxy: vi.fn(),
getGeminiClient: mockGetGeminiClient,
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
model,
generateContentConfig: {},
})),
},
isInteractive: () => false,
} as unknown as Config;
});
@@ -270,7 +276,7 @@ describe('WebFetchTool', () => {
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (req) => ({
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
@@ -298,7 +304,7 @@ describe('WebFetchTool', () => {
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (req) => ({
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
@@ -320,7 +326,7 @@ describe('WebFetchTool', () => {
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (req) => ({
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
@@ -342,7 +348,7 @@ describe('WebFetchTool', () => {
} as Response);
// Mock fallback LLM call to return the content passed to it
mockGenerateContent.mockImplementationOnce(async (req) => ({
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
}));
+2 -6
View File
@@ -19,9 +19,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js';
import type { Config } from '../config/config.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
import { ApprovalMode } from '../policy/types.js';
import { getResponseText } from '../utils/partUtils.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { convert } from 'html-to-text';
@@ -171,10 +169,9 @@ ${textContent}
---
`;
const result = await geminiClient.generateContent(
{ model: 'web-fetch-fallback' },
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
{},
signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
const resultText = getResponseText(result) || '';
return {
@@ -255,10 +252,9 @@ ${textContent}
try {
const response = await geminiClient.generateContent(
{ model: 'web-fetch' },
[{ role: 'user', parts: [{ text: userPrompt }] }],
{ tools: [{ urlContext: {} }] },
signal, // Pass signal
DEFAULT_GEMINI_FLASH_MODEL,
);
debugLogger.debug(
@@ -25,6 +25,12 @@ describe('WebSearchTool', () => {
const mockConfigInstance = {
getGeminiClient: () => mockGeminiClient,
getProxy: () => undefined,
generationConfigService: {
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
model,
sdkConfig: {},
})),
},
} as unknown as Config;
mockGeminiClient = new GeminiClient(mockConfigInstance);
tool = new WebSearchTool(mockConfigInstance);
+1 -3
View File
@@ -14,7 +14,6 @@ import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js';
import { type Config } from '../config/config.js';
import { getResponseText } from '../utils/partUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
interface GroundingChunkWeb {
uri?: string;
@@ -81,10 +80,9 @@ class WebSearchToolInvocation extends BaseToolInvocation<
try {
const response = await geminiClient.generateContent(
{ model: 'web-search' },
[{ role: 'user', parts: [{ text: this.params.query }] }],
{ tools: [{ googleSearch: {} }] },
signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
const responseText = getResponseText(response);
+46 -9
View File
@@ -14,6 +14,11 @@ import {
defaultSummarizer,
} from './summarizer.js';
import type { ToolResult } from '../tools/tools.js';
import type {
ModelConfigService,
ResolvedModelConfig,
} from '../services/modelConfigService.js';
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
// Mock GeminiClient and Config constructor
vi.mock('../core/client.js');
@@ -22,11 +27,18 @@ vi.mock('../config/config.js');
describe('summarizers', () => {
let mockGeminiClient: GeminiClient;
let MockConfig: Mock;
let mockConfigInstance: Config;
const abortSignal = new AbortController().signal;
const mockResolvedConfig = {
model: 'gemini-pro',
generateContentConfig: {
maxOutputTokens: 2000,
},
} as unknown as ResolvedModelConfig;
beforeEach(() => {
MockConfig = vi.mocked(Config);
const mockConfigInstance = new MockConfig(
mockConfigInstance = new MockConfig(
'test-api-key',
'gemini-pro',
false,
@@ -38,6 +50,9 @@ describe('summarizers', () => {
undefined,
undefined,
);
(mockConfigInstance.modelConfigService as unknown) = {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
} as unknown as ModelConfigService;
mockGeminiClient = new GeminiClient(mockConfigInstance);
(mockGeminiClient.generateContent as Mock) = vi.fn();
@@ -54,10 +69,11 @@ describe('summarizers', () => {
it('should return original text if it is shorter than maxLength', async () => {
const shortText = 'This is a short text.';
const result = await summarizeToolOutput(
mockConfigInstance,
{ model: DEFAULT_GEMINI_MODEL },
shortText,
mockGeminiClient,
abortSignal,
2000,
);
expect(result).toBe(shortText);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
@@ -66,10 +82,11 @@ describe('summarizers', () => {
it('should return original text if it is empty', async () => {
const emptyText = '';
const result = await summarizeToolOutput(
mockConfigInstance,
{ model: DEFAULT_GEMINI_MODEL },
emptyText,
mockGeminiClient,
abortSignal,
2000,
);
expect(result).toBe(emptyText);
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
@@ -81,12 +98,12 @@ describe('summarizers', () => {
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
const result = await summarizeToolOutput(
mockConfigInstance,
{ model: DEFAULT_GEMINI_MODEL },
longText,
mockGeminiClient,
abortSignal,
2000,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
@@ -99,10 +116,11 @@ describe('summarizers', () => {
(mockGeminiClient.generateContent as Mock).mockRejectedValue(error);
const result = await summarizeToolOutput(
mockConfigInstance,
{ model: DEFAULT_GEMINI_MODEL },
longText,
mockGeminiClient,
abortSignal,
2000,
);
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
@@ -115,8 +133,24 @@ describe('summarizers', () => {
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
candidates: [{ content: { parts: [{ text: summary }] } }],
});
(mockConfigInstance.modelConfigService as unknown) = {
getResolvedConfig() {
return {
model: 'gemini-pro-limited',
generateContentConfig: {
maxOutputTokens: 1000,
},
};
},
};
await summarizeToolOutput(longText, mockGeminiClient, abortSignal, 1000);
await summarizeToolOutput(
mockConfigInstance,
{ model: 'gemini-pro-limited' },
longText,
mockGeminiClient,
abortSignal,
);
const expectedPrompt = `Summarize the following tool output to be a maximum of 1000 tokens. The summary should be concise and capture the main points of the tool output.
@@ -133,7 +167,7 @@ Return the summary string which should first contain an overall summarization of
`;
const calledWith = (mockGeminiClient.generateContent as Mock).mock
.calls[0];
const contents = calledWith[0];
const contents = calledWith[1];
expect(contents[0].parts[0].text).toBe(expectedPrompt);
});
});
@@ -150,6 +184,7 @@ Return the summary string which should first contain an overall summarization of
});
const result = await llmSummarizer(
mockConfigInstance,
toolResult,
mockGeminiClient,
abortSignal,
@@ -171,6 +206,7 @@ Return the summary string which should first contain an overall summarization of
});
const result = await llmSummarizer(
mockConfigInstance,
toolResult,
mockGeminiClient,
abortSignal,
@@ -179,7 +215,7 @@ Return the summary string which should first contain an overall summarization of
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
const calledWith = (mockGeminiClient.generateContent as Mock).mock
.calls[0];
const contents = calledWith[0];
const contents = calledWith[1];
expect(contents[0].parts[0].text).toContain(`"${longText}"`);
expect(result).toBe(summary);
});
@@ -193,6 +229,7 @@ Return the summary string which should first contain an overall summarization of
};
const result = await defaultSummarizer(
mockConfigInstance,
toolResult,
mockGeminiClient,
abortSignal,
+21 -15
View File
@@ -5,15 +5,12 @@
*/
import type { ToolResult } from '../tools/tools.js';
import type {
Content,
GenerateContentConfig,
GenerateContentResponse,
} from '@google/genai';
import type { Content } from '@google/genai';
import type { GeminiClient } from '../core/client.js';
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
import { getResponseText, partToString } from './partUtils.js';
import { debugLogger } from './debugLogger.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import type { Config } from '../config/config.js';
/**
* A function that summarizes the result of a tool execution.
@@ -22,6 +19,7 @@ import { debugLogger } from './debugLogger.js';
* @returns The summary of the result.
*/
export type Summarizer = (
config: Config,
result: ToolResult,
geminiClient: GeminiClient,
abortSignal: AbortSignal,
@@ -36,6 +34,7 @@ export type Summarizer = (
* @returns The summary of the result.
*/
export const defaultSummarizer: Summarizer = (
_config: Config,
result: ToolResult,
_geminiClient: GeminiClient,
_abortSignal: AbortSignal,
@@ -55,19 +54,30 @@ Text to summarize:
Return the summary string which should first contain an overall summarization of text followed by the full stack trace of errors and warnings in the tool output.
`;
export const llmSummarizer: Summarizer = (result, geminiClient, abortSignal) =>
export const llmSummarizer: Summarizer = async (
config,
result,
geminiClient,
abortSignal,
) =>
summarizeToolOutput(
config,
{ model: 'summarizer-default' },
partToString(result.llmContent),
geminiClient,
abortSignal,
);
export async function summarizeToolOutput(
config: Config,
modelConfigKey: ModelConfigKey,
textToSummarize: string,
geminiClient: GeminiClient,
abortSignal: AbortSignal,
maxOutputTokens: number = 2000,
): Promise<string> {
const maxOutputTokens =
config.modelConfigService.getResolvedConfig(modelConfigKey)
.generateContentConfig.maxOutputTokens ?? 2000;
// There is going to be a slight difference here since we are comparing length of string with maxOutputTokens.
// This is meant to be a ballpark estimation of if we need to summarize the tool output.
if (!textToSummarize || textToSummarize.length < maxOutputTokens) {
@@ -79,16 +89,12 @@ export async function summarizeToolOutput(
).replace('{textToSummarize}', textToSummarize);
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
const toolOutputSummarizerConfig: GenerateContentConfig = {
maxOutputTokens,
};
try {
const parsedResponse = (await geminiClient.generateContent(
const parsedResponse = await geminiClient.generateContent(
modelConfigKey,
contents,
toolOutputSummarizerConfig,
abortSignal,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
)) as unknown as GenerateContentResponse;
);
return getResponseText(parsedResponse) || textToSummarize;
} catch (error) {
debugLogger.warn('Failed to summarize tool output.', error);