refactor(core): Require model for utility calls (#7566)

This commit is contained in:
Abhi
2025-09-01 20:28:54 -04:00
committed by GitHub
parent 4fd1113905
commit 5bac855697
4 changed files with 34 additions and 19 deletions

View File

@@ -425,11 +425,16 @@ describe('Gemini Client (client.ts)', () => {
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateJson(contents, schema, abortSignal);
await client.generateJson(
contents,
schema,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
{
model: 'test-model', // Should use current model from config
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
@@ -2306,11 +2311,16 @@ ${JSON.stringify(
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(contents, generationConfig, abortSignal);
await client.generateContent(
contents,
generationConfig,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).toHaveBeenCalledWith(
{
model: 'test-model',
model: DEFAULT_GEMINI_FLASH_MODEL,
config: {
abortSignal,
systemInstruction: getCoreSystemPrompt(''),
@@ -2336,7 +2346,12 @@ ${JSON.stringify(
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
await client.generateContent(contents, {}, new AbortController().signal);
await client.generateContent(
contents,
{},
new AbortController().signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
expect(mockGenerateContentFn).not.toHaveBeenCalledWith({
model: initialModel,
@@ -2345,7 +2360,7 @@ ${JSON.stringify(
});
expect(mockGenerateContentFn).toHaveBeenCalledWith(
{
model: currentModel,
model: DEFAULT_GEMINI_FLASH_MODEL,
config: expect.any(Object),
contents,
},

View File

@@ -582,12 +582,9 @@ export class GeminiClient {
contents: Content[],
schema: Record<string, unknown>,
abortSignal: AbortSignal,
model?: string,
model: string,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
// Use current model from config instead of hardcoded Flash model
const modelToUse =
model || this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
@@ -600,7 +597,7 @@ export class GeminiClient {
const apiCall = () =>
this.getContentGenerator().generateContent(
{
model: modelToUse,
model,
config: {
...requestConfig,
systemInstruction,
@@ -637,7 +634,7 @@ export class GeminiClient {
if (text.startsWith(prefix) && text.endsWith(suffix)) {
logMalformedJsonResponse(
this.config,
new MalformedJsonResponseEvent(modelToUse),
new MalformedJsonResponseEvent(model),
);
text = text
.substring(prefix.length, text.length - suffix.length)
@@ -691,9 +688,8 @@ export class GeminiClient {
contents: Content[],
generationConfig: GenerateContentConfig,
abortSignal: AbortSignal,
model?: string,
model: string,
): Promise<GenerateContentResponse> {
const modelToUse = model ?? this.config.getModel();
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
@@ -712,7 +708,7 @@ export class GeminiClient {
const apiCall = () =>
this.getContentGenerator().generateContent(
{
model: modelToUse,
model,
config: requestConfig,
contents,
},
@@ -732,7 +728,7 @@ export class GeminiClient {
await reportError(
error,
`Error generating content via API with model ${modelToUse}.`,
`Error generating content via API with model ${model}.`,
{
requestContents: contents,
requestConfig: configToUse,
@@ -740,7 +736,7 @@ export class GeminiClient {
'generateContent-api',
);
throw new Error(
`Failed to generate content with model ${modelToUse}: ${getErrorMessage(error)}`,
`Failed to generate content with model ${model}: ${getErrorMessage(error)}`,
);
}
}

View File

@@ -18,7 +18,7 @@ import {
import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js';
import type { Config } from '../config/config.js';
import { ApprovalMode } from '../config/config.js';
import { ApprovalMode, DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
import { getResponseText } from '../utils/partUtils.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { convert } from 'html-to-text';
@@ -116,6 +116,7 @@ ${textContent}
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
{},
signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
const resultText = getResponseText(result) || '';
return {
@@ -193,6 +194,7 @@ ${textContent}
[{ role: 'user', parts: [{ text: userPrompt }] }],
{ tools: [{ urlContext: {} }] },
signal, // Pass signal
DEFAULT_GEMINI_FLASH_MODEL,
);
console.debug(

View File

@@ -10,8 +10,9 @@ import { BaseDeclarativeTool, BaseToolInvocation, Kind } from './tools.js';
import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js';
import type { Config } from '../config/config.js';
import { type Config } from '../config/config.js';
import { getResponseText } from '../utils/partUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
interface GroundingChunkWeb {
uri?: string;
@@ -78,6 +79,7 @@ class WebSearchToolInvocation extends BaseToolInvocation<
[{ role: 'user', parts: [{ text: this.params.query }] }],
{ tools: [{ googleSearch: {} }] },
signal,
DEFAULT_GEMINI_FLASH_MODEL,
);
const responseText = getResponseText(response);