mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 10:10:56 -07:00
feat(core): Wire up chat code path for model configs. (#12850)
This commit is contained in:
@@ -27,8 +27,9 @@ import {
|
||||
type FunctionCall,
|
||||
type Part,
|
||||
type GenerateContentResponse,
|
||||
type GenerateContentConfig,
|
||||
type Content,
|
||||
type PartListUnion,
|
||||
type Tool,
|
||||
} from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
@@ -55,14 +56,22 @@ import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
|
||||
const { mockSendMessageStream, mockExecuteToolCall, mockCompress } = vi.hoisted(
|
||||
() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}),
|
||||
);
|
||||
const {
|
||||
mockSendMessageStream,
|
||||
mockExecuteToolCall,
|
||||
mockSetSystemInstruction,
|
||||
mockCompress,
|
||||
mockSetTools,
|
||||
} = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockSetSystemInstruction: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
mockSetTools: vi.fn(),
|
||||
}));
|
||||
|
||||
let mockChatHistory: Content[] = [];
|
||||
const mockSetHistory = vi.fn((newHistory: Content[]) => {
|
||||
@@ -83,6 +92,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
setHistory: mockSetHistory,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
})),
|
||||
};
|
||||
});
|
||||
@@ -172,8 +183,10 @@ const mockModelResponse = (
|
||||
const getMockMessageParams = (callIndex: number) => {
|
||||
const call = mockSendMessageStream.mock.calls[callIndex];
|
||||
expect(call).toBeDefined();
|
||||
// Arg 1 of sendMessageStream is the message parameters
|
||||
return call[1] as { message?: Part[]; config?: GenerateContentConfig };
|
||||
return {
|
||||
modelConfigKey: call[0],
|
||||
message: call[1],
|
||||
} as { modelConfigKey: ModelConfigKey; message: PartListUnion };
|
||||
};
|
||||
|
||||
let mockConfig: Config;
|
||||
@@ -223,6 +236,8 @@ describe('AgentExecutor', () => {
|
||||
mockCompress.mockClear();
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockSetSystemInstruction.mockReset();
|
||||
mockSetTools.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
mockedLogAgentFinish.mockReset();
|
||||
@@ -241,6 +256,8 @@ describe('AgentExecutor', () => {
|
||||
() =>
|
||||
({
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
setSystemInstruction: mockSetSystemInstruction,
|
||||
setTools: mockSetTools,
|
||||
getHistory: vi.fn((_curated?: boolean) => [...mockChatHistory]),
|
||||
getLastPromptTokenCount: vi.fn(() => 100),
|
||||
setHistory: mockSetHistory,
|
||||
@@ -358,7 +375,7 @@ describe('AgentExecutor', () => {
|
||||
await executor.run(inputs, signal);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const startHistory = chatConstructorArgs[2]; // history is the 3rd arg
|
||||
const startHistory = chatConstructorArgs[3]; // history is the 4th arg
|
||||
|
||||
expect(startHistory).toBeDefined();
|
||||
expect(startHistory).toHaveLength(2);
|
||||
@@ -459,10 +476,7 @@ describe('AgentExecutor', () => {
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
const chatConstructorArgs = MockedGeminiChat.mock.calls[0];
|
||||
const chatConfig = chatConstructorArgs[1];
|
||||
const systemInstruction = chatConfig?.systemInstruction as string;
|
||||
|
||||
const systemInstruction = MockedGeminiChat.mock.calls[0][1];
|
||||
expect(systemInstruction).toContain(
|
||||
`MUST call the \`${TASK_COMPLETE_TOOL_NAME}\` tool`,
|
||||
);
|
||||
@@ -472,18 +486,11 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
expect(systemInstruction).toContain('Always use absolute paths');
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
expect(sentTools).toEqual(
|
||||
@@ -604,17 +611,11 @@ describe('AgentExecutor', () => {
|
||||
|
||||
const output = await executor.run({ goal: 'Do work' }, signal);
|
||||
|
||||
const turn1Params = getMockMessageParams(0);
|
||||
const firstToolGroup = turn1Params.config?.tools?.[0];
|
||||
const { modelConfigKey } = getMockMessageParams(0);
|
||||
expect(modelConfigKey.model).toBe(getModelConfigAlias(definition));
|
||||
|
||||
expect(firstToolGroup).toBeDefined();
|
||||
if (!firstToolGroup || !('functionDeclarations' in firstToolGroup)) {
|
||||
throw new Error(
|
||||
'Test expectation failed: Config does not contain functionDeclarations.',
|
||||
);
|
||||
}
|
||||
|
||||
const sentTools = firstToolGroup.functionDeclarations;
|
||||
const call = mockSetTools.mock.calls[0];
|
||||
const sentTools = (call[0] as Tool[])[0].functionDeclarations;
|
||||
expect(sentTools).toBeDefined();
|
||||
|
||||
const completeToolDef = sentTools!.find(
|
||||
@@ -754,7 +755,7 @@ describe('AgentExecutor', () => {
|
||||
expect(turn2Parts).toBeDefined();
|
||||
expect(turn2Parts).toHaveLength(1);
|
||||
|
||||
expect(turn2Parts![0]).toEqual(
|
||||
expect((turn2Parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
@@ -944,7 +945,7 @@ describe('AgentExecutor', () => {
|
||||
const turn2Params = getMockMessageParams(1);
|
||||
const parts = turn2Params.message;
|
||||
expect(parts).toBeDefined();
|
||||
expect(parts![0]).toEqual(
|
||||
expect((parts as Part[])![0]).toEqual(
|
||||
expect.objectContaining({
|
||||
functionResponse: expect.objectContaining({
|
||||
id: badCallId,
|
||||
@@ -1222,18 +1223,18 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that is interruptible by an abort signal.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) => {
|
||||
// This promise resolves when aborted, ending the generator.
|
||||
signal?.addEventListener('abort', () => {
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
})();
|
||||
});
|
||||
})(),
|
||||
);
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
|
||||
@@ -1534,16 +1535,16 @@ describe('AgentExecutor', () => {
|
||||
);
|
||||
|
||||
// Mock a model call that gets interrupted by the timeout.
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
// This promise never resolves, it waits for abort.
|
||||
await new Promise<void>((resolve) => {
|
||||
signal?.addEventListener('abort', () => resolve());
|
||||
});
|
||||
})(),
|
||||
);
|
||||
|
||||
// Recovery turn (succeeds)
|
||||
mockModelResponse(
|
||||
@@ -1588,26 +1589,26 @@ describe('AgentExecutor', () => {
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
// Mock the recovery call to also be long-running
|
||||
mockSendMessageStream.mockImplementationOnce(async (_model, params) => {
|
||||
const signal = params?.config?.abortSignal;
|
||||
// eslint-disable-next-line require-yield
|
||||
return (async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})();
|
||||
});
|
||||
mockSendMessageStream.mockImplementationOnce(
|
||||
async (_key, _message, _promptId, signal) =>
|
||||
// eslint-disable-next-line require-yield
|
||||
(async function* () {
|
||||
await new Promise<void>((resolve) =>
|
||||
signal?.addEventListener('abort', () => resolve()),
|
||||
);
|
||||
})(),
|
||||
);
|
||||
|
||||
const runPromise = executor.run(
|
||||
{ goal: 'Timeout recovery fail' },
|
||||
|
||||
@@ -12,7 +12,6 @@ import type {
|
||||
Content,
|
||||
Part,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
FunctionDeclaration,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
@@ -53,6 +52,7 @@ import { parseThought } from '../utils/thoughtUtils.js';
|
||||
import { type z } from 'zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -595,18 +595,19 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
||||
const messageParams = {
|
||||
message: message.parts || [],
|
||||
config: {
|
||||
abortSignal: signal,
|
||||
tools: tools.length > 0 ? [{ functionDeclarations: tools }] : undefined,
|
||||
},
|
||||
};
|
||||
if (tools.length > 0) {
|
||||
// TODO(12622): Move tools back to config.
|
||||
chat.setTools([{ functionDeclarations: tools }]);
|
||||
}
|
||||
|
||||
const responseStream = await chat.sendMessageStream(
|
||||
this.definition.modelConfig.model,
|
||||
messageParams,
|
||||
{
|
||||
model: getModelConfigAlias(this.definition),
|
||||
overrideScope: this.definition.name,
|
||||
},
|
||||
message.parts || [],
|
||||
promptId,
|
||||
signal,
|
||||
);
|
||||
|
||||
const functionCalls: FunctionCall[] = [];
|
||||
@@ -650,7 +651,7 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
/** Initializes a `GeminiChat` instance for the agent run. */
|
||||
private async createChatObject(inputs: AgentInputs): Promise<GeminiChat> {
|
||||
const { promptConfig, modelConfig } = this.definition;
|
||||
const { promptConfig } = this.definition;
|
||||
|
||||
if (!promptConfig.systemPrompt && !promptConfig.initialMessages) {
|
||||
throw new Error(
|
||||
@@ -669,22 +670,10 @@ export class AgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
: undefined;
|
||||
|
||||
try {
|
||||
const generationConfig: GenerateContentConfig = {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
};
|
||||
|
||||
if (systemInstruction) {
|
||||
generationConfig.systemInstruction = systemInstruction;
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
this.runtimeContext,
|
||||
generationConfig,
|
||||
systemInstruction,
|
||||
[], // set in `callModel`,
|
||||
startHistory,
|
||||
);
|
||||
} catch (error) {
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { AgentRegistry } from './registry.js';
|
||||
import { AgentRegistry, getModelConfigAlias } from './registry.js';
|
||||
import { makeFakeConfig } from '../test-utils/config.js';
|
||||
import type { AgentDefinition } from './types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
@@ -77,6 +77,21 @@ describe('AgentRegistry', () => {
|
||||
it('should register a valid agent definition', () => {
|
||||
registry.testRegisterAgent(MOCK_AGENT_V1);
|
||||
expect(registry.getDefinition('MockAgent')).toEqual(MOCK_AGENT_V1);
|
||||
expect(
|
||||
mockConfig.modelConfigService.getResolvedConfig({
|
||||
model: getModelConfigAlias(MOCK_AGENT_V1),
|
||||
}),
|
||||
).toStrictEqual({
|
||||
model: MOCK_AGENT_V1.modelConfig.model,
|
||||
generateContentConfig: {
|
||||
temperature: MOCK_AGENT_V1.modelConfig.temp,
|
||||
topP: MOCK_AGENT_V1.modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1,
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle special characters in agent names', () => {
|
||||
|
||||
@@ -9,6 +9,16 @@ import type { AgentDefinition } from './types.js';
|
||||
import { CodebaseInvestigatorAgent } from './codebase-investigator.js';
|
||||
import { type z } from 'zod';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigAlias } from '../services/modelConfigService.js';
|
||||
|
||||
/**
|
||||
* Returns the model config alias for a given agent definition.
|
||||
*/
|
||||
export function getModelConfigAlias<TOutput extends z.ZodTypeAny>(
|
||||
definition: AgentDefinition<TOutput>,
|
||||
): string {
|
||||
return `${definition.name}-config`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages the discovery, loading, validation, and registration of
|
||||
@@ -84,6 +94,29 @@ export class AgentRegistry {
|
||||
}
|
||||
|
||||
this.agents.set(definition.name, definition);
|
||||
|
||||
// Register model config.
|
||||
// TODO(12916): Migrate sub-agents where possible to static configs.
|
||||
const modelConfig = definition.modelConfig;
|
||||
|
||||
const runtimeAlias: ModelConfigAlias = {
|
||||
modelConfig: {
|
||||
model: modelConfig.model,
|
||||
generateContentConfig: {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
this.config.modelConfigService.registerRuntimeModelConfig(
|
||||
getModelConfigAlias(definition),
|
||||
runtimeAlias,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type { ModelConfigServiceConfig } from '../services/modelConfigService.js';
|
||||
import { DEFAULT_THINKING_MODE } from './models.js';
|
||||
|
||||
// The default model configs. We use `base` as the parent for all of our model
|
||||
// configs, while `chat-base`, a child of `base`, is the parent of the models
|
||||
@@ -25,7 +26,9 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
generateContentConfig: {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1,
|
||||
// TODO(joshualitt): Introduce new bases for Gemini 3 models to use
|
||||
// thinkingLevel instead.
|
||||
thinkingBudget: DEFAULT_THINKING_MODE,
|
||||
},
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
@@ -38,6 +41,12 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
// ensure these model configs can be used interactively.
|
||||
// TODO(joshualitt): Introduce internal base configs for the various models,
|
||||
// note: we will have to think carefully about names.
|
||||
'gemini-3-pro-preview': {
|
||||
extends: 'chat-base',
|
||||
modelConfig: {
|
||||
model: 'gemini-3-pro-preview',
|
||||
},
|
||||
},
|
||||
'gemini-2.5-pro': {
|
||||
extends: 'chat-base',
|
||||
modelConfig: {
|
||||
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import { isThinkingSupported, GeminiClient } from './client.js';
|
||||
import { GeminiClient } from './client.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGenerator,
|
||||
@@ -134,30 +134,6 @@ async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
|
||||
return results;
|
||||
}
|
||||
|
||||
describe('isThinkingSupported', () => {
|
||||
it('should return true for gemini-2.5', () => {
|
||||
expect(isThinkingSupported('gemini-2.5')).toBe(true);
|
||||
expect(isThinkingSupported('gemini-2.5-flash')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-2.5-pro', () => {
|
||||
expect(isThinkingSupported('gemini-2.5-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for gemini-3-pro', () => {
|
||||
expect(isThinkingSupported('gemini-3-pro')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for gemini-2.0 models', () => {
|
||||
expect(isThinkingSupported('gemini-2.0-flash')).toBe(false);
|
||||
expect(isThinkingSupported('gemini-2.0-pro')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for other models', () => {
|
||||
expect(isThinkingSupported('some-other-model')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
@@ -766,14 +742,10 @@ ${JSON.stringify(
|
||||
|
||||
// Assert
|
||||
expect(ideContextStore.get).toHaveBeenCalled();
|
||||
// The `turn.run` method is now called with the model name as the first
|
||||
// argument. We use `expect.any(String)` because this test is
|
||||
// concerned with the IDE context logic, not the model routing,
|
||||
// which is tested in its own dedicated suite.
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
{ model: 'default-routed-model' },
|
||||
initialRequest,
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1302,9 +1274,9 @@ ${JSON.stringify(
|
||||
expect(mockConfig.getModelRouterService).toHaveBeenCalled();
|
||||
expect(mockRouterService.route).toHaveBeenCalled();
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model', // The model from the router
|
||||
{ model: 'routed-model' },
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1319,9 +1291,9 @@ ${JSON.stringify(
|
||||
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
{ model: 'routed-model' },
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
// Second turn
|
||||
@@ -1336,9 +1308,9 @@ ${JSON.stringify(
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
// Should stick to the first model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
{ model: 'routed-model' },
|
||||
[{ text: 'Continue' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1353,9 +1325,9 @@ ${JSON.stringify(
|
||||
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
{ model: 'routed-model' },
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
// New prompt
|
||||
@@ -1374,9 +1346,9 @@ ${JSON.stringify(
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
|
||||
// Should use the newly routed model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'new-routed-model',
|
||||
{ model: 'new-routed-model' },
|
||||
[{ text: 'A new topic' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1395,9 +1367,9 @@ ${JSON.stringify(
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL },
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1417,9 +1389,9 @@ ${JSON.stringify(
|
||||
|
||||
// First call should use fallback model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL },
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
// End fallback mode
|
||||
@@ -1436,9 +1408,9 @@ ${JSON.stringify(
|
||||
// Router should still not be called, and it should stick to the fallback model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again
|
||||
expect(mockTurnRunFn).toHaveBeenLastCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL, // Still the fallback model
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL }, // Still the fallback model
|
||||
[{ text: 'Continue' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1487,17 +1459,17 @@ ${JSON.stringify(
|
||||
// First call with original request
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.any(String),
|
||||
{ model: 'default-routed-model' },
|
||||
initialRequest,
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
// Second call with "Please continue."
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.any(String),
|
||||
{ model: 'default-routed-model' },
|
||||
[{ text: 'System: Please continue.' }],
|
||||
expect.any(Object),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -2264,7 +2236,7 @@ ${JSON.stringify(
|
||||
.mockReturnValueOnce(true);
|
||||
|
||||
let capturedSignal: AbortSignal;
|
||||
mockTurnRunFn.mockImplementation((model, request, signal) => {
|
||||
mockTurnRunFn.mockImplementation((_modelConfigKey, _request, signal) => {
|
||||
capturedSignal = signal;
|
||||
return (async function* () {
|
||||
yield { type: 'content', value: 'First event' };
|
||||
|
||||
@@ -33,7 +33,6 @@ import type {
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_THINKING_MODE,
|
||||
getEffectiveModel,
|
||||
} from '../config/models.js';
|
||||
import { LoopDetectionService } from '../services/loopDetectionService.js';
|
||||
@@ -54,19 +53,10 @@ import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
return !model.startsWith('gemini-2.0');
|
||||
}
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private readonly generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 1,
|
||||
topP: 0.95,
|
||||
topK: 64,
|
||||
};
|
||||
private sessionTurnCount = 0;
|
||||
|
||||
private readonly loopDetector: LoopDetectionService;
|
||||
@@ -194,24 +184,10 @@ export class GeminiClient {
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
|
||||
const model = this.config.getModel();
|
||||
|
||||
const config: GenerateContentConfig = { ...this.generateContentConfig };
|
||||
|
||||
if (isThinkingSupported(model)) {
|
||||
config.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: DEFAULT_THINKING_MODE,
|
||||
};
|
||||
}
|
||||
|
||||
return new GeminiChat(
|
||||
this.config,
|
||||
{
|
||||
systemInstruction,
|
||||
...config,
|
||||
tools,
|
||||
},
|
||||
systemInstruction,
|
||||
tools,
|
||||
history,
|
||||
resumedSessionData,
|
||||
);
|
||||
@@ -515,7 +491,7 @@ export class GeminiClient {
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
}
|
||||
|
||||
const resultStream = turn.run(modelToUse, request, linkedSignal);
|
||||
const resultStream = turn.run({ model: modelToUse }, request, linkedSignal);
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
|
||||
@@ -5,11 +5,7 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import type {
|
||||
Content,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { Content, GenerateContentResponse } from '@google/genai';
|
||||
import { ApiError } from '@google/genai';
|
||||
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||
import {
|
||||
@@ -94,7 +90,6 @@ describe('GeminiChat', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let chat: GeminiChat;
|
||||
let mockConfig: Config;
|
||||
const config: GenerateContentConfig = {};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
@@ -135,6 +130,14 @@ describe('GeminiChat', () => {
|
||||
}),
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => ({
|
||||
model: modelConfigKey.model,
|
||||
generateContentConfig: {
|
||||
temperature: 0,
|
||||
},
|
||||
})),
|
||||
},
|
||||
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
|
||||
setPreviewModelBypassMode: vi.fn(),
|
||||
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
|
||||
@@ -145,7 +148,7 @@ describe('GeminiChat', () => {
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
// Reset history for each test by creating a new instance
|
||||
chat = new GeminiChat(mockConfig, config, []);
|
||||
chat = new GeminiChat(mockConfig);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -159,13 +162,13 @@ describe('GeminiChat', () => {
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
{ role: 'model', parts: [{ text: 'Hi there' }] },
|
||||
];
|
||||
const chatWithHistory = new GeminiChat(mockConfig, config, history);
|
||||
const chatWithHistory = new GeminiChat(mockConfig, '', [], history);
|
||||
const estimatedTokens = Math.ceil(JSON.stringify(history).length / 4);
|
||||
expect(chatWithHistory.getLastPromptTokenCount()).toBe(estimatedTokens);
|
||||
});
|
||||
|
||||
it('should initialize lastPromptTokenCount for empty history', () => {
|
||||
const chatEmpty = new GeminiChat(mockConfig, config, []);
|
||||
const chatEmpty = new GeminiChat(mockConfig);
|
||||
expect(chatEmpty.getLastPromptTokenCount()).toBe(
|
||||
Math.ceil(JSON.stringify([]).length / 4),
|
||||
);
|
||||
@@ -206,9 +209,10 @@ describe('GeminiChat', () => {
|
||||
// 2. Action & Assert: The stream processing should complete without throwing an error
|
||||
// because the presence of a tool call makes the empty final chunk acceptable.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-tool-call-empty-end',
|
||||
new AbortController().signal,
|
||||
);
|
||||
await expect(
|
||||
(async () => {
|
||||
@@ -258,9 +262,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action & Assert: The stream should fail because there's no finish reason.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test message' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-no-finish-empty-end',
|
||||
new AbortController().signal,
|
||||
);
|
||||
await expect(
|
||||
(async () => {
|
||||
@@ -304,9 +309,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action & Assert: The stream should complete without throwing an error.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-valid-then-invalid-end',
|
||||
new AbortController().signal,
|
||||
);
|
||||
await expect(
|
||||
(async () => {
|
||||
@@ -351,9 +357,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-empty-chunk-consolidation',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// Consume the stream
|
||||
@@ -409,9 +416,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-multi-chunk',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// Consume the stream to trigger history recording.
|
||||
@@ -457,9 +465,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and fully consume the stream to trigger history recording.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-mixed-chunk',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// This loop consumes the stream.
|
||||
@@ -499,9 +508,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
{ message: 'test' },
|
||||
{ model: PREVIEW_GEMINI_MODEL },
|
||||
'test',
|
||||
'prompt-id-fast-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
@@ -531,9 +541,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
{ message: 'test' },
|
||||
{ model: DEFAULT_GEMINI_FLASH_MODEL },
|
||||
'test',
|
||||
'prompt-id-normal-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
@@ -577,9 +588,10 @@ describe('GeminiChat', () => {
|
||||
// ACT
|
||||
const consumeStream = async () => {
|
||||
const stream = await chat.sendMessageStream(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
{ message: 'test' },
|
||||
{ model: PREVIEW_GEMINI_MODEL },
|
||||
'test',
|
||||
'prompt-id-bypass',
|
||||
new AbortController().signal,
|
||||
);
|
||||
// Consume the stream to trigger execution
|
||||
for await (const _ of stream) {
|
||||
@@ -639,16 +651,15 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Action: Send the function response back to the model and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
{
|
||||
message: {
|
||||
functionResponse: {
|
||||
name: 'find_restaurant',
|
||||
response: { name: 'Vesuvio' },
|
||||
},
|
||||
functionResponse: {
|
||||
name: 'find_restaurant',
|
||||
response: { name: 'Vesuvio' },
|
||||
},
|
||||
},
|
||||
'prompt-id-stream-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// 4. Assert: The stream processing should throw an InvalidStreamError.
|
||||
@@ -689,9 +700,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Should not throw an error
|
||||
@@ -725,9 +737,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(
|
||||
@@ -760,9 +773,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(
|
||||
@@ -795,9 +809,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Should not throw an error
|
||||
@@ -831,9 +846,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.5-pro',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.5-pro' },
|
||||
'test',
|
||||
'prompt-id-malformed',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Should throw an error
|
||||
@@ -877,9 +893,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Send a message
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.5-pro',
|
||||
{ message: 'test retry' },
|
||||
{ model: 'gemini-2.5-pro' },
|
||||
'test retry',
|
||||
'prompt-id-retry-malformed',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
@@ -933,9 +950,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'hello' },
|
||||
{ model: 'test-model' },
|
||||
'hello',
|
||||
'prompt-id-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
@@ -950,7 +968,12 @@ describe('GeminiChat', () => {
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
],
|
||||
config: {},
|
||||
config: {
|
||||
systemInstruction: '',
|
||||
tools: [],
|
||||
temperature: 0,
|
||||
abortSignal: expect.any(AbortSignal),
|
||||
},
|
||||
},
|
||||
'prompt-id-1',
|
||||
);
|
||||
@@ -1000,9 +1023,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-1.5-pro',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-1.5-pro' },
|
||||
'test',
|
||||
'prompt-id-no-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(
|
||||
@@ -1047,9 +1071,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// ACT: Send a message and collect all events from the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-yield-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
@@ -1088,9 +1113,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test',
|
||||
'prompt-id-retry-success',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const chunks: StreamEvent[] = [];
|
||||
for await (const chunk of stream) {
|
||||
@@ -1159,9 +1185,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test', config: { temperature: 0.5 } },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-retry-temperature',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
for await (const _ of stream) {
|
||||
@@ -1179,7 +1206,7 @@ describe('GeminiChat', () => {
|
||||
1,
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.5,
|
||||
temperature: 0,
|
||||
}),
|
||||
}),
|
||||
'prompt-id-retry-temperature',
|
||||
@@ -1217,9 +1244,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test',
|
||||
'prompt-id-retry-fail',
|
||||
new AbortController().signal,
|
||||
);
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
@@ -1282,9 +1310,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-400',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(
|
||||
@@ -1320,9 +1349,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-429-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const events: StreamEvent[] = [];
|
||||
@@ -1368,9 +1398,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-500-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const events: StreamEvent[] = [];
|
||||
@@ -1424,9 +1455,10 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-fetch-error-retry',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
const events: StreamEvent[] = [];
|
||||
@@ -1487,9 +1519,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Send a new message
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'Second question' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'Second question',
|
||||
'prompt-id-retry-existing',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
@@ -1558,9 +1591,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Call the method and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test empty stream' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test empty stream',
|
||||
'prompt-id-empty-stream',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const chunks: StreamEvent[] = [];
|
||||
for await (const chunk of stream) {
|
||||
@@ -1638,18 +1672,20 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Start the first stream and consume only the first chunk to pause it
|
||||
const firstStream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'first' },
|
||||
{ model: 'test-model' },
|
||||
'first',
|
||||
'prompt-1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const firstStreamIterator = firstStream[Symbol.asyncIterator]();
|
||||
await firstStreamIterator.next();
|
||||
|
||||
// 4. While the first stream is paused, start the second call. It will block.
|
||||
const secondStreamPromise = chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'second' },
|
||||
{ model: 'test-model' },
|
||||
'second',
|
||||
'prompt-2',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// 5. Assert that only one API call has been made so far.
|
||||
@@ -1707,9 +1743,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test message',
|
||||
'prompt-id-res3',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
@@ -1793,9 +1830,10 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'trigger 429' },
|
||||
{ model: 'test-model' },
|
||||
'trigger 429',
|
||||
'prompt-id-fb1',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Consume stream to trigger logic
|
||||
@@ -1827,9 +1865,10 @@ describe('GeminiChat', () => {
|
||||
mockHandleFallback.mockResolvedValue(false);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test stop' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test stop',
|
||||
'prompt-id-fb2',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
await expect(
|
||||
@@ -1885,9 +1924,10 @@ describe('GeminiChat', () => {
|
||||
|
||||
// Send a message and consume the stream
|
||||
const stream = await chat.sendMessageStream(
|
||||
'gemini-2.0-flash',
|
||||
{ message: 'test' },
|
||||
{ model: 'gemini-2.0-flash' },
|
||||
'test message',
|
||||
'prompt-id-discard-test',
|
||||
new AbortController().signal,
|
||||
);
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
@@ -1965,9 +2005,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
{ model: 'test-model' },
|
||||
'test',
|
||||
'prompt-id-preview-model-reset',
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(mockConfig.setPreviewModelBypassMode).toHaveBeenCalledWith(false);
|
||||
@@ -1989,9 +2030,10 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const resultStream = await chat.sendMessageStream(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
{ message: 'test' },
|
||||
{ model: PREVIEW_GEMINI_MODEL },
|
||||
'test',
|
||||
'prompt-id-preview-model-healing',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of resultStream) {
|
||||
// consume stream
|
||||
@@ -2019,9 +2061,10 @@ describe('GeminiChat', () => {
|
||||
vi.mocked(mockConfig.isPreviewModelBypassMode).mockReturnValue(true);
|
||||
|
||||
const resultStream = await chat.sendMessageStream(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
{ message: 'test' },
|
||||
{ model: PREVIEW_GEMINI_MODEL },
|
||||
'test',
|
||||
'prompt-id-bypass-no-healing',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of resultStream) {
|
||||
// consume stream
|
||||
@@ -2033,7 +2076,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
describe('ensureActiveLoopHasThoughtSignatures', () => {
|
||||
it('should add thoughtSignature to the first functionCall in each model turn of the active loop', () => {
|
||||
const chat = new GeminiChat(mockConfig, {}, []);
|
||||
const chat = new GeminiChat(mockConfig, '', [], []);
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'Old message' }] },
|
||||
{
|
||||
@@ -2090,7 +2133,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should not modify contents if there is no user text message', () => {
|
||||
const chat = new GeminiChat(mockConfig, {}, []);
|
||||
const chat = new GeminiChat(mockConfig, '', [], []);
|
||||
const history: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
@@ -2107,14 +2150,14 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should handle an empty history', () => {
|
||||
const chat = new GeminiChat(mockConfig, {}, []);
|
||||
const chat = new GeminiChat(mockConfig, '', []);
|
||||
const history: Content[] = [];
|
||||
const newContents = chat.ensureActiveLoopHasThoughtSignatures(history);
|
||||
expect(newContents).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle history with only a user message', () => {
|
||||
const chat = new GeminiChat(mockConfig, {}, []);
|
||||
const chat = new GeminiChat(mockConfig, '', []);
|
||||
const history: Content[] = [{ role: 'user', parts: [{ text: 'Hello' }] }];
|
||||
const newContents = chat.ensureActiveLoopHasThoughtSignatures(history);
|
||||
expect(newContents).toEqual(history);
|
||||
|
||||
@@ -10,10 +10,10 @@
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
Content,
|
||||
GenerateContentConfig,
|
||||
SendMessageParameters,
|
||||
Part,
|
||||
Tool,
|
||||
PartListUnion,
|
||||
GenerateContentConfig,
|
||||
} from '@google/genai';
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent, FinishReason } from '@google/genai';
|
||||
@@ -43,6 +43,7 @@ import {
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
/** A regular content chunk from the API. */
|
||||
@@ -202,7 +203,8 @@ export class GeminiChat {
|
||||
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
private readonly generationConfig: GenerateContentConfig = {},
|
||||
private systemInstruction: string = '',
|
||||
private tools: Tool[] = [],
|
||||
private history: Content[] = [],
|
||||
resumedSessionData?: ResumedSessionData,
|
||||
) {
|
||||
@@ -215,7 +217,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
setSystemInstruction(sysInstr: string) {
|
||||
this.generationConfig.systemInstruction = sysInstr;
|
||||
this.systemInstruction = sysInstr;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -226,7 +228,10 @@ export class GeminiChat {
|
||||
* sending the next message.
|
||||
*
|
||||
* @see {@link Chat#sendMessage} for non-streaming method.
|
||||
* @param params - parameters for sending the message.
|
||||
* @param modelConfigKey - The key for the model config.
|
||||
* @param message - The list of messages to send.
|
||||
* @param prompt_id - The ID of the prompt.
|
||||
* @param signal - An abort signal for this message.
|
||||
* @return The model's response.
|
||||
*
|
||||
* @example
|
||||
@@ -241,9 +246,10 @@ export class GeminiChat {
|
||||
* ```
|
||||
*/
|
||||
async sendMessageStream(
|
||||
model: string,
|
||||
params: SendMessageParameters,
|
||||
modelConfigKey: ModelConfigKey,
|
||||
message: PartListUnion,
|
||||
prompt_id: string,
|
||||
signal: AbortSignal,
|
||||
): Promise<AsyncGenerator<StreamEvent>> {
|
||||
await this.sendPromise;
|
||||
|
||||
@@ -251,21 +257,21 @@ export class GeminiChat {
|
||||
// This ensures that we attempt to use Preview Model for every new user turn
|
||||
// (unless the "Always" fallback mode is active, which is handled separately).
|
||||
this.config.setPreviewModelBypassMode(false);
|
||||
|
||||
let streamDoneResolver: () => void;
|
||||
const streamDonePromise = new Promise<void>((resolve) => {
|
||||
streamDoneResolver = resolve;
|
||||
});
|
||||
this.sendPromise = streamDonePromise;
|
||||
|
||||
const userContent = createUserContent(params.message);
|
||||
const userContent = createUserContent(message);
|
||||
const { model, generateContentConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig(modelConfigKey);
|
||||
generateContentConfig.abortSignal = signal;
|
||||
|
||||
// Record user input - capture complete message with all parts (text, files, images, etc.)
|
||||
// but skip recording function responses (tool call results) as they should be stored in tool call records
|
||||
if (!isFunctionResponse(userContent)) {
|
||||
const userMessage = Array.isArray(params.message)
|
||||
? params.message
|
||||
: [params.message];
|
||||
const userMessage = Array.isArray(message) ? message : [message];
|
||||
const userMessageContent = partListUnionToString(toParts(userMessage));
|
||||
this.chatRecordingService.recordMessage({
|
||||
model,
|
||||
@@ -301,18 +307,14 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
// If this is a retry, set temperature to 1 to encourage different output.
|
||||
const currentParams = { ...params };
|
||||
if (attempt > 0) {
|
||||
currentParams.config = {
|
||||
...currentParams.config,
|
||||
temperature: 1,
|
||||
};
|
||||
generateContentConfig.temperature = 1;
|
||||
}
|
||||
|
||||
const stream = await self.makeApiCallAndProcessStream(
|
||||
model,
|
||||
generateContentConfig,
|
||||
requestContents,
|
||||
currentParams,
|
||||
prompt_id,
|
||||
);
|
||||
|
||||
@@ -385,8 +387,8 @@ export class GeminiChat {
|
||||
|
||||
private async makeApiCallAndProcessStream(
|
||||
model: string,
|
||||
generateContentConfig: GenerateContentConfig,
|
||||
requestContents: Content[],
|
||||
params: SendMessageParameters,
|
||||
prompt_id: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
let effectiveModel = model;
|
||||
@@ -418,7 +420,13 @@ export class GeminiChat {
|
||||
modelToUse === PREVIEW_GEMINI_MODEL
|
||||
? contentsForPreviewModel
|
||||
: requestContents,
|
||||
config: { ...this.generationConfig, ...params.config },
|
||||
config: {
|
||||
...generateContentConfig,
|
||||
// TODO(12622): Ensure we don't overrwrite these when they are
|
||||
// passed via config.
|
||||
systemInstruction: this.systemInstruction,
|
||||
tools: this.tools,
|
||||
},
|
||||
},
|
||||
prompt_id,
|
||||
);
|
||||
@@ -433,7 +441,7 @@ export class GeminiChat {
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
signal: params.config?.abortSignal,
|
||||
signal: generateContentConfig.abortSignal,
|
||||
maxAttempts:
|
||||
this.config.isPreviewModelFallbackMode() &&
|
||||
model === PREVIEW_GEMINI_MODEL
|
||||
@@ -561,7 +569,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
setTools(tools: Tool[]): void {
|
||||
this.generationConfig.tools = tools;
|
||||
this.tools = tools;
|
||||
}
|
||||
|
||||
async maybeIncludeSchemaDepthContext(error: StructuredError): Promise<void> {
|
||||
|
||||
@@ -97,7 +97,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -105,12 +105,10 @@ describe('Turn', () => {
|
||||
}
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
'test-model',
|
||||
{
|
||||
message: reqParts,
|
||||
config: { abortSignal: expect.any(AbortSignal) },
|
||||
},
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
'prompt-id-1',
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
expect(events).toEqual([
|
||||
@@ -146,7 +144,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Use tools' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -210,7 +208,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test abort' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
abortController.signal,
|
||||
)) {
|
||||
@@ -233,7 +231,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -256,7 +254,7 @@ describe('Turn', () => {
|
||||
mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined);
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -297,7 +295,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'Test undefined tool parts' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -374,7 +372,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'Test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -411,7 +409,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test no finish reason' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -456,7 +454,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test multiple responses' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -499,7 +497,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'Test citations' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -549,7 +547,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -596,7 +594,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -642,7 +640,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -680,7 +678,7 @@ describe('Turn', () => {
|
||||
const reqParts: Part[] = [{ text: 'Test malformed error handling' }];
|
||||
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
abortController.signal,
|
||||
)) {
|
||||
@@ -706,7 +704,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -754,7 +752,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -780,7 +778,7 @@ describe('Turn', () => {
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const _ of turn.run(
|
||||
'test-model',
|
||||
{ model: 'gemini' },
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
|
||||
@@ -30,6 +30,7 @@ import type { GeminiChat } from './geminiChat.js';
|
||||
import { InvalidStreamError } from './geminiChat.js';
|
||||
import { parseThought, type ThoughtSummary } from '../utils/thoughtUtils.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
|
||||
// Define a structure for tools passed to the server
|
||||
export interface ServerTool {
|
||||
@@ -232,9 +233,10 @@ export class Turn {
|
||||
private readonly chat: GeminiChat,
|
||||
private readonly prompt_id: string,
|
||||
) {}
|
||||
|
||||
// The run method yields simpler events suitable for server logic
|
||||
async *run(
|
||||
model: string,
|
||||
modelConfigKey: ModelConfigKey,
|
||||
req: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
@@ -242,14 +244,10 @@ export class Turn {
|
||||
// Note: This assumes `sendMessageStream` yields events like
|
||||
// { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse }
|
||||
const responseStream = await this.chat.sendMessageStream(
|
||||
model,
|
||||
{
|
||||
message: req,
|
||||
config: {
|
||||
abortSignal: signal,
|
||||
},
|
||||
},
|
||||
modelConfigKey,
|
||||
req,
|
||||
this.prompt_id,
|
||||
signal,
|
||||
);
|
||||
|
||||
for await (const streamEvent of responseStream) {
|
||||
|
||||
@@ -231,4 +231,42 @@ describe('ModelConfigService Integration', () => {
|
||||
topP: 0.95, // from base
|
||||
});
|
||||
});
|
||||
|
||||
it('should correctly merge static aliases, runtime aliases, and overrides', () => {
|
||||
// Re-instantiate service for this isolated test to not pollute other tests
|
||||
const service = new ModelConfigService(complexConfig);
|
||||
|
||||
// Register a runtime alias, simulating what AgentExecutor does.
|
||||
// This alias extends a static base and provides its own settings.
|
||||
service.registerRuntimeModelConfig('agent-runtime:my-agent', {
|
||||
extends: 'creative-writer', // extends a multi-level alias
|
||||
modelConfig: {
|
||||
generateContentConfig: {
|
||||
temperature: 0.1, // Overrides parent
|
||||
maxOutputTokens: 8192, // Adds a new property
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Resolve the configuration for the runtime alias, with a matching agent scope
|
||||
const resolved = service.getResolvedConfig({
|
||||
model: 'agent-runtime:my-agent',
|
||||
overrideScope: 'core',
|
||||
});
|
||||
|
||||
// Assert the final merged configuration.
|
||||
expect(resolved.model).toBe('gemini-1.5-pro-latest'); // from 'default-text-model'
|
||||
expect(resolved.generateContentConfig).toEqual({
|
||||
// from 'core' agent override, wins over runtime alias's 0.1 and creative-writer's 0.9
|
||||
temperature: 0.5,
|
||||
// from 'base' alias
|
||||
topP: 0.95,
|
||||
// from 'creative-writer' alias
|
||||
topK: 50,
|
||||
// from runtime alias
|
||||
maxOutputTokens: 8192,
|
||||
// from 'core' agent override
|
||||
stopSequences: ['AGENT_STOP'],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -550,4 +550,30 @@ describe('ModelConfigService', () => {
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('runtime aliases', () => {
|
||||
it('should resolve a simple runtime-registered alias', () => {
|
||||
const config: ModelConfigServiceConfig = {
|
||||
aliases: {},
|
||||
overrides: [],
|
||||
};
|
||||
const service = new ModelConfigService(config);
|
||||
|
||||
service.registerRuntimeModelConfig('runtime-alias', {
|
||||
modelConfig: {
|
||||
model: 'gemini-runtime-model',
|
||||
generateContentConfig: {
|
||||
temperature: 0.123,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const resolved = service.getResolvedConfig({ model: 'runtime-alias' });
|
||||
|
||||
expect(resolved.model).toBe('gemini-runtime-model');
|
||||
expect(resolved.generateContentConfig).toEqual({
|
||||
temperature: 0.123,
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -56,9 +56,15 @@ export interface _ResolvedModelConfig {
|
||||
}
|
||||
|
||||
export class ModelConfigService {
|
||||
private readonly runtimeAliases: Record<string, ModelConfigAlias> = {};
|
||||
|
||||
// TODO(12597): Process config to build a typed alias hierarchy.
|
||||
constructor(private readonly config: ModelConfigServiceConfig) {}
|
||||
|
||||
registerRuntimeModelConfig(aliasName: string, alias: ModelConfigAlias): void {
|
||||
this.runtimeAliases[aliasName] = alias;
|
||||
}
|
||||
|
||||
private resolveAlias(
|
||||
aliasName: string,
|
||||
aliases: Record<string, ModelConfigAlias>,
|
||||
@@ -99,12 +105,13 @@ export class ModelConfigService {
|
||||
} {
|
||||
const config = this.config || {};
|
||||
const { aliases = {}, overrides = [] } = config;
|
||||
const allAliases = { ...aliases, ...this.runtimeAliases };
|
||||
let baseModel: string | undefined = context.model;
|
||||
let resolvedConfig: GenerateContentConfig = {};
|
||||
|
||||
// Step 1: Alias Resolution
|
||||
if (aliases[context.model]) {
|
||||
const resolvedAlias = this.resolveAlias(context.model, aliases);
|
||||
if (allAliases[context.model]) {
|
||||
const resolvedAlias = this.resolveAlias(context.model, allAliases);
|
||||
baseModel = resolvedAlias.modelConfig.model; // This can now be undefined
|
||||
resolvedConfig = this.deepMerge(
|
||||
resolvedConfig,
|
||||
|
||||
@@ -11,7 +11,19 @@
|
||||
"topP": 0.95,
|
||||
"thinkingConfig": {
|
||||
"includeThoughts": true,
|
||||
"thinkingBudget": -1
|
||||
"thinkingBudget": 8192
|
||||
},
|
||||
"topK": 64
|
||||
}
|
||||
},
|
||||
"gemini-3-pro-preview": {
|
||||
"model": "gemini-3-pro-preview",
|
||||
"generateContentConfig": {
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"thinkingConfig": {
|
||||
"includeThoughts": true,
|
||||
"thinkingBudget": 8192
|
||||
},
|
||||
"topK": 64
|
||||
}
|
||||
@@ -23,7 +35,7 @@
|
||||
"topP": 0.95,
|
||||
"thinkingConfig": {
|
||||
"includeThoughts": true,
|
||||
"thinkingBudget": -1
|
||||
"thinkingBudget": 8192
|
||||
},
|
||||
"topK": 64
|
||||
}
|
||||
@@ -35,7 +47,7 @@
|
||||
"topP": 0.95,
|
||||
"thinkingConfig": {
|
||||
"includeThoughts": true,
|
||||
"thinkingBudget": -1
|
||||
"thinkingBudget": 8192
|
||||
},
|
||||
"topK": 64
|
||||
}
|
||||
@@ -47,7 +59,7 @@
|
||||
"topP": 0.95,
|
||||
"thinkingConfig": {
|
||||
"includeThoughts": true,
|
||||
"thinkingBudget": -1
|
||||
"thinkingBudget": 8192
|
||||
},
|
||||
"topK": 64
|
||||
}
|
||||
|
||||
@@ -82,7 +82,8 @@ describe('checkNextSpeaker', () => {
|
||||
// GeminiChat will receive the mocked instances via the mocked GoogleGenAI constructor
|
||||
chatInstance = new GeminiChat(
|
||||
mockConfig,
|
||||
{},
|
||||
'', // empty system instruction
|
||||
[], // no tools
|
||||
[], // initial history
|
||||
);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user