feat(core): Wire up chat code path for model configs. (#12850)

This commit is contained in:
joshualitt
2025-11-19 20:41:16 -08:00
committed by GitHub
parent 43d6dc3668
commit 257cd07a3a
19 changed files with 485 additions and 347 deletions
+78 -77
View File
@@ -27,8 +27,9 @@ import {
type FunctionCall,
type Part,
type GenerateContentResponse,
type GenerateContentConfig,
type Content,
type PartListUnion,
type Tool,
} from '@google/genai';
import type { Config } from '../config/config.js';
import { MockTool } from '../test-utils/mock-tool.js';
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import { CompressionStatus } from '../core/turn.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { getModelConfigAlias } from './registry.js';
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockCompress: vi.fn(),
}),
);
const {
mockSendMessageStream,
mockExecuteToolCall,
mockSetSystemInstruction,
mockCompress,
mockSetTools,
} = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(),
mockSetTools: vi.fn(),
}));
let mockChatHistory: Content[] = [];
const mockSetHistory = vi.fn((newHistory: Content[]) => {
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
sendMessageStream: mockSendMessageStream,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
setHistory: mockSetHistory,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
})),
};
});
@@ -172,8 +183,10 @@ const mockModelResponse = (
const getMockMessageParams = (callIndex: number) => {
const call = mockSendMessageStream.mock.calls[callIndex];
expect(call).toBeDefined();
// Arg 1 of sendMessageStream is the message parameters
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
return {
modelConfigKey: call[0],
message: call[1],
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
};
let mockConfig: Config;
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
mockCompress.mockClear();
mockSetHistory.mockClear();
mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset();
mockSetTools.mockReset();
mockExecuteToolCall.mockReset();
mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset();
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
() =>
({
sendMessageStream: mockSendMessageStream,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
getLastPromptTokenCount: vi.fn(() => 100),
setHistory: mockSetHistory,
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
await executor.run(inputs, signal);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
expect(startHistory).toBeDefined();
expect(startHistory).toHaveLength(2);
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const chatConfig = chatConstructorArgs[1];
const systemInstruction = chatConfig?.systemInstruction as string;
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
expect(systemInstruction).toContain(
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
);
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
);
expect(systemInstruction).toContain('Always use absolute paths');
const turn1Params = getMockMessageParams(0);
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
const firstToolGroup = turn1Params.config?.tools?.[0];
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
expect(sentTools).toEqual(
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
const output = await executor.run({ goal: 'Do work' }, signal);
const turn1Params = getMockMessageParams(0);
const firstToolGroup = turn1Params.config?.tools?.[0];
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
const completeToolDef = sentTools!.find(
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
expect(turn2Parts).toBeDefined();
expect(turn2Parts).toHaveLength(1);
expect(turn2Parts![0]).toEqual(
expect((turn2Parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
name: TASK_COMPLETE_TOOL_NAME,
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
const turn2Params = getMockMessageParams(1);
const parts = turn2Params.message;
expect(parts).toBeDefined();
expect(parts![0]).toEqual(
expect((parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
id: badCallId,
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
);
// Mock a model call that is interruptible by an abort signal.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
});
});
});
})();
});
})(),
);
// Recovery turn
mockModelResponse([], 'I give up');
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
);
// Mock a model call that gets interrupted by the timeout.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})(),
);
// Recovery turn (succeeds)
mockModelResponse(
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
onActivity,
);
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
// Mock the recovery call to also be long-running
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
const runPromise = executor.run(
{ goal: 'Timeout recovery fail' },
+14 -25
View File
@@ -12,7 +12,6 @@ import type {
Content,
Part,
FunctionCall,
GenerateContentConfig,
FunctionDeclaration,
Schema,
} from '@google/genai';
@@ -53,6 +52,7 @@ import { parseThought } from '../utils/thoughtUtils.js';
import { type z } from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { debugLogger } from '../utils/debugLogger.js';
import { getModelConfigAlias } from './registry.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -595,18 +595,19 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
signal: AbortSignal,
promptId: string,
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
const messageParams = {
message: message.parts || [],
config: {
abortSignal: signal,
tools: tools.length > 0 ? [{ functionDeclarations: tools }] : undefined,
},
};
if (tools.length > 0) {
// TODO(12622): Move tools back to config.
chat.setTools([{ functionDeclarations: tools }]);
}
const responseStream = await chat.sendMessageStream(
this.definition.modelConfig.model,
messageParams,
{
model: getModelConfigAlias(this.definition),
overrideScope: this.definition.name,
},
message.parts || [],
promptId,
signal,
);
const functionCalls: FunctionCall[] = [];
@@ -650,7 +651,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
/** Initializes a `GeminiChat` instance for the agent run. */
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> {
const { promptConfig, modelConfig } = this.definition;
const { promptConfig } = this.definition;
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
throw new Error(
@@ -669,22 +670,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
: undefined;
try {
const generationConfig: GenerateContentConfig = {
temperature: modelConfig.temp,
topP: modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: modelConfig.thinkingBudget ?? -1,
},
};
if (systemInstruction) {
generationConfig.systemInstruction = systemInstruction;
}
return new GeminiChat(
this.runtimeContext,
generationConfig,
systemInstruction,
[], // set in `callModel`,
startHistory,
);
} catch (error) {
+16 -1
View File
@@ -5,7 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { AgentRegistry } from './registry.js';
import { AgentRegistry, getModelConfigAlias } from './registry.js';
import { makeFakeConfig } from '../test-utils/config.js';
import type { AgentDefinition } from './types.js';
import type { Config } from '../config/config.js';
@@ -77,6 +77,21 @@ describe('AgentRegistry', () => {
it('should register a valid agent definition', () => {
registry.testRegisterAgent(MOCK_AGENT_V1);
expect(registry.getDefinition('MockAgent')).toEqual(MOCK_AGENT_V1);
expect(
mockConfig.modelConfigService.getResolvedConfig({
model: getModelConfigAlias(MOCK_AGENT_V1),
}),
).toStrictEqual({
model: MOCK_AGENT_V1.modelConfig.model,
generateContentConfig: {
temperature: MOCK_AGENT_V1.modelConfig.temp,
topP: MOCK_AGENT_V1.modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: -1,
},
},
});
});
it('should handle special characters in agent names', () => {
+33
View File
@@ -9,6 +9,16 @@ import type { AgentDefinition } from './types.js';
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
import { type z } from 'zod';
import { debugLogger } from '../utils/debugLogger.js';
import type { ModelConfigAlias } from '../services/modelConfigService.js';
/**
* Returns the model config alias for a given agent definition.
*/
export function getModelConfigAlias<TOutput extends z.ZodTypeAny>(
definition: AgentDefinition<TOutput>,
): string {
return `${definition.name}-config`;
}
/**
* Manages the discovery, loading, validation, and registration of
@@ -84,6 +94,29 @@ export class AgentRegistry {
}
this.agents.set(definition.name, definition);
// Register model config.
// TODO(12916): Migrate sub-agents where possible to static configs.
const modelConfig = definition.modelConfig;
const runtimeAlias: ModelConfigAlias = {
modelConfig: {
model: modelConfig.model,
generateContentConfig: {
temperature: modelConfig.temp,
topP: modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: modelConfig.thinkingBudget ?? -1,
},
},
},
};
this.config.modelConfigService.registerRuntimeModelConfig(
getModelConfigAlias(definition),
runtimeAlias,
);
}
/**