feat(core): Wire up chat code path for model configs. (#12850)

This commit is contained in:
joshualitt
2025-11-19 20:41:16 -08:00
committed by GitHub
parent 43d6dc3668
commit 257cd07a3a
19 changed files with 485 additions and 347 deletions
+78 -77
View File
@@ -27,8 +27,9 @@ import {
type FunctionCall,
type Part,
type GenerateContentResponse,
type GenerateContentConfig,
type Content,
type PartListUnion,
type Tool,
} from '@google/genai';
import type { Config } from '../config/config.js';
import { MockTool } from '../test-utils/mock-tool.js';
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import { CompressionStatus } from '../core/turn.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { getModelConfigAlias } from './registry.js';
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockCompress: vi.fn(),
}),
);
const {
mockSendMessageStream,
mockExecuteToolCall,
mockSetSystemInstruction,
mockCompress,
mockSetTools,
} = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(),
mockSetTools: vi.fn(),
}));
let mockChatHistory: Content[] = [];
const mockSetHistory = vi.fn((newHistory: Content[]) => {
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
sendMessageStream: mockSendMessageStream,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
setHistory: mockSetHistory,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
})),
};
});
@@ -172,8 +183,10 @@ const mockModelResponse = (
const getMockMessageParams = (callIndex: number) => {
const call = mockSendMessageStream.mock.calls[callIndex];
expect(call).toBeDefined();
// Arg 1 of sendMessageStream is the message parameters
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
return {
modelConfigKey: call[0],
message: call[1],
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
};
let mockConfig: Config;
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
mockCompress.mockClear();
mockSetHistory.mockClear();
mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset();
mockSetTools.mockReset();
mockExecuteToolCall.mockReset();
mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset();
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
() =>
({
sendMessageStream: mockSendMessageStream,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
getLastPromptTokenCount: vi.fn(() => 100),
setHistory: mockSetHistory,
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
await executor.run(inputs, signal);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
expect(startHistory).toBeDefined();
expect(startHistory).toHaveLength(2);
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const chatConfig = chatConstructorArgs[1];
const systemInstruction = chatConfig?.systemInstruction as string;
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
expect(systemInstruction).toContain(
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
);
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
);
expect(systemInstruction).toContain('Always use absolute paths');
const turn1Params = getMockMessageParams(0);
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
const firstToolGroup = turn1Params.config?.tools?.[0];
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
expect(sentTools).toEqual(
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
const output = await executor.run({ goal: 'Do work' }, signal);
const turn1Params = getMockMessageParams(0);
const firstToolGroup = turn1Params.config?.tools?.[0];
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
const completeToolDef = sentTools!.find(
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
expect(turn2Parts).toBeDefined();
expect(turn2Parts).toHaveLength(1);
expect(turn2Parts![0]).toEqual(
expect((turn2Parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
name: TASK_COMPLETE_TOOL_NAME,
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
const turn2Params = getMockMessageParams(1);
const parts = turn2Params.message;
expect(parts).toBeDefined();
expect(parts![0]).toEqual(
expect((parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
id: badCallId,
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
);
// Mock a model call that is interruptible by an abort signal.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
});
});
});
})();
});
})(),
);
// Recovery turn
mockModelResponse([], 'I give up');
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
);
// Mock a model call that gets interrupted by the timeout.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})(),
);
// Recovery turn (succeeds)
mockModelResponse(
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
onActivity,
);
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
// Mock the recovery call to also be long-running
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
const runPromise = executor.run(
{ goal: 'Timeout recovery fail' },