feat(core): Wire up chat code path for model configs. (#12850)

This commit is contained in:
joshualitt
2025-11-19 20:41:16 -08:00
committed by GitHub
parent 43d6dc3668
commit 257cd07a3a
19 changed files with 485 additions and 347 deletions

View File

@@ -27,8 +27,9 @@ import {
type FunctionCall,
type Part,
type GenerateContentResponse,
type GenerateContentConfig,
type Content,
type PartListUnion,
type Tool,
} from '@google/genai';
import type { Config } from '../config/config.js';
import { MockTool } from '../test-utils/mock-tool.js';
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import { CompressionStatus } from '../core/turn.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { getModelConfigAlias } from './registry.js';
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockCompress: vi.fn(),
}),
);
const {
mockSendMessageStream,
mockExecuteToolCall,
mockSetSystemInstruction,
mockCompress,
mockSetTools,
} = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(),
mockSetTools: vi.fn(),
}));
let mockChatHistory: Content[] = [];
const mockSetHistory = vi.fn((newHistory: Content[]) => {
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
sendMessageStream: mockSendMessageStream,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
setHistory: mockSetHistory,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
})),
};
});
@@ -172,8 +183,10 @@ const mockModelResponse = (
const getMockMessageParams = (callIndex: number) => {
const call = mockSendMessageStream.mock.calls[callIndex];
expect(call).toBeDefined();
// Arg 1 of sendMessageStream is the message parameters
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
return {
modelConfigKey: call[0],
message: call[1],
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
};
let mockConfig: Config;
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
mockCompress.mockClear();
mockSetHistory.mockClear();
mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset();
mockSetTools.mockReset();
mockExecuteToolCall.mockReset();
mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset();
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
() =>
({
sendMessageStream: mockSendMessageStream,
setSystemInstruction: mockSetSystemInstruction,
setTools: mockSetTools,
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
getLastPromptTokenCount: vi.fn(() => 100),
setHistory: mockSetHistory,
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
await executor.run(inputs, signal);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
expect(startHistory).toBeDefined();
expect(startHistory).toHaveLength(2);
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
const chatConfig = chatConstructorArgs[1];
const systemInstruction = chatConfig?.systemInstruction as string;
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
expect(systemInstruction).toContain(
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
);
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
);
expect(systemInstruction).toContain('Always use absolute paths');
const turn1Params = getMockMessageParams(0);
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
const firstToolGroup = turn1Params.config?.tools?.[0];
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
expect(sentTools).toEqual(
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
const output = await executor.run({ goal: 'Do work' }, signal);
const turn1Params = getMockMessageParams(0);
const firstToolGroup = turn1Params.config?.tools?.[0];
const { modelConfigKey } = getMockMessageParams(0);
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
expect(firstToolGroup).toBeDefined();
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
throw new Error(
'Test expectation failed: Config does not contain functionDeclarations.',
);
}
const sentTools = firstToolGroup.functionDeclarations;
const call = mockSetTools.mock.calls[0];
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
expect(sentTools).toBeDefined();
const completeToolDef = sentTools!.find(
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
expect(turn2Parts).toBeDefined();
expect(turn2Parts).toHaveLength(1);
expect(turn2Parts![0]).toEqual(
expect((turn2Parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
name: TASK_COMPLETE_TOOL_NAME,
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
const turn2Params = getMockMessageParams(1);
const parts = turn2Params.message;
expect(parts).toBeDefined();
expect(parts![0]).toEqual(
expect((parts as Part[])![0]).toEqual(
expect.objectContaining({
functionResponse: expect.objectContaining({
id: badCallId,
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
);
// Mock a model call that is interruptible by an abort signal.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) => {
// This promise resolves when aborted, ending the generator.
signal?.addEventListener('abort', () => {
resolve();
});
});
});
})();
});
})(),
);
// Recovery turn
mockModelResponse([], 'I give up');
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
);
// Mock a model call that gets interrupted by the timeout.
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
// This promise never resolves, it waits for abort.
await new Promise<void>((resolve) => {
signal?.addEventListener('abort', () => resolve());
});
})(),
);
// Recovery turn (succeeds)
mockModelResponse(
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
onActivity,
);
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
// Mock the recovery call to also be long-running
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
const signal = params?.config?.abortSignal;
// eslint-disable-next-line require-yield
return (async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})();
});
mockSendMessageStream.mockImplementationOnce(
async (_key, _message, _promptId, signal) =>
// eslint-disable-next-line require-yield
(async function* () {
await new Promise<void>((resolve) =>
signal?.addEventListener('abort', () => resolve()),
);
})(),
);
const runPromise = executor.run(
{ goal: 'Timeout recovery fail' },

View File

@@ -12,7 +12,6 @@ import type {
Content,
Part,
FunctionCall,
GenerateContentConfig,
FunctionDeclaration,
Schema,
} from '@google/genai';
@@ -53,6 +52,7 @@ import { parseThought } from '../utils/thoughtUtils.js';
import { type z } from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import { debugLogger } from '../utils/debugLogger.js';
import { getModelConfigAlias } from './registry.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -595,18 +595,19 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
signal: AbortSignal,
promptId: string,
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
const messageParams = {
message: message.parts || [],
config: {
abortSignal: signal,
tools: tools.length > 0 ? [{ functionDeclarations: tools }] : undefined,
},
};
if (tools.length > 0) {
// TODO(12622): Move tools back to config.
chat.setTools([{ functionDeclarations: tools }]);
}
const responseStream = await chat.sendMessageStream(
this.definition.modelConfig.model,
messageParams,
{
model: getModelConfigAlias(this.definition),
overrideScope: this.definition.name,
},
message.parts || [],
promptId,
signal,
);
const functionCalls: FunctionCall[] = [];
@@ -650,7 +651,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
/** Initializes a `GeminiChat` instance for the agent run. */
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> {
const { promptConfig, modelConfig } = this.definition;
const { promptConfig } = this.definition;
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
throw new Error(
@@ -669,22 +670,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
: undefined;
try {
const generationConfig: GenerateContentConfig = {
temperature: modelConfig.temp,
topP: modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: modelConfig.thinkingBudget ?? -1,
},
};
if (systemInstruction) {
generationConfig.systemInstruction = systemInstruction;
}
return new GeminiChat(
this.runtimeContext,
generationConfig,
systemInstruction,
[], // set in `callModel`,
startHistory,
);
} catch (error) {

View File

@@ -5,7 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { AgentRegistry } from './registry.js';
import { AgentRegistry, getModelConfigAlias } from './registry.js';
import { makeFakeConfig } from '../test-utils/config.js';
import type { AgentDefinition } from './types.js';
import type { Config } from '../config/config.js';
@@ -77,6 +77,21 @@ describe('AgentRegistry', () => {
it('should register a valid agent definition', () => {
registry.testRegisterAgent(MOCK_AGENT_V1);
expect(registry.getDefinition('MockAgent')).toEqual(MOCK_AGENT_V1);
expect(
mockConfig.modelConfigService.getResolvedConfig({
model: getModelConfigAlias(MOCK_AGENT_V1),
}),
).toStrictEqual({
model: MOCK_AGENT_V1.modelConfig.model,
generateContentConfig: {
temperature: MOCK_AGENT_V1.modelConfig.temp,
topP: MOCK_AGENT_V1.modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: -1,
},
},
});
});
it('should handle special characters in agent names', () => {

View File

@@ -9,6 +9,16 @@ import type { AgentDefinition } from './types.js';
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
import { type z } from 'zod';
import { debugLogger } from '../utils/debugLogger.js';
import type { ModelConfigAlias } from '../services/modelConfigService.js';
/**
* Returns the model config alias for a given agent definition.
*/
export function getModelConfigAlias<TOutput extends z.ZodTypeAny>(
definition: AgentDefinition<TOutput>,
): string {
return `${definition.name}-config`;
}
/**
* Manages the discovery, loading, validation, and registration of
@@ -84,6 +94,29 @@ export class AgentRegistry {
}
this.agents.set(definition.name, definition);
// Register model config.
// TODO(12916): Migrate sub-agents where possible to static configs.
const modelConfig = definition.modelConfig;
const runtimeAlias: ModelConfigAlias = {
modelConfig: {
model: modelConfig.model,
generateContentConfig: {
temperature: modelConfig.temp,
topP: modelConfig.top_p,
thinkingConfig: {
includeThoughts: true,
thinkingBudget: modelConfig.thinkingBudget ?? -1,
},
},
},
};
this.config.modelConfigService.registerRuntimeModelConfig(
getModelConfigAlias(definition),
runtimeAlias,
);
}
/**

View File

@@ -5,6 +5,7 @@
*/
import type { ModelConfigServiceConfig } from '../services/modelConfigService.js';
import { DEFAULT_THINKING_MODE } from './models.js';
// The default model configs. We use `base` as the parent for all of our model
// configs, while `chat-base`, a child of `base`, is the parent of the models
@@ -25,7 +26,9 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
generateContentConfig: {
thinkingConfig: {
includeThoughts: true,
thinkingBudget: -1,
// TODO(joshualitt): Introduce new bases for Gemini 3 models to use
// thinkingLevel instead.
thinkingBudget: DEFAULT_THINKING_MODE,
},
temperature: 1,
topP: 0.95,
@@ -38,6 +41,12 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
// ensure these model configs can be used interactively.
// TODO(joshualitt): Introduce internal base configs for the various models,
// note: we will have to think carefully about names.
'gemini-3-pro-preview': {
extends: 'chat-base',
modelConfig: {
model: 'gemini-3-pro-preview',
},
},
'gemini-2.5-pro': {
extends: 'chat-base',
modelConfig: {

View File

@@ -15,7 +15,7 @@ import {
} from 'vitest';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import { isThinkingSupported, GeminiClient } from './client.js';
import { GeminiClient } from './client.js';
import {
AuthType,
type ContentGenerator,
@@ -134,30 +134,6 @@ async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
return results;
}
describe('isThinkingSupported', () => {
it('should return true for gemini-2.5', () => {
expect(isThinkingSupported('gemini-2.5')).toBe(true);
expect(isThinkingSupported('gemini-2.5-flash')).toBe(true);
});
it('should return true for gemini-2.5-pro', () => {
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
});
it('should return true for gemini-3-pro', () => {
expect(isThinkingSupported('gemini-3-pro')).toBe(true);
});
it('should return false for gemini-2.0 models', () => {
expect(isThinkingSupported('gemini-2.0-flash')).toBe(false);
expect(isThinkingSupported('gemini-2.0-pro')).toBe(false);
});
it('should return true for other models', () => {
expect(isThinkingSupported('some-other-model')).toBe(true);
});
});
describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
@@ -766,14 +742,10 @@ ${JSON.stringify(
// Assert
expect(ideContextStore.get).toHaveBeenCalled();
// The `turn.run` method is now called with the model name as the first
// argument. We use `expect.any(String)` because this test is
// concerned with the IDE context logic, not the model routing,
// which is tested in its own dedicated suite.
expect(mockTurnRunFn).toHaveBeenCalledWith(
expect.any(String),
{ model: 'default-routed-model' },
initialRequest,
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -1302,9 +1274,9 @@ ${JSON.stringify(
expect(mockConfig.getModelRouterService).toHaveBeenCalled();
expect(mockRouterService.route).toHaveBeenCalled();
expect(mockTurnRunFn).toHaveBeenCalledWith(
'routed-model', // The model from the router
{ model: 'routed-model' },
[{ text: 'Hi' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -1319,9 +1291,9 @@ ${JSON.stringify(
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
expect(mockTurnRunFn).toHaveBeenCalledWith(
'routed-model',
{ model: 'routed-model' },
[{ text: 'Hi' }],
expect.any(Object),
expect.any(AbortSignal),
);
// Second turn
@@ -1336,9 +1308,9 @@ ${JSON.stringify(
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
// Should stick to the first model
expect(mockTurnRunFn).toHaveBeenCalledWith(
'routed-model',
{ model: 'routed-model' },
[{ text: 'Continue' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -1353,9 +1325,9 @@ ${JSON.stringify(
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
expect(mockTurnRunFn).toHaveBeenCalledWith(
'routed-model',
{ model: 'routed-model' },
[{ text: 'Hi' }],
expect.any(Object),
expect.any(AbortSignal),
);
// New prompt
@@ -1374,9 +1346,9 @@ ${JSON.stringify(
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
// Should use the newly routed model
expect(mockTurnRunFn).toHaveBeenCalledWith(
'new-routed-model',
{ model: 'new-routed-model' },
[{ text: 'A new topic' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -1395,9 +1367,9 @@ ${JSON.stringify(
await fromAsync(stream);
expect(mockTurnRunFn).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
{ model: DEFAULT_GEMINI_FLASH_MODEL },
[{ text: 'Hi' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -1417,9 +1389,9 @@ ${JSON.stringify(
// First call should use fallback model
expect(mockTurnRunFn).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
{ model: DEFAULT_GEMINI_FLASH_MODEL },
[{ text: 'Hi' }],
expect.any(Object),
expect.any(AbortSignal),
);
// End fallback mode
@@ -1436,9 +1408,9 @@ ${JSON.stringify(
// Router should still not be called, and it should stick to the fallback model
expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again
expect(mockTurnRunFn).toHaveBeenLastCalledWith(
DEFAULT_GEMINI_FLASH_MODEL, // Still the fallback model
{ model: DEFAULT_GEMINI_FLASH_MODEL }, // Still the fallback model
[{ text: 'Continue' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
});
@@ -1487,17 +1459,17 @@ ${JSON.stringify(
// First call with original request
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1,
expect.any(String),
{ model: 'default-routed-model' },
initialRequest,
expect.any(Object),
expect.any(AbortSignal),
);
// Second call with "Please continue."
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.any(String),
{ model: 'default-routed-model' },
[{ text: 'System: Please continue.' }],
expect.any(Object),
expect.any(AbortSignal),
);
});
@@ -2264,7 +2236,7 @@ ${JSON.stringify(
.mockReturnValueOnce(true);
let capturedSignal: AbortSignal;
mockTurnRunFn.mockImplementation((model, request, signal) => {
mockTurnRunFn.mockImplementation((_modelConfigKey, _request, signal) => {
capturedSignal = signal;
return (async function* () {
yield { type: 'content', value: 'First event' };

View File

@@ -33,7 +33,6 @@ import type {
import type { ContentGenerator } from './contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_THINKING_MODE,
getEffectiveModel,
} from '../config/models.js';
import { LoopDetectionService } from '../services/loopDetectionService.js';
@@ -54,19 +53,10 @@ import type { RoutingContext } from '../routing/routingStrategy.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
export function isThinkingSupported(model: string) {
return !model.startsWith('gemini-2.0');
}
const MAX_TURNS = 100;
export class GeminiClient {
private chat?: GeminiChat;
private readonly generateContentConfig: GenerateContentConfig = {
temperature: 1,
topP: 0.95,
topK: 64,
};
private sessionTurnCount = 0;
private readonly loopDetector: LoopDetectionService;
@@ -194,24 +184,10 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
const model = this.config.getModel();
const config: GenerateContentConfig = { ...this.generateContentConfig };
if (isThinkingSupported(model)) {
config.thinkingConfig = {
includeThoughts: true,
thinkingBudget: DEFAULT_THINKING_MODE,
};
}
return new GeminiChat(
this.config,
{
systemInstruction,
...config,
tools,
},
systemInstruction,
tools,
history,
resumedSessionData,
);
@@ -515,7 +491,7 @@ export class GeminiClient {
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
}
const resultStream = turn.run(modelToUse, request, linkedSignal);
const resultStream = turn.run({ model: modelToUse }, request, linkedSignal);
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };

View File

@@ -5,11 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import type {
Content,
GenerateContentConfig,
GenerateContentResponse,
} from '@google/genai';
import type { Content, GenerateContentResponse } from '@google/genai';
import { ApiError } from '@google/genai';
import type { ContentGenerator } from '../core/contentGenerator.js';
import {
@@ -94,7 +90,6 @@ describe('GeminiChat', () => {
let mockContentGenerator: ContentGenerator;
let chat: GeminiChat;
let mockConfig: Config;
const config: GenerateContentConfig = {};
beforeEach(() => {
vi.clearAllMocks();
@@ -135,6 +130,14 @@ describe('GeminiChat', () => {
}),
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
getRetryFetchErrors: vi.fn().mockReturnValue(false),
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => ({
model: modelConfigKey.model,
generateContentConfig: {
temperature: 0,
},
})),
},
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
setPreviewModelBypassMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
@@ -145,7 +148,7 @@ describe('GeminiChat', () => {
// Disable 429 simulation for tests
setSimulate429(false);
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig, config, []);
chat = new GeminiChat(mockConfig);
});
afterEach(() => {
@@ -159,13 +162,13 @@ describe('GeminiChat', () => {
{ role: 'user', parts: [{ text: 'Hello' }] },
{ role: 'model', parts: [{ text: 'Hi there' }] },
];
const chatWithHistory = new GeminiChat(mockConfig, config, history);
const chatWithHistory = new GeminiChat(mockConfig, '', [], history);
const estimatedTokens = Math.ceil(JSON.stringify(history).length / 4);
expect(chatWithHistory.getLastPromptTokenCount()).toBe(estimatedTokens);
});
it('should initialize lastPromptTokenCount for empty history', () => {
const chatEmpty = new GeminiChat(mockConfig, config, []);
const chatEmpty = new GeminiChat(mockConfig);
expect(chatEmpty.getLastPromptTokenCount()).toBe(
Math.ceil(JSON.stringify([]).length / 4),
);
@@ -206,9 +209,10 @@ describe('GeminiChat', () => {
// 2. Action & Assert: The stream processing should complete without throwing an error
// because the presence of a tool call makes the empty final chunk acceptable.
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
{ model: 'test-model' },
'test message',
'prompt-id-tool-call-empty-end',
new AbortController().signal,
);
await expect(
(async () => {
@@ -258,9 +262,10 @@ describe('GeminiChat', () => {
// 2. Action & Assert: The stream should fail because there's no finish reason.
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test message' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-no-finish-empty-end',
new AbortController().signal,
);
await expect(
(async () => {
@@ -304,9 +309,10 @@ describe('GeminiChat', () => {
// 2. Action & Assert: The stream should complete without throwing an error.
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
{ model: 'test-model' },
'test message',
'prompt-id-valid-then-invalid-end',
new AbortController().signal,
);
await expect(
(async () => {
@@ -351,9 +357,10 @@ describe('GeminiChat', () => {
// 2. Action: Send a message and consume the stream.
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
{ model: 'test-model' },
'test message',
'prompt-id-empty-chunk-consolidation',
new AbortController().signal,
);
for await (const _ of stream) {
// Consume the stream
@@ -409,9 +416,10 @@ describe('GeminiChat', () => {
// 2. Action: Send a message and consume the stream.
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
{ model: 'test-model' },
'test message',
'prompt-id-multi-chunk',
new AbortController().signal,
);
for await (const _ of stream) {
// Consume the stream to trigger history recording.
@@ -457,9 +465,10 @@ describe('GeminiChat', () => {
// 2. Action: Send a message and fully consume the stream to trigger history recording.
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test message' },
{ model: 'test-model' },
'test message',
'prompt-id-mixed-chunk',
new AbortController().signal,
);
for await (const _ of stream) {
// This loop consumes the stream.
@@ -499,9 +508,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
PREVIEW_GEMINI_MODEL,
{ message: 'test' },
{ model: PREVIEW_GEMINI_MODEL },
'test',
'prompt-id-fast-retry',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
@@ -531,9 +541,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
DEFAULT_GEMINI_FLASH_MODEL,
{ message: 'test' },
{ model: DEFAULT_GEMINI_FLASH_MODEL },
'test',
'prompt-id-normal-retry',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
@@ -577,9 +588,10 @@ describe('GeminiChat', () => {
// ACT
const consumeStream = async () => {
const stream = await chat.sendMessageStream(
PREVIEW_GEMINI_MODEL,
{ message: 'test' },
{ model: PREVIEW_GEMINI_MODEL },
'test',
'prompt-id-bypass',
new AbortController().signal,
);
// Consume the stream to trigger execution
for await (const _ of stream) {
@@ -639,16 +651,15 @@ describe('GeminiChat', () => {
// 3. Action: Send the function response back to the model and consume the stream.
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ model: 'gemini-2.0-flash' },
{
message: {
functionResponse: {
name: 'find_restaurant',
response: { name: 'Vesuvio' },
},
functionResponse: {
name: 'find_restaurant',
response: { name: 'Vesuvio' },
},
},
'prompt-id-stream-1',
new AbortController().signal,
);
// 4. Assert: The stream processing should throw an InvalidStreamError.
@@ -689,9 +700,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-1',
new AbortController().signal,
);
// Should not throw an error
@@ -725,9 +737,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-1',
new AbortController().signal,
);
await expect(
@@ -760,9 +773,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-1',
new AbortController().signal,
);
await expect(
@@ -795,9 +809,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-1',
new AbortController().signal,
);
// Should not throw an error
@@ -831,9 +846,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.5-pro',
{ message: 'test' },
{ model: 'gemini-2.5-pro' },
'test',
'prompt-id-malformed',
new AbortController().signal,
);
// Should throw an error
@@ -877,9 +893,10 @@ describe('GeminiChat', () => {
// 2. Send a message
const stream = await chat.sendMessageStream(
'gemini-2.5-pro',
{ message: 'test retry' },
{ model: 'gemini-2.5-pro' },
'test retry',
'prompt-id-retry-malformed',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
@@ -933,9 +950,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'hello' },
{ model: 'test-model' },
'hello',
'prompt-id-1',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
@@ -950,7 +968,12 @@ describe('GeminiChat', () => {
parts: [{ text: 'hello' }],
},
],
config: {},
config: {
systemInstruction: '',
tools: [],
temperature: 0,
abortSignal: expect.any(AbortSignal),
},
},
'prompt-id-1',
);
@@ -1000,9 +1023,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-1.5-pro',
{ message: 'test' },
{ model: 'gemini-1.5-pro' },
'test',
'prompt-id-no-retry',
new AbortController().signal,
);
await expect(
@@ -1047,9 +1071,10 @@ describe('GeminiChat', () => {
// ACT: Send a message and collect all events from the stream.
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-yield-retry',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
@@ -1088,9 +1113,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test',
'prompt-id-retry-success',
new AbortController().signal,
);
const chunks: StreamEvent[] = [];
for await (const chunk of stream) {
@@ -1159,9 +1185,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test', config: { temperature: 0.5 } },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-retry-temperature',
new AbortController().signal,
);
for await (const _ of stream) {
@@ -1179,7 +1206,7 @@ describe('GeminiChat', () => {
1,
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.5,
temperature: 0,
}),
}),
'prompt-id-retry-temperature',
@@ -1217,9 +1244,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test',
'prompt-id-retry-fail',
new AbortController().signal,
);
await expect(async () => {
for await (const _ of stream) {
@@ -1282,9 +1310,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-400',
new AbortController().signal,
);
await expect(
@@ -1320,9 +1349,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-429-retry',
new AbortController().signal,
);
const events: StreamEvent[] = [];
@@ -1368,9 +1398,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-500-retry',
new AbortController().signal,
);
const events: StreamEvent[] = [];
@@ -1424,9 +1455,10 @@ describe('GeminiChat', () => {
});
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-fetch-error-retry',
new AbortController().signal,
);
const events: StreamEvent[] = [];
@@ -1487,9 +1519,10 @@ describe('GeminiChat', () => {
// 3. Send a new message
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'Second question' },
{ model: 'gemini-2.0-flash' },
'Second question',
'prompt-id-retry-existing',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
@@ -1558,9 +1591,10 @@ describe('GeminiChat', () => {
// 2. Call the method and consume the stream.
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test empty stream' },
{ model: 'gemini-2.0-flash' },
'test empty stream',
'prompt-id-empty-stream',
new AbortController().signal,
);
const chunks: StreamEvent[] = [];
for await (const chunk of stream) {
@@ -1638,18 +1672,20 @@ describe('GeminiChat', () => {
// 3. Start the first stream and consume only the first chunk to pause it
const firstStream = await chat.sendMessageStream(
'test-model',
{ message: 'first' },
{ model: 'test-model' },
'first',
'prompt-1',
new AbortController().signal,
);
const firstStreamIterator = firstStream[Symbol.asyncIterator]();
await firstStreamIterator.next();
// 4. While the first stream is paused, start the second call. It will block.
const secondStreamPromise = chat.sendMessageStream(
'test-model',
{ message: 'second' },
{ model: 'test-model' },
'second',
'prompt-2',
new AbortController().signal,
);
// 5. Assert that only one API call has been made so far.
@@ -1707,9 +1743,10 @@ describe('GeminiChat', () => {
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test message',
'prompt-id-res3',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
@@ -1793,9 +1830,10 @@ describe('GeminiChat', () => {
});
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'trigger 429' },
{ model: 'test-model' },
'trigger 429',
'prompt-id-fb1',
new AbortController().signal,
);
// Consume stream to trigger logic
@@ -1827,9 +1865,10 @@ describe('GeminiChat', () => {
mockHandleFallback.mockResolvedValue(false);
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test stop' },
{ model: 'gemini-2.0-flash' },
'test stop',
'prompt-id-fb2',
new AbortController().signal,
);
await expect(
@@ -1885,9 +1924,10 @@ describe('GeminiChat', () => {
// Send a message and consume the stream
const stream = await chat.sendMessageStream(
'gemini-2.0-flash',
{ message: 'test' },
{ model: 'gemini-2.0-flash' },
'test message',
'prompt-id-discard-test',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
@@ -1965,9 +2005,10 @@ describe('GeminiChat', () => {
);
await chat.sendMessageStream(
'test-model',
{ message: 'test' },
{ model: 'test-model' },
'test',
'prompt-id-preview-model-reset',
new AbortController().signal,
);
expect(mockConfig.setPreviewModelBypassMode).toHaveBeenCalledWith(false);
@@ -1989,9 +2030,10 @@ describe('GeminiChat', () => {
);
const resultStream = await chat.sendMessageStream(
PREVIEW_GEMINI_MODEL,
{ message: 'test' },
{ model: PREVIEW_GEMINI_MODEL },
'test',
'prompt-id-preview-model-healing',
new AbortController().signal,
);
for await (const _ of resultStream) {
// consume stream
@@ -2019,9 +2061,10 @@ describe('GeminiChat', () => {
vi.mocked(mockConfig.isPreviewModelBypassMode).mockReturnValue(true);
const resultStream = await chat.sendMessageStream(
PREVIEW_GEMINI_MODEL,
{ message: 'test' },
{ model: PREVIEW_GEMINI_MODEL },
'test',
'prompt-id-bypass-no-healing',
new AbortController().signal,
);
for await (const _ of resultStream) {
// consume stream
@@ -2033,7 +2076,7 @@ describe('GeminiChat', () => {
describe('ensureActiveLoopHasThoughtSignatures', () => {
it('should add thoughtSignature to the first functionCall in each model turn of the active loop', () => {
const chat = new GeminiChat(mockConfig, {}, []);
const chat = new GeminiChat(mockConfig, '', [], []);
const history: Content[] = [
{ role: 'user', parts: [{ text: 'Old message' }] },
{
@@ -2090,7 +2133,7 @@ describe('GeminiChat', () => {
});
it('should not modify contents if there is no user text message', () => {
const chat = new GeminiChat(mockConfig, {}, []);
const chat = new GeminiChat(mockConfig, '', [], []);
const history: Content[] = [
{
role: 'user',
@@ -2107,14 +2150,14 @@ describe('GeminiChat', () => {
});
it('should handle an empty history', () => {
const chat = new GeminiChat(mockConfig, {}, []);
const chat = new GeminiChat(mockConfig, '', []);
const history: Content[] = [];
const newContents = chat.ensureActiveLoopHasThoughtSignatures(history);
expect(newContents).toEqual([]);
});
it('should handle history with only a user message', () => {
const chat = new GeminiChat(mockConfig, {}, []);
const chat = new GeminiChat(mockConfig, '', []);
const history: Content[] = [{ role: 'user', parts: [{ text: 'Hello' }] }];
const newContents = chat.ensureActiveLoopHasThoughtSignatures(history);
expect(newContents).toEqual(history);

View File

@@ -10,10 +10,10 @@
import type {
GenerateContentResponse,
Content,
GenerateContentConfig,
SendMessageParameters,
Part,
Tool,
PartListUnion,
GenerateContentConfig,
} from '@google/genai';
import { toParts } from '../code_assist/converter.js';
import { createUserContent, FinishReason } from '@google/genai';
@@ -43,6 +43,7 @@ import {
import { handleFallback } from '../fallback/handler.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { partListUnionToString } from './geminiRequest.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
export enum StreamEventType {
/** A regular content chunk from the API. */
@@ -202,7 +203,8 @@ export class GeminiChat {
constructor(
private readonly config: Config,
private readonly generationConfig: GenerateContentConfig = {},
private systemInstruction: string = '',
private tools: Tool[] = [],
private history: Content[] = [],
resumedSessionData?: ResumedSessionData,
) {
@@ -215,7 +217,7 @@ export class GeminiChat {
}
setSystemInstruction(sysInstr: string) {
this.generationConfig.systemInstruction = sysInstr;
this.systemInstruction = sysInstr;
}
/**
@@ -226,7 +228,10 @@ export class GeminiChat {
* sending the next message.
*
* @see {@link Chat#sendMessage} for non-streaming method.
* @param params - parameters for sending the message.
* @param modelConfigKey - The key for the model config.
* @param message - The list of messages to send.
* @param prompt_id - The ID of the prompt.
* @param signal - An abort signal for this message.
* @return The model's response.
*
* @example
@@ -241,9 +246,10 @@ export class GeminiChat {
* ```
*/
async sendMessageStream(
model: string,
params: SendMessageParameters,
modelConfigKey: ModelConfigKey,
message: PartListUnion,
prompt_id: string,
signal: AbortSignal,
): Promise<AsyncGenerator<StreamEvent>> {
await this.sendPromise;
@@ -251,21 +257,21 @@ export class GeminiChat {
// This ensures that we attempt to use Preview Model for every new user turn
// (unless the "Always" fallback mode is active, which is handled separately).
this.config.setPreviewModelBypassMode(false);
let streamDoneResolver: () => void;
const streamDonePromise = new Promise<void>((resolve) => {
streamDoneResolver = resolve;
});
this.sendPromise = streamDonePromise;
const userContent = createUserContent(params.message);
const userContent = createUserContent(message);
const { model, generateContentConfig } =
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
generateContentConfig.abortSignal = signal;
// Record user input - capture complete message with all parts (text, files, images, etc.)
// but skip recording function responses (tool call results) as they should be stored in tool call records
if (!isFunctionResponse(userContent)) {
const userMessage = Array.isArray(params.message)
? params.message
: [params.message];
const userMessage = Array.isArray(message) ? message : [message];
const userMessageContent = partListUnionToString(toParts(userMessage));
this.chatRecordingService.recordMessage({
model,
@@ -301,18 +307,14 @@ export class GeminiChat {
}
// If this is a retry, set temperature to 1 to encourage different output.
const currentParams = { ...params };
if (attempt > 0) {
currentParams.config = {
...currentParams.config,
temperature: 1,
};
generateContentConfig.temperature = 1;
}
const stream = await self.makeApiCallAndProcessStream(
model,
generateContentConfig,
requestContents,
currentParams,
prompt_id,
);
@@ -385,8 +387,8 @@ export class GeminiChat {
private async makeApiCallAndProcessStream(
model: string,
generateContentConfig: GenerateContentConfig,
requestContents: Content[],
params: SendMessageParameters,
prompt_id: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
let effectiveModel = model;
@@ -418,7 +420,13 @@ export class GeminiChat {
modelToUse === PREVIEW_GEMINI_MODEL
? contentsForPreviewModel
: requestContents,
config: { ...this.generationConfig, ...params.config },
config: {
...generateContentConfig,
// TODO(12622): Ensure we don't overrwrite these when they are
// passed via config.
systemInstruction: this.systemInstruction,
tools: this.tools,
},
},
prompt_id,
);
@@ -433,7 +441,7 @@ export class GeminiChat {
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
retryFetchErrors: this.config.getRetryFetchErrors(),
signal: params.config?.abortSignal,
signal: generateContentConfig.abortSignal,
maxAttempts:
this.config.isPreviewModelFallbackMode() &&
model === PREVIEW_GEMINI_MODEL
@@ -561,7 +569,7 @@ export class GeminiChat {
}
setTools(tools: Tool[]): void {
this.generationConfig.tools = tools;
this.tools = tools;
}
async maybeIncludeSchemaDepthContext(error: StructuredError): Promise<void> {

View File

@@ -97,7 +97,7 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Hi' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -105,12 +105,10 @@ describe('Turn', () => {
}
expect(mockSendMessageStream).toHaveBeenCalledWith(
'test-model',
{
message: reqParts,
config: { abortSignal: expect.any(AbortSignal) },
},
{ model: 'gemini' },
reqParts,
'prompt-id-1',
expect.any(AbortSignal),
);
expect(events).toEqual([
@@ -146,7 +144,7 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Use tools' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -210,7 +208,7 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Test abort' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
abortController.signal,
)) {
@@ -233,7 +231,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -256,7 +254,7 @@ describe('Turn', () => {
mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined);
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -297,7 +295,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'Test undefined tool parts' }],
new AbortController().signal,
)) {
@@ -374,7 +372,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'Test' }],
new AbortController().signal,
)) {
@@ -411,7 +409,7 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Test no finish reason' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -456,7 +454,7 @@ describe('Turn', () => {
const events = [];
const reqParts: Part[] = [{ text: 'Test multiple responses' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {
@@ -499,7 +497,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'Test citations' }],
new AbortController().signal,
)) {
@@ -549,7 +547,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'test' }],
new AbortController().signal,
)) {
@@ -596,7 +594,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'test' }],
new AbortController().signal,
)) {
@@ -642,7 +640,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'test' }],
new AbortController().signal,
)) {
@@ -680,7 +678,7 @@ describe('Turn', () => {
const reqParts: Part[] = [{ text: 'Test malformed error handling' }];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
abortController.signal,
)) {
@@ -706,7 +704,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[],
new AbortController().signal,
)) {
@@ -754,7 +752,7 @@ describe('Turn', () => {
const events = [];
for await (const event of turn.run(
'test-model',
{ model: 'gemini' },
[{ text: 'Hi' }],
new AbortController().signal,
)) {
@@ -780,7 +778,7 @@ describe('Turn', () => {
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const reqParts: Part[] = [{ text: 'Hi' }];
for await (const _ of turn.run(
'test-model',
{ model: 'gemini' },
reqParts,
new AbortController().signal,
)) {

View File

@@ -30,6 +30,7 @@ import type { GeminiChat } from './geminiChat.js';
import { InvalidStreamError } from './geminiChat.js';
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
import { createUserContent } from '@google/genai';
import type { ModelConfigKey } from '../services/modelConfigService.js';
// Define a structure for tools passed to the server
export interface ServerTool {
@@ -232,9 +233,10 @@ export class Turn {
private readonly chat: GeminiChat,
private readonly prompt_id: string,
) {}
// The run method yields simpler events suitable for server logic
async *run(
model: string,
modelConfigKey: ModelConfigKey,
req: PartListUnion,
signal: AbortSignal,
): AsyncGenerator<ServerGeminiStreamEvent> {
@@ -242,14 +244,10 @@ export class Turn {
// Note: This assumes `sendMessageStream` yields events like
// { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse }
const responseStream = await this.chat.sendMessageStream(
model,
{
message: req,
config: {
abortSignal: signal,
},
},
modelConfigKey,
req,
this.prompt_id,
signal,
);
for await (const streamEvent of responseStream) {

View File

@@ -231,4 +231,42 @@ describe('ModelConfigService Integration', () => {
topP: 0.95, // from base
});
});
it('should correctly merge static aliases, runtime aliases, and overrides', () => {
// Re-instantiate service for this isolated test to not pollute other tests
const service = new ModelConfigService(complexConfig);
// Register a runtime alias, simulating what AgentExecutor does.
// This alias extends a static base and provides its own settings.
service.registerRuntimeModelConfig('agent-runtime:my-agent', {
extends: 'creative-writer', // extends a multi-level alias
modelConfig: {
generateContentConfig: {
temperature: 0.1, // Overrides parent
maxOutputTokens: 8192, // Adds a new property
},
},
});
// Resolve the configuration for the runtime alias, with a matching agent scope
const resolved = service.getResolvedConfig({
model: 'agent-runtime:my-agent',
overrideScope: 'core',
});
// Assert the final merged configuration.
expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from 'default-text-model'
expect(resolved.generateContentConfig).toEqual({
// from 'core' agent override, wins over runtime alias's 0.1 and creative-writer's 0.9
temperature: 0.5,
// from 'base' alias
topP: 0.95,
// from 'creative-writer' alias
topK: 50,
// from runtime alias
maxOutputTokens: 8192,
// from 'core' agent override
stopSequences: ['AGENT_STOP'],
});
});
});

View File

@@ -550,4 +550,30 @@ describe('ModelConfigService', () => {
]);
});
});
describe('runtime aliases', () => {
it('should resolve a simple runtime-registered alias', () => {
const config: ModelConfigServiceConfig = {
aliases: {},
overrides: [],
};
const service = new ModelConfigService(config);
service.registerRuntimeModelConfig('runtime-alias', {
modelConfig: {
model: 'gemini-runtime-model',
generateContentConfig: {
temperature: 0.123,
},
},
});
const resolved = service.getResolvedConfig({ model: 'runtime-alias' });
expect(resolved.model).toBe('gemini-runtime-model');
expect(resolved.generateContentConfig).toEqual({
temperature: 0.123,
});
});
});
});

View File

@@ -56,9 +56,15 @@ export interface _ResolvedModelConfig {
}
export class ModelConfigService {
private readonly runtimeAliases: Record<string, ModelConfigAlias> = {};
// TODO(12597): Process config to build a typed alias hierarchy.
constructor(private readonly config: ModelConfigServiceConfig) {}
registerRuntimeModelConfig(aliasName: string, alias: ModelConfigAlias): void {
this.runtimeAliases[aliasName] = alias;
}
private resolveAlias(
aliasName: string,
aliases: Record<string, ModelConfigAlias>,
@@ -99,12 +105,13 @@ export class ModelConfigService {
} {
const config = this.config || {};
const { aliases = {}, overrides = [] } = config;
const allAliases = { ...aliases, ...this.runtimeAliases };
let baseModel: string | undefined = context.model;
let resolvedConfig: GenerateContentConfig = {};
// Step 1: Alias Resolution
if (aliases[context.model]) {
const resolvedAlias = this.resolveAlias(context.model, aliases);
if (allAliases[context.model]) {
const resolvedAlias = this.resolveAlias(context.model, allAliases);
baseModel = resolvedAlias.modelConfig.model; // This can now be undefined
resolvedConfig = this.deepMerge(
resolvedConfig,

View File

@@ -11,7 +11,19 @@
"topP": 0.95,
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": -1
"thinkingBudget": 8192
},
"topK": 64
}
},
"gemini-3-pro-preview": {
"model": "gemini-3-pro-preview",
"generateContentConfig": {
"temperature": 1,
"topP": 0.95,
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": 8192
},
"topK": 64
}
@@ -23,7 +35,7 @@
"topP": 0.95,
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": -1
"thinkingBudget": 8192
},
"topK": 64
}
@@ -35,7 +47,7 @@
"topP": 0.95,
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": -1
"thinkingBudget": 8192
},
"topK": 64
}
@@ -47,7 +59,7 @@
"topP": 0.95,
"thinkingConfig": {
"includeThoughts": true,
"thinkingBudget": -1
"thinkingBudget": 8192
},
"topK": 64
}

View File

@@ -82,7 +82,8 @@ describe('checkNextSpeaker', () => {
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
chatInstance = new GeminiChat(
mockConfig,
{},
'', // empty system instruction
[], // no tools
[], // initial history
);