mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
feat(core): Wire up chat code path for model configs. (#12850)
This commit is contained in:
@@ -27,8 +27,9 @@ import {
|
||||
type FunctionCall,
|
||||
type Part,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentConfig,
|
||||
type Content,
|
||||
type PartListUnion,
|
||||
type Tool,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
|
||||
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||
() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}),
|
||||
);
|
||||
const {
|
||||
mockSendMessageStream,
|
||||
mockExecuteToolCall,
|
||||
mockSetSystemInstruction,
|
||||
mockCompress,
|
||||
mockSetTools,
|
||||
} = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockSetSystemInstruction: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
mockSetTools: vi.fn(),
|
||||
}));
|
||||
|
||||
let mockChatHistory: Content[] = [];
|
||||
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
})),
|
||||
};
|
||||
});
|
||||
@@ -172,8 +183,10 @@ const mockModelResponse = (
|
||||
const getMockMessageParams = (callIndex: number) => {
|
||||
const call = mockSendMessageStream.mock.calls[callIndex];
|
||||
expect(call).toBeDefined();
|
||||
// Arg 1 of sendMessageStream is the message parameters
|
||||
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
|
||||
return {
|
||||
modelConfigKey: call[0],
|
||||
message: call[1],
|
||||
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
|
||||
};
|
||||
|
||||
let mockConfig: Config;
|
||||
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
|
||||
mockCompress.mockClear();
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockSetSystemInstruction.mockReset();
|
||||
mockSetTools.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
mockedLogAgentFinish.mockReset();
|
||||
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
|
||||
() =>
|
||||
({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
|
||||
await executor.run(inputs, signal);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
|
||||
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
|
||||
|
||||
expect(startHistory).toBeDefined();
|
||||
expect(startHistory).toHaveLength(2);
|
||||
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const chatConfig = chatConstructorArgs[1];
|
||||
const systemInstruction = chatConfig?.systemInstruction as string;
|
||||
|
||||
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
|
||||
expect(systemInstruction).toContain(
|
||||
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
|
||||
);
|
||||
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
expect(systemInstruction).toContain('Always use absolute paths');
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
expect(sentTools).toEqual(
|
||||
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
|
||||
|
||||
const output = await executor.run({ goal: 'Do work' }, signal);
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
const completeToolDef = sentTools!.find(
|
||||
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
|
||||
expect(turn2Parts).toBeDefined();
|
||||
expect(turn2Parts).toHaveLength(1);
|
||||
|
||||
expect(turn2Parts![0]).toEqual(
|
||||
expect((turn2Parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
|
||||
const turn2Params = getMockMessageParams(1);
|
||||
const parts = turn2Params.message;
|
||||
expect(parts).toBeDefined();
|
||||
expect(parts![0]).toEqual(
|
||||
expect((parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
id: badCallId,
|
||||
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that is interruptible by an abort signal.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
})();
|
||||
});
|
||||
})(),
|
||||
);
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that gets interrupted by the timeout.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})(),
|
||||
);
|
||||
|
||||
// Recovery turn (succeeds)
|
||||
mockModelResponse(
|
||||
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
// Mock the recovery call to also be long-running
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
const runPromise = executor.run(
|
||||
{ goal: 'Timeout recovery fail' },
|
||||
|
||||
Reference in New Issue
Block a user