mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-27 13:34:15 -07:00
feat(core): Migrate generateContent to model configs. (#12834)
This commit is contained in:
@@ -114,7 +114,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
},
|
||||
},
|
||||
'web-search-tool': {
|
||||
'web-search': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {
|
||||
generateContentConfig: {
|
||||
@@ -122,7 +122,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
},
|
||||
},
|
||||
'web-fetch-tool': {
|
||||
'web-fetch': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {
|
||||
generateContentConfig: {
|
||||
@@ -130,6 +130,11 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
},
|
||||
},
|
||||
// TODO(joshualitt): During cleanup, make modelConfig optional.
|
||||
'web-fetch-fallback': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
},
|
||||
'loop-detection': {
|
||||
extends: 'gemini-2.5-flash-base',
|
||||
modelConfig: {},
|
||||
|
||||
@@ -42,6 +42,10 @@ import { ideContextStore } from '../ide/ideContext.js';
|
||||
import type { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type {
|
||||
ModelConfigKey,
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
@@ -262,6 +266,17 @@ describe('Gemini Client (client.ts)', () => {
|
||||
reasoning: 'test',
|
||||
}),
|
||||
}),
|
||||
modelConfigService: {
|
||||
getResolvedConfig(modelConfigKey: ModelConfigKey) {
|
||||
return {
|
||||
model: modelConfigKey.model,
|
||||
generateContentConfig: {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
} as unknown as ResolvedModelConfig,
|
||||
};
|
||||
},
|
||||
},
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Config;
|
||||
|
||||
@@ -2268,14 +2283,12 @@ ${JSON.stringify(
|
||||
describe('generateContent', () => {
|
||||
it('should call generateContent with the correct parameters', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const generationConfig = { temperature: 0.5 };
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
await client.generateContent(
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL },
|
||||
contents,
|
||||
generationConfig,
|
||||
abortSignal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
@@ -2284,7 +2297,7 @@ ${JSON.stringify(
|
||||
config: {
|
||||
abortSignal,
|
||||
systemInstruction: getCoreSystemPrompt({} as unknown as Config, ''),
|
||||
temperature: 0.5,
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
},
|
||||
contents,
|
||||
@@ -2301,10 +2314,9 @@ ${JSON.stringify(
|
||||
vi.spyOn(client['config'], 'getModel').mockReturnValueOnce(currentModel);
|
||||
|
||||
await client.generateContent(
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL },
|
||||
contents,
|
||||
{},
|
||||
new AbortController().signal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
expect(mockContentGenerator.generateContent).not.toHaveBeenCalledWith({
|
||||
@@ -2324,7 +2336,6 @@ ${JSON.stringify(
|
||||
|
||||
it('should use the Flash model when fallback mode is active', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const generationConfig = { temperature: 0.5 };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
|
||||
|
||||
@@ -2332,10 +2343,9 @@ ${JSON.stringify(
|
||||
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
|
||||
|
||||
await client.generateContent(
|
||||
{ model: requestedModel },
|
||||
contents,
|
||||
generationConfig,
|
||||
abortSignal,
|
||||
requestedModel,
|
||||
);
|
||||
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
|
||||
@@ -54,6 +54,7 @@ import type { IdeContext, File } from '../ide/types.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||
@@ -602,37 +603,42 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
async generateContent(
|
||||
modelConfigKey: ModelConfigKey,
|
||||
contents: Content[],
|
||||
generationConfig: GenerateContentConfig,
|
||||
abortSignal: AbortSignal,
|
||||
model: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
const configToUse: GenerateContentConfig = {
|
||||
...this.generateContentConfig,
|
||||
...generationConfig,
|
||||
};
|
||||
const desiredModelConfig =
|
||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||
let {
|
||||
model: currentAttemptModel,
|
||||
generateContentConfig: currentAttemptGenerateContentConfig,
|
||||
} = desiredModelConfig;
|
||||
const fallbackModelConfig =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
...modelConfigKey,
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
});
|
||||
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...configToUse,
|
||||
systemInstruction,
|
||||
};
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: model;
|
||||
currentAttemptModel = modelToUse;
|
||||
const modelConfigToUse = this.config.isInFallbackMode()
|
||||
? fallbackModelConfig
|
||||
: desiredModelConfig;
|
||||
currentAttemptModel = modelConfigToUse.model;
|
||||
currentAttemptGenerateContentConfig =
|
||||
modelConfigToUse.generateContentConfig;
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
...currentAttemptGenerateContentConfig,
|
||||
abortSignal,
|
||||
systemInstruction,
|
||||
};
|
||||
|
||||
return this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model: modelToUse,
|
||||
model: currentAttemptModel,
|
||||
config: requestConfig,
|
||||
contents,
|
||||
},
|
||||
@@ -661,7 +667,7 @@ export class GeminiClient {
|
||||
`Error generating content via API with model ${currentAttemptModel}.`,
|
||||
{
|
||||
requestContents: contents,
|
||||
requestConfig: configToUse,
|
||||
requestConfig: currentAttemptGenerateContentConfig,
|
||||
},
|
||||
'generateContent-api',
|
||||
);
|
||||
|
||||
@@ -103,7 +103,7 @@
|
||||
"maxOutputTokens": 2000
|
||||
}
|
||||
},
|
||||
"web-search-tool": {
|
||||
"web-search": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
@@ -115,7 +115,7 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"web-fetch-tool": {
|
||||
"web-fetch": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
@@ -127,6 +127,13 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"web-fetch-fallback": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
"temperature": 0,
|
||||
"topP": 1
|
||||
}
|
||||
},
|
||||
"loop-detection": {
|
||||
"model": "gemini-2.5-flash",
|
||||
"generateContentConfig": {
|
||||
|
||||
@@ -366,10 +366,11 @@ describe('ShellTool', () => {
|
||||
const result = await promise;
|
||||
|
||||
expect(summarizer.summarizeToolOutput).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
{ model: 'summarizer-shell' },
|
||||
expect.any(String),
|
||||
mockConfig.getGeminiClient(),
|
||||
mockAbortSignal,
|
||||
1000,
|
||||
);
|
||||
expect(result.llmContent).toBe('summarized output');
|
||||
expect(result.returnDisplay).toBe('long output');
|
||||
|
||||
@@ -308,10 +308,11 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
: {};
|
||||
if (summarizeConfig && summarizeConfig[SHELL_TOOL_NAME]) {
|
||||
const summary = await summarizeToolOutput(
|
||||
this.config,
|
||||
{ model: 'summarizer-shell' },
|
||||
llmContent,
|
||||
this.config.getGeminiClient(),
|
||||
signal,
|
||||
summarizeConfig[SHELL_TOOL_NAME].tokenBudget,
|
||||
);
|
||||
return {
|
||||
llmContent: summary,
|
||||
|
||||
@@ -142,6 +142,12 @@ describe('WebFetchTool', () => {
|
||||
setApprovalMode: vi.fn(),
|
||||
getProxy: vi.fn(),
|
||||
getGeminiClient: mockGetGeminiClient,
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
model,
|
||||
generateContentConfig: {},
|
||||
})),
|
||||
},
|
||||
isInteractive: () => false,
|
||||
} as unknown as Config;
|
||||
});
|
||||
@@ -270,7 +276,7 @@ describe('WebFetchTool', () => {
|
||||
} as Response);
|
||||
|
||||
// Mock fallback LLM call to return the content passed to it
|
||||
mockGenerateContent.mockImplementationOnce(async (req) => ({
|
||||
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
|
||||
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
|
||||
}));
|
||||
|
||||
@@ -298,7 +304,7 @@ describe('WebFetchTool', () => {
|
||||
} as Response);
|
||||
|
||||
// Mock fallback LLM call to return the content passed to it
|
||||
mockGenerateContent.mockImplementationOnce(async (req) => ({
|
||||
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
|
||||
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
|
||||
}));
|
||||
|
||||
@@ -320,7 +326,7 @@ describe('WebFetchTool', () => {
|
||||
} as Response);
|
||||
|
||||
// Mock fallback LLM call to return the content passed to it
|
||||
mockGenerateContent.mockImplementationOnce(async (req) => ({
|
||||
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
|
||||
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
|
||||
}));
|
||||
|
||||
@@ -342,7 +348,7 @@ describe('WebFetchTool', () => {
|
||||
} as Response);
|
||||
|
||||
// Mock fallback LLM call to return the content passed to it
|
||||
mockGenerateContent.mockImplementationOnce(async (req) => ({
|
||||
mockGenerateContent.mockImplementationOnce(async (_, req) => ({
|
||||
candidates: [{ content: { parts: [{ text: req[0].parts[0].text }] } }],
|
||||
}));
|
||||
|
||||
|
||||
@@ -19,9 +19,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
|
||||
import { convert } from 'html-to-text';
|
||||
@@ -171,10 +169,9 @@ ${textContent}
|
||||
---
|
||||
`;
|
||||
const result = await geminiClient.generateContent(
|
||||
{ model: 'web-fetch-fallback' },
|
||||
[{ role: 'user', parts: [{ text: fallbackPrompt }] }],
|
||||
{},
|
||||
signal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
const resultText = getResponseText(result) || '';
|
||||
return {
|
||||
@@ -255,10 +252,9 @@ ${textContent}
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
{ model: 'web-fetch' },
|
||||
[{ role: 'user', parts: [{ text: userPrompt }] }],
|
||||
{ tools: [{ urlContext: {} }] },
|
||||
signal, // Pass signal
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
debugLogger.debug(
|
||||
|
||||
@@ -25,6 +25,12 @@ describe('WebSearchTool', () => {
|
||||
const mockConfigInstance = {
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
getProxy: () => undefined,
|
||||
generationConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
model,
|
||||
sdkConfig: {},
|
||||
})),
|
||||
},
|
||||
} as unknown as Config;
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
tool = new WebSearchTool(mockConfigInstance);
|
||||
|
||||
@@ -14,7 +14,6 @@ import { ToolErrorType } from './tool-error.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { type Config } from '../config/config.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
|
||||
interface GroundingChunkWeb {
|
||||
uri?: string;
|
||||
@@ -81,10 +80,9 @@ class WebSearchToolInvocation extends BaseToolInvocation<
|
||||
|
||||
try {
|
||||
const response = await geminiClient.generateContent(
|
||||
{ model: 'web-search' },
|
||||
[{ role: 'user', parts: [{ text: this.params.query }] }],
|
||||
{ tools: [{ googleSearch: {} }] },
|
||||
signal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
const responseText = getResponseText(response);
|
||||
|
||||
@@ -14,6 +14,11 @@ import {
|
||||
defaultSummarizer,
|
||||
} from './summarizer.js';
|
||||
import type { ToolResult } from '../tools/tools.js';
|
||||
import type {
|
||||
ModelConfigService,
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
|
||||
// Mock GeminiClient and Config constructor
|
||||
vi.mock('../core/client.js');
|
||||
@@ -22,11 +27,18 @@ vi.mock('../config/config.js');
|
||||
describe('summarizers', () => {
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let MockConfig: Mock;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
const mockResolvedConfig = {
|
||||
model: 'gemini-pro',
|
||||
generateContentConfig: {
|
||||
maxOutputTokens: 2000,
|
||||
},
|
||||
} as unknown as ResolvedModelConfig;
|
||||
|
||||
beforeEach(() => {
|
||||
MockConfig = vi.mocked(Config);
|
||||
const mockConfigInstance = new MockConfig(
|
||||
mockConfigInstance = new MockConfig(
|
||||
'test-api-key',
|
||||
'gemini-pro',
|
||||
false,
|
||||
@@ -38,6 +50,9 @@ describe('summarizers', () => {
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
(mockConfigInstance.modelConfigService as unknown) = {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
} as unknown as ModelConfigService;
|
||||
|
||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||
(mockGeminiClient.generateContent as Mock) = vi.fn();
|
||||
@@ -54,10 +69,11 @@ describe('summarizers', () => {
|
||||
it('should return original text if it is shorter than maxLength', async () => {
|
||||
const shortText = 'This is a short text.';
|
||||
const result = await summarizeToolOutput(
|
||||
mockConfigInstance,
|
||||
{ model: DEFAULT_GEMINI_MODEL },
|
||||
shortText,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
2000,
|
||||
);
|
||||
expect(result).toBe(shortText);
|
||||
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
|
||||
@@ -66,10 +82,11 @@ describe('summarizers', () => {
|
||||
it('should return original text if it is empty', async () => {
|
||||
const emptyText = '';
|
||||
const result = await summarizeToolOutput(
|
||||
mockConfigInstance,
|
||||
{ model: DEFAULT_GEMINI_MODEL },
|
||||
emptyText,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
2000,
|
||||
);
|
||||
expect(result).toBe(emptyText);
|
||||
expect(mockGeminiClient.generateContent).not.toHaveBeenCalled();
|
||||
@@ -81,12 +98,12 @@ describe('summarizers', () => {
|
||||
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: summary }] } }],
|
||||
});
|
||||
|
||||
const result = await summarizeToolOutput(
|
||||
mockConfigInstance,
|
||||
{ model: DEFAULT_GEMINI_MODEL },
|
||||
longText,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
2000,
|
||||
);
|
||||
|
||||
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
|
||||
@@ -99,10 +116,11 @@ describe('summarizers', () => {
|
||||
(mockGeminiClient.generateContent as Mock).mockRejectedValue(error);
|
||||
|
||||
const result = await summarizeToolOutput(
|
||||
mockConfigInstance,
|
||||
{ model: DEFAULT_GEMINI_MODEL },
|
||||
longText,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
2000,
|
||||
);
|
||||
|
||||
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
|
||||
@@ -115,8 +133,24 @@ describe('summarizers', () => {
|
||||
(mockGeminiClient.generateContent as Mock).mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: summary }] } }],
|
||||
});
|
||||
(mockConfigInstance.modelConfigService as unknown) = {
|
||||
getResolvedConfig() {
|
||||
return {
|
||||
model: 'gemini-pro-limited',
|
||||
generateContentConfig: {
|
||||
maxOutputTokens: 1000,
|
||||
},
|
||||
};
|
||||
},
|
||||
};
|
||||
|
||||
await summarizeToolOutput(longText, mockGeminiClient, abortSignal, 1000);
|
||||
await summarizeToolOutput(
|
||||
mockConfigInstance,
|
||||
{ model: 'gemini-pro-limited' },
|
||||
longText,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
const expectedPrompt = `Summarize the following tool output to be a maximum of 1000 tokens. The summary should be concise and capture the main points of the tool output.
|
||||
|
||||
@@ -133,7 +167,7 @@ Return the summary string which should first contain an overall summarization of
|
||||
`;
|
||||
const calledWith = (mockGeminiClient.generateContent as Mock).mock
|
||||
.calls[0];
|
||||
const contents = calledWith[0];
|
||||
const contents = calledWith[1];
|
||||
expect(contents[0].parts[0].text).toBe(expectedPrompt);
|
||||
});
|
||||
});
|
||||
@@ -150,6 +184,7 @@ Return the summary string which should first contain an overall summarization of
|
||||
});
|
||||
|
||||
const result = await llmSummarizer(
|
||||
mockConfigInstance,
|
||||
toolResult,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
@@ -171,6 +206,7 @@ Return the summary string which should first contain an overall summarization of
|
||||
});
|
||||
|
||||
const result = await llmSummarizer(
|
||||
mockConfigInstance,
|
||||
toolResult,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
@@ -179,7 +215,7 @@ Return the summary string which should first contain an overall summarization of
|
||||
expect(mockGeminiClient.generateContent).toHaveBeenCalledTimes(1);
|
||||
const calledWith = (mockGeminiClient.generateContent as Mock).mock
|
||||
.calls[0];
|
||||
const contents = calledWith[0];
|
||||
const contents = calledWith[1];
|
||||
expect(contents[0].parts[0].text).toContain(`"${longText}"`);
|
||||
expect(result).toBe(summary);
|
||||
});
|
||||
@@ -193,6 +229,7 @@ Return the summary string which should first contain an overall summarization of
|
||||
};
|
||||
|
||||
const result = await defaultSummarizer(
|
||||
mockConfigInstance,
|
||||
toolResult,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
|
||||
@@ -5,15 +5,12 @@
|
||||
*/
|
||||
|
||||
import type { ToolResult } from '../tools/tools.js';
|
||||
import type {
|
||||
Content,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import { DEFAULT_GEMINI_FLASH_LITE_MODEL } from '../config/models.js';
|
||||
import { getResponseText, partToString } from './partUtils.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
/**
|
||||
* A function that summarizes the result of a tool execution.
|
||||
@@ -22,6 +19,7 @@ import { debugLogger } from './debugLogger.js';
|
||||
* @returns The summary of the result.
|
||||
*/
|
||||
export type Summarizer = (
|
||||
config: Config,
|
||||
result: ToolResult,
|
||||
geminiClient: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
@@ -36,6 +34,7 @@ export type Summarizer = (
|
||||
* @returns The summary of the result.
|
||||
*/
|
||||
export const defaultSummarizer: Summarizer = (
|
||||
_config: Config,
|
||||
result: ToolResult,
|
||||
_geminiClient: GeminiClient,
|
||||
_abortSignal: AbortSignal,
|
||||
@@ -55,19 +54,30 @@ Text to summarize:
|
||||
Return the summary string which should first contain an overall summarization of text followed by the full stack trace of errors and warnings in the tool output.
|
||||
`;
|
||||
|
||||
export const llmSummarizer: Summarizer = (result, geminiClient, abortSignal) =>
|
||||
export const llmSummarizer: Summarizer = async (
|
||||
config,
|
||||
result,
|
||||
geminiClient,
|
||||
abortSignal,
|
||||
) =>
|
||||
summarizeToolOutput(
|
||||
config,
|
||||
{ model: 'summarizer-default' },
|
||||
partToString(result.llmContent),
|
||||
geminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
export async function summarizeToolOutput(
|
||||
config: Config,
|
||||
modelConfigKey: ModelConfigKey,
|
||||
textToSummarize: string,
|
||||
geminiClient: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
maxOutputTokens: number = 2000,
|
||||
): Promise<string> {
|
||||
const maxOutputTokens =
|
||||
config.modelConfigService.getResolvedConfig(modelConfigKey)
|
||||
.generateContentConfig.maxOutputTokens ?? 2000;
|
||||
// There is going to be a slight difference here since we are comparing length of string with maxOutputTokens.
|
||||
// This is meant to be a ballpark estimation of if we need to summarize the tool output.
|
||||
if (!textToSummarize || textToSummarize.length < maxOutputTokens) {
|
||||
@@ -79,16 +89,12 @@ export async function summarizeToolOutput(
|
||||
).replace('{textToSummarize}', textToSummarize);
|
||||
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
const toolOutputSummarizerConfig: GenerateContentConfig = {
|
||||
maxOutputTokens,
|
||||
};
|
||||
try {
|
||||
const parsedResponse = (await geminiClient.generateContent(
|
||||
const parsedResponse = await geminiClient.generateContent(
|
||||
modelConfigKey,
|
||||
contents,
|
||||
toolOutputSummarizerConfig,
|
||||
abortSignal,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
)) as unknown as GenerateContentResponse;
|
||||
);
|
||||
return getResponseText(parsedResponse) || textToSummarize;
|
||||
} catch (error) {
|
||||
debugLogger.warn('Failed to summarize tool output.', error);
|
||||
|
||||
Reference in New Issue
Block a user