mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 12:04:56 -07:00
feat(core): Wire up chat code path for model configs. (#12850)
This commit is contained in:
@@ -27,8 +27,9 @@ import {
|
||||
type FunctionCall,
|
||||
type Part,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentConfig,
|
||||
type Content,
|
||||
type PartListUnion,
|
||||
type Tool,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
|
||||
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||
() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}),
|
||||
);
|
||||
const {
|
||||
mockSendMessageStream,
|
||||
mockExecuteToolCall,
|
||||
mockSetSystemInstruction,
|
||||
mockCompress,
|
||||
mockSetTools,
|
||||
} = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockSetSystemInstruction: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
mockSetTools: vi.fn(),
|
||||
}));
|
||||
|
||||
let mockChatHistory: Content[] = [];
|
||||
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
})),
|
||||
};
|
||||
});
|
||||
@@ -172,8 +183,10 @@ const mockModelResponse = (
|
||||
const getMockMessageParams = (callIndex: number) => {
|
||||
const call = mockSendMessageStream.mock.calls[callIndex];
|
||||
expect(call).toBeDefined();
|
||||
// Arg 1 of sendMessageStream is the message parameters
|
||||
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
|
||||
return {
|
||||
modelConfigKey: call[0],
|
||||
message: call[1],
|
||||
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
|
||||
};
|
||||
|
||||
let mockConfig: Config;
|
||||
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
|
||||
mockCompress.mockClear();
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockSetSystemInstruction.mockReset();
|
||||
mockSetTools.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
mockedLogAgentFinish.mockReset();
|
||||
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
|
||||
() =>
|
||||
({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
|
||||
await executor.run(inputs, signal);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
|
||||
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
|
||||
|
||||
expect(startHistory).toBeDefined();
|
||||
expect(startHistory).toHaveLength(2);
|
||||
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const chatConfig = chatConstructorArgs[1];
|
||||
const systemInstruction = chatConfig?.systemInstruction as string;
|
||||
|
||||
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
|
||||
expect(systemInstruction).toContain(
|
||||
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
|
||||
);
|
||||
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
expect(systemInstruction).toContain('Always use absolute paths');
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
expect(sentTools).toEqual(
|
||||
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
|
||||
|
||||
const output = await executor.run({ goal: 'Do work' }, signal);
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
const completeToolDef = sentTools!.find(
|
||||
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
|
||||
expect(turn2Parts).toBeDefined();
|
||||
expect(turn2Parts).toHaveLength(1);
|
||||
|
||||
expect(turn2Parts![0]).toEqual(
|
||||
expect((turn2Parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
|
||||
const turn2Params = getMockMessageParams(1);
|
||||
const parts = turn2Params.message;
|
||||
expect(parts).toBeDefined();
|
||||
expect(parts![0]).toEqual(
|
||||
expect((parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
id: badCallId,
|
||||
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that is interruptible by an abort signal.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
})();
|
||||
});
|
||||
})(),
|
||||
);
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that gets interrupted by the timeout.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})(),
|
||||
);
|
||||
|
||||
// Recovery turn (succeeds)
|
||||
mockModelResponse(
|
||||
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
// Mock the recovery call to also be long-running
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
const runPromise = executor.run(
|
||||
{ goal: 'Timeout recovery fail' },
|
||||
|
||||
@@ -12,7 +12,6 @@ import type {
|
||||
Content,
|
||||
Part,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
FunctionDeclaration,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
@@ -53,6 +52,7 @@ import { parseThought } from '../utils/thoughtUtils.js';
|
||||
import { type z } from 'zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -595,18 +595,19 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
||||
const messageParams = {
|
||||
message: message.parts || [],
|
||||
config: {
|
||||
abortSignal: signal,
|
||||
tools: tools.length > 0 ? [{ functionDeclarations: tools }] : undefined,
|
||||
},
|
||||
};
|
||||
if (tools.length > 0) {
|
||||
// TODO(12622): Move tools back to config.
|
||||
chat.setTools([{ functionDeclarations: tools }]);
|
||||
}
|
||||
|
||||
const responseStream = await chat.sendMessageStream(
|
||||
this.definition.modelConfig.model,
|
||||
messageParams,
|
||||
{
|
||||
model: getModelConfigAlias(this.definition),
|
||||
overrideScope: this.definition.name,
|
||||
},
|
||||
message.parts || [],
|
||||
promptId,
|
||||
signal,
|
||||
);
|
||||
|
||||
const functionCalls: FunctionCall[] = [];
|
||||
@@ -650,7 +651,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
/** Initializes a `GeminiChat` instance for the agent run. */
|
||||
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> {
|
||||
const { promptConfig, modelConfig } = this.definition;
|
||||
const { promptConfig } = this.definition;
|
||||
|
||||
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
|
||||
throw new Error(
|
||||
@@ -669,22 +670,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const generationConfig: GenerateContentConfig = {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
};
|
||||
|
||||
if (systemInstruction) {
|
||||
generationConfig.systemInstruction = systemInstruction;
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
this.runtimeContext,
|
||||
generationConfig,
|
||||
systemInstruction,
|
||||
[], // set in `callModel`,
|
||||
startHistory,
|
||||
);
|
||||
} catch (error) {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { AgentRegistry } from './registry.js';
|
||||
import { AgentRegistry, getModelConfigAlias } from './registry.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import type { AgentDefinition } from './types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -77,6 +77,21 @@ describe('AgentRegistry', () => {
|
||||
it('should register a valid agent definition', () => {
|
||||
registry.testRegisterAgent(MOCK_AGENT_V1);
|
||||
expect(registry.getDefinition('MockAgent')).toEqual(MOCK_AGENT_V1);
|
||||
expect(
|
||||
mockConfig.modelConfigService.getResolvedConfig({
|
||||
model: getModelConfigAlias(MOCK_AGENT_V1),
|
||||
}),
|
||||
).toStrictEqual({
|
||||
model: MOCK_AGENT_V1.modelConfig.model,
|
||||
generateContentConfig: {
|
||||
temperature: MOCK_AGENT_V1.modelConfig.temp,
|
||||
topP: MOCK_AGENT_V1.modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle special characters in agent names', () => {
|
||||
|
||||
@@ -9,6 +9,16 @@ import type { AgentDefinition } from './types.js';
|
||||
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
||||
import { type z } from 'zod';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigAlias } from '../services/modelConfigService.js';
|
||||
|
||||
/**
|
||||
* Returns the model config alias for a given agent definition.
|
||||
*/
|
||||
export function getModelConfigAlias<TOutput extends z.ZodTypeAny>(
|
||||
definition: AgentDefinition<TOutput>,
|
||||
): string {
|
||||
return `${definition.name}-config`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages the discovery, loading, validation, and registration of
|
||||
@@ -84,6 +94,29 @@ export class AgentRegistry {
|
||||
}
|
||||
|
||||
this.agents.set(definition.name, definition);
|
||||
|
||||
// Register model config.
|
||||
// TODO(12916): Migrate sub-agents where possible to static configs.
|
||||
const modelConfig = definition.modelConfig;
|
||||
|
||||
const runtimeAlias: ModelConfigAlias = {
|
||||
modelConfig: {
|
||||
model: modelConfig.model,
|
||||
generateContentConfig: {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
this.config.modelConfigService.registerRuntimeModelConfig(
|
||||
getModelConfigAlias(definition),
|
||||
runtimeAlias,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user