From 5bac855697447afdd86da396d57a8130cfdb3564 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Mon, 1 Sep 2025 20:28:54 -0400 Subject: [PATCH] refactor(core): Require model for utility calls (#7566) --- packages/core/src/core/client.test.ts | 27 +++++++++++++++++++++------ packages/core/src/core/client.ts | 18 +++++++----------- packages/core/src/tools/web-fetch.ts | 4 +++- packages/core/src/tools/web-search.ts | 4 +++- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 3aea7ad9f6..b7ba47ee97 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -425,11 +425,16 @@ describe('Gemini Client (client.ts)', () => { }; client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateJson(contents, schema, abortSignal); + await client.generateJson( + contents, + schema, + abortSignal, + DEFAULT_GEMINI_FLASH_MODEL, + ); expect(mockGenerateContentFn).toHaveBeenCalledWith( { - model: 'test-model', // Should use current model from config + model: DEFAULT_GEMINI_FLASH_MODEL, config: { abortSignal, systemInstruction: getCoreSystemPrompt(''), @@ -2306,11 +2311,16 @@ ${JSON.stringify( }; client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent(contents, generationConfig, abortSignal); + await client.generateContent( + contents, + generationConfig, + abortSignal, + DEFAULT_GEMINI_FLASH_MODEL, + ); expect(mockGenerateContentFn).toHaveBeenCalledWith( { - model: 'test-model', + model: DEFAULT_GEMINI_FLASH_MODEL, config: { abortSignal, systemInstruction: getCoreSystemPrompt(''), @@ -2336,7 +2346,12 @@ ${JSON.stringify( }; client['contentGenerator'] = mockGenerator as ContentGenerator; - await client.generateContent(contents, {}, new AbortController().signal); + await client.generateContent( + contents, + {}, + new AbortController().signal, + DEFAULT_GEMINI_FLASH_MODEL, + ); expect(mockGenerateContentFn).not.toHaveBeenCalledWith({ model: initialModel, @@ -2345,7 +2360,7 @@ ${JSON.stringify( }); expect(mockGenerateContentFn).toHaveBeenCalledWith( { - model: currentModel, + model: DEFAULT_GEMINI_FLASH_MODEL, config: expect.any(Object), contents, }, diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index c9cf597daa..e07ed56e4c 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -582,12 +582,9 @@ export class GeminiClient { contents: Content[], schema: Record, abortSignal: AbortSignal, - model?: string, + model: string, config: GenerateContentConfig = {}, ): Promise> { - // Use current model from config instead of hardcoded Flash model - const modelToUse = - model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); @@ -600,7 +597,7 @@ export class GeminiClient { const apiCall = () => this.getContentGenerator().generateContent( { - model: modelToUse, + model, config: { ...requestConfig, systemInstruction, @@ -637,7 +634,7 @@ export class GeminiClient { if (text.startsWith(prefix) && text.endsWith(suffix)) { logMalformedJsonResponse( this.config, - new MalformedJsonResponseEvent(modelToUse), + new MalformedJsonResponseEvent(model), ); text = text .substring(prefix.length, text.length - suffix.length) @@ -691,9 +688,8 @@ export class GeminiClient { contents: Content[], generationConfig: GenerateContentConfig, abortSignal: AbortSignal, - model?: string, + model: string, ): Promise { - const modelToUse = model ?? this.config.getModel(); const configToUse: GenerateContentConfig = { ...this.generateContentConfig, ...generationConfig, @@ -712,7 +708,7 @@ export class GeminiClient { const apiCall = () => this.getContentGenerator().generateContent( { - model: modelToUse, + model, config: requestConfig, contents, }, @@ -732,7 +728,7 @@ export class GeminiClient { await reportError( error, - `Error generating content via API with model ${modelToUse}.`, + `Error generating content via API with model ${model}.`, { requestContents: contents, requestConfig: configToUse, @@ -740,7 +736,7 @@ export class GeminiClient { 'generateContent-api', ); throw new Error( - `Failed to generate content with model ${modelToUse}: ${getErrorMessage(error)}`, + `Failed to generate content with model ${model}: ${getErrorMessage(error)}`, ); } } diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 586ce362a2..5e1e9bf531 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -18,7 +18,7 @@ import { import { ToolErrorType } from './tool-error.js'; import { getErrorMessage } from '../utils/errors.js'; import type { Config } from '../config/config.js'; -import { ApprovalMode } from '../config/config.js'; +import { ApprovalMode, DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js'; import { getResponseText } from '../utils/partUtils.js'; import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js'; import { convert } from 'html-to-text'; @@ -116,6 +116,7 @@ ${textContent} [{ role: 'user', parts: [{ text: fallbackPrompt }] }], {}, signal, + DEFAULT_GEMINI_FLASH_MODEL, ); const resultText = getResponseText(result) || ''; return { @@ -193,6 +194,7 @@ ${textContent} [{ role: 'user', parts: [{ text: userPrompt }] }], { tools: [{ urlContext: {} }] }, signal, // Pass signal + DEFAULT_GEMINI_FLASH_MODEL, ); console.debug( diff --git a/packages/core/src/tools/web-search.ts b/packages/core/src/tools/web-search.ts index b5f57e4bd9..afaad2ecb3 100644 --- a/packages/core/src/tools/web-search.ts +++ b/packages/core/src/tools/web-search.ts @@ -10,8 +10,9 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js'; import { ToolErrorType } from './tool-error.js'; import { getErrorMessage } from '../utils/errors.js'; -import type { Config } from '../config/config.js'; +import { type Config } from '../config/config.js'; import { getResponseText } from '../utils/partUtils.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; interface GroundingChunkWeb { uri?: string; @@ -78,6 +79,7 @@ class WebSearchToolInvocation extends BaseToolInvocation< [{ role: 'user', parts: [{ text: this.params.query }] }], { tools: [{ googleSearch: {} }] }, signal, + DEFAULT_GEMINI_FLASH_MODEL, ); const responseText = getResponseText(response);