diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index 82966d47ed..4de0c8a5c4 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -7,7 +7,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import * as os from 'node:os'; import * as path from 'node:path'; -import { ShellTool, EditTool, WriteFileTool } from '@google/gemini-cli-core'; +import { + ShellTool, + EditTool, + WriteFileTool, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, +} from '@google/gemini-cli-core'; import { loadCliConfig, parseArguments, type CliArgs } from './config.js'; import type { Settings } from './settings.js'; import type { Extension } from './extension.js'; @@ -1484,6 +1490,97 @@ describe('loadCliConfig model selection', () => { }); }); +describe('loadCliConfig model selection with model router', () => { + it('should use auto model when useModelRouter is true and no model is provided', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + { + experimental: { + useModelRouter: true, + }, + }, + [], + 'test-session', + argv, + ); + + expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL_AUTO); + }); + + it('should use default model when useModelRouter is false and no model is provided', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + { + experimental: { + useModelRouter: false, + }, + }, + [], + 'test-session', + argv, + ); + + expect(config.getModel()).toBe(DEFAULT_GEMINI_MODEL); + }); + + it('should prioritize argv over useModelRouter', async () => { + process.argv = ['node', 'script.js', '--model', 'gemini-from-argv']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + { + experimental: { + useModelRouter: true, + }, + }, + [], + 'test-session', + argv, + ); + + expect(config.getModel()).toBe('gemini-from-argv'); + }); + + it('should prioritize settings over useModelRouter', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + { + experimental: { + useModelRouter: true, + }, + model: { + name: 'gemini-from-settings', + }, + }, + [], + 'test-session', + argv, + ); + + expect(config.getModel()).toBe('gemini-from-settings'); + }); + + it('should prioritize environment variable over useModelRouter', async () => { + process.argv = ['node', 'script.js']; + vi.stubEnv('GEMINI_MODEL', 'gemini-from-env'); + const argv = await parseArguments({} as Settings); + const config = await loadCliConfig( + { + experimental: { + useModelRouter: true, + }, + }, + [], + 'test-session', + argv, + ); + + expect(config.getModel()).toBe('gemini-from-env'); + }); +}); + describe('loadCliConfig folderTrust', () => { const originalArgv = process.argv; @@ -1668,6 +1765,32 @@ describe('loadCliConfig useRipgrep', () => { const config = await loadCliConfig(settings, [], 'test-session', argv); expect(config.getUseRipgrep()).toBe(true); }); + + describe('loadCliConfig useModelRouter', () => { + it('should be false by default when useModelRouter is not set in settings', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const settings: Settings = {}; + const config = await loadCliConfig(settings, [], 'test-session', argv); + expect(config.getUseModelRouter()).toBe(false); + }); + + it('should be true when useModelRouter is set to true in settings', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const settings: Settings = { experimental: { useModelRouter: true } }; + const config = await loadCliConfig(settings, [], 'test-session', argv); + expect(config.getUseModelRouter()).toBe(true); + }); + + it('should be false when useModelRouter is explicitly set to false in settings', async () => { + process.argv = ['node', 'script.js']; + const argv = await parseArguments({} as Settings); + const settings: Settings = { experimental: { useModelRouter: false } }; + const config = await loadCliConfig(settings, [], 'test-session', argv); + expect(config.getUseModelRouter()).toBe(false); + }); + }); }); describe('loadCliConfig tool exclusions', () => { diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 3d0ba230d7..440c923aa3 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -25,6 +25,7 @@ import { getCurrentGeminiMdFilename, ApprovalMode, DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_MEMORY_FILE_FILTERING_OPTIONS, FileDiscoveryService, @@ -98,7 +99,6 @@ export async function parseArguments(settings: Settings): Promise { alias: 'm', type: 'string', description: `Model`, - default: process.env['GEMINI_MODEL'], }) .option('prompt', { alias: 'p', @@ -550,6 +550,16 @@ export async function loadCliConfig( ); } + const useModelRouter = settings.experimental?.useModelRouter ?? false; + const defaultModel = useModelRouter + ? DEFAULT_GEMINI_MODEL_AUTO + : DEFAULT_GEMINI_MODEL; + const resolvedModel: string = + argv.model || + process.env['GEMINI_MODEL'] || + settings.model?.name || + defaultModel; + const sandboxConfig = await loadSandboxConfig(settings, argv); const screenReader = argv.screenReader !== undefined @@ -611,7 +621,7 @@ export async function loadCliConfig( cwd, fileDiscoveryService: fileService, bugCommand: settings.advanced?.bugCommand, - model: argv.model || settings.model?.name || DEFAULT_GEMINI_MODEL, + model: resolvedModel, extensionContextFilePaths, maxSessionTurns: settings.model?.maxSessionTurns ?? -1, experimentalZedIntegration: argv.experimentalAcp || false, @@ -637,6 +647,7 @@ export async function loadCliConfig( output: { format: (argv.outputFormat ?? settings.output?.format) as OutputFormat, }, + useModelRouter, }); } diff --git a/packages/cli/src/config/settingsSchema.test.ts b/packages/cli/src/config/settingsSchema.test.ts index 47fc91c108..4088151049 100644 --- a/packages/cli/src/config/settingsSchema.test.ts +++ b/packages/cli/src/config/settingsSchema.test.ts @@ -315,5 +315,20 @@ describe('SettingsSchema', () => { .description, ).toBe('Enable debug logging of keystrokes to the console.'); }); + + it('should have useModelRouter setting in schema', () => { + expect( + getSettingsSchema().experimental.properties.useModelRouter, + ).toBeDefined(); + expect( + getSettingsSchema().experimental.properties.useModelRouter.type, + ).toBe('boolean'); + expect( + getSettingsSchema().experimental.properties.useModelRouter.category, + ).toBe('Experimental'); + expect( + getSettingsSchema().experimental.properties.useModelRouter.default, + ).toBe(false); + }); }); }); diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index e8c8032add..4a4b3902e3 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -962,6 +962,16 @@ const SETTINGS_SCHEMA = { description: 'Enable extension management features.', showInDialog: false, }, + useModelRouter: { + type: 'boolean', + label: 'Use Model Router', + category: 'Experimental', + requiresRestart: true, + default: false, + description: + 'Enable model routing to route requests to the best model based on complexity.', + showInDialog: false, + }, }, }, diff --git a/packages/core/index.ts b/packages/core/index.ts index d746c9a082..daa3771150 100644 --- a/packages/core/index.ts +++ b/packages/core/index.ts @@ -8,6 +8,7 @@ export * from './src/index.js'; export { Storage } from './src/config/storage.js'; export { DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_EMBEDDING_MODEL, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index bb4a7f5827..3a0762094d 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -523,6 +523,31 @@ describe('Server Config (config.ts)', () => { }); }); + describe('UseModelRouter Configuration', () => { + it('should default useModelRouter to false when not provided', () => { + const config = new Config(baseParams); + expect(config.getUseModelRouter()).toBe(false); + }); + + it('should set useModelRouter to true when provided as true', () => { + const paramsWithModelRouter: ConfigParameters = { + ...baseParams, + useModelRouter: true, + }; + const config = new Config(paramsWithModelRouter); + expect(config.getUseModelRouter()).toBe(true); + }); + + it('should set useModelRouter to false when explicitly provided as false', () => { + const paramsWithModelRouter: ConfigParameters = { + ...baseParams, + useModelRouter: false, + }; + const config = new Config(paramsWithModelRouter); + expect(config.getUseModelRouter()).toBe(false); + }); + }); + describe('createToolRegistry', () => { it('should register a tool if coreTools contains an argument-specific pattern', async () => { const params: ConfigParameters = { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 0fcbc0a61c..19308ebfa2 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -44,7 +44,6 @@ import { StartSessionEvent } from '../telemetry/index.js'; import { DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_MODEL, } from './models.js'; import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; import type { MCPOAuthConfig } from '../mcp/oauth-provider.js'; @@ -240,6 +239,7 @@ export interface ConfigParameters { useSmartEdit?: boolean; policyEngineConfig?: PolicyEngineConfig; output?: OutputSettings; + useModelRouter?: boolean; } export class Config { @@ -327,6 +327,7 @@ export class Config { private readonly messageBus: MessageBus; private readonly policyEngine: PolicyEngine; private readonly outputSettings: OutputSettings; + private readonly useModelRouter: boolean; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -376,7 +377,7 @@ export class Config { this.cwd = params.cwd ?? process.cwd(); this.fileDiscoveryService = params.fileDiscoveryService ?? null; this.bugCommand = params.bugCommand; - this.model = params.model || DEFAULT_GEMINI_MODEL; + this.model = params.model; this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; this.maxSessionTurns = params.maxSessionTurns ?? -1; this.experimentalZedIntegration = @@ -411,6 +412,7 @@ export class Config { this.enableToolOutputTruncation = params.enableToolOutputTruncation ?? false; this.useSmartEdit = params.useSmartEdit ?? true; + this.useModelRouter = params.useModelRouter ?? false; this.extensionManagement = params.extensionManagement ?? true; this.storage = new Storage(this.targetDir); this.enablePromptCompletion = params.enablePromptCompletion ?? false; @@ -931,6 +933,10 @@ export class Config { : OutputFormat.TEXT; } + getUseModelRouter(): boolean { + return this.useModelRouter; + } + async getGitService(): Promise { if (!this.gitService) { this.gitService = new GitService(this.targetDir, this.storage); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index a0aa73bfdd..d1f07d60d6 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -8,6 +8,8 @@ export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro'; export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash'; export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite'; +export const DEFAULT_GEMINI_MODEL_AUTO = 'auto'; + export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001'; // Some thinking models do not default to dynamic thinking which is done by a value of -1 diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 21392ed12b..86405ba5a8 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -33,7 +33,10 @@ import type { ChatRecordingService } from '../services/chatRecordingService.js'; import type { ContentGenerator } from './contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_THINKING_MODE, + getEffectiveModel, } from '../config/models.js'; import { LoopDetectionService } from '../services/loopDetectionService.js'; import { ideContextStore } from '../ide/ideContext.js'; @@ -52,14 +55,14 @@ import { handleFallback } from '../fallback/handler.js'; import type { RoutingContext } from '../routing/routingStrategy.js'; export function isThinkingSupported(model: string) { - if (model.startsWith('gemini-2.5')) return true; - return false; + return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO; } export function isThinkingDefault(model: string) { - if (model.startsWith('gemini-2.5-flash-lite')) return false; - if (model.startsWith('gemini-2.5')) return true; - return false; + if (model.startsWith('gemini-2.5-flash-lite')) { + return false; + } + return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO; } /** @@ -227,23 +230,21 @@ export class GeminiClient { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); const model = this.config.getModel(); - const generateContentConfigWithThinking = isThinkingSupported(model) - ? { - ...this.generateContentConfig, - thinkingConfig: { - thinkingBudget: -1, - includeThoughts: true, - ...(!isThinkingDefault(model) - ? { thinkingBudget: DEFAULT_THINKING_MODE } - : {}), - }, - } - : this.generateContentConfig; + + const config: GenerateContentConfig = { ...this.generateContentConfig }; + + if (isThinkingSupported(model)) { + config.thinkingConfig = { + includeThoughts: true, + thinkingBudget: DEFAULT_THINKING_MODE, + }; + } + return new GeminiChat( this.config, { systemInstruction, - ...generateContentConfigWithThinking, + ...config, tools, }, history, @@ -790,6 +791,18 @@ export class GeminiClient { prompt_id: string, force: boolean = false, ): Promise { + // If the model is 'auto', we will use a placeholder model to check. + // Compression occurs before we choose a model, so calling `count_tokens` + // before the model is chosen would result in an error. + const configModel = this.config.getModel(); + let model: string = + configModel === DEFAULT_GEMINI_MODEL_AUTO + ? DEFAULT_GEMINI_MODEL + : configModel; + + // Check if the model needs to be a fallback + model = getEffectiveModel(this.config.isInFallbackMode(), model); + const curatedHistory = this.getChat().getHistory(true); // Regardless of `force`, don't do anything if the history is empty. @@ -804,8 +817,6 @@ export class GeminiClient { }; } - const model = this.config.getModel(); - const { totalTokens: originalTokenCount } = await this.getContentGeneratorOrFail().countTokens({ model, diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 9f7f19ba9a..38a7485295 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -231,6 +231,7 @@ export class GeminiChat { : [params.message]; const userMessageContent = partListUnionToString(toParts(userMessage)); this.chatRecordingService.recordMessage({ + model, type: 'user', content: userMessageContent, }); @@ -371,7 +372,7 @@ export class GeminiChat { authType: this.config.getContentGeneratorConfig()?.authType, }); - return this.processStreamResponse(streamResponse, userContent); + return this.processStreamResponse(model, streamResponse, userContent); } /** @@ -474,6 +475,7 @@ export class GeminiChat { } private async *processStreamResponse( + model: string, streamResponse: AsyncGenerator, userInput: Content, ): AsyncGenerator { @@ -552,6 +554,7 @@ export class GeminiChat { if (responseText.trim()) { this.chatRecordingService.recordMessage({ + model, type: 'gemini', content: responseText, }); @@ -662,7 +665,10 @@ export class GeminiChat { * Records completed tool calls with full metadata. * This is called by external components when tool calls complete, before sending responses to Gemini. */ - recordCompletedToolCalls(toolCalls: CompletedToolCall[]): void { + recordCompletedToolCalls( + model: string, + toolCalls: CompletedToolCall[], + ): void { const toolCallRecords = toolCalls.map((call) => { const resultDisplayRaw = call.response?.resultDisplay; const resultDisplay = @@ -679,7 +685,7 @@ export class GeminiChat { }; }); - this.chatRecordingService.recordToolCalls(toolCallRecords); + this.chatRecordingService.recordToolCalls(model, toolCallRecords); } /** diff --git a/packages/core/src/routing/strategies/overrideStrategy.test.ts b/packages/core/src/routing/strategies/overrideStrategy.test.ts index 69c4088f8d..bc80a99ad3 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.test.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.test.ts @@ -9,15 +9,16 @@ import { OverrideStrategy } from './overrideStrategy.js'; import type { RoutingContext } from '../routingStrategy.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { Config } from '../../config/config.js'; +import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js'; describe('OverrideStrategy', () => { const strategy = new OverrideStrategy(); const mockContext = {} as RoutingContext; const mockClient = {} as BaseLlmClient; - it('should return null when no override model is specified', async () => { + it('should return null when the override model is auto', async () => { const mockConfig = { - getModel: () => '', // Simulate no model override + getModel: () => DEFAULT_GEMINI_MODEL_AUTO, } as Config; const decision = await strategy.route(mockContext, mockConfig, mockClient); diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts index b3aef6c332..06d6b7f3dd 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -5,6 +5,7 @@ */ import type { Config } from '../../config/config.js'; +import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { RoutingContext, @@ -24,17 +25,18 @@ export class OverrideStrategy implements RoutingStrategy { _baseLlmClient: BaseLlmClient, ): Promise { const overrideModel = config.getModel(); - if (overrideModel) { - return { - model: overrideModel, - metadata: { - source: this.name, - latencyMs: 0, - reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`, - }, - }; - } - // No override specified, pass to the next strategy. - return null; + + // If the model is 'auto' we should pass to the next strategy. + if (overrideModel === DEFAULT_GEMINI_MODEL_AUTO) return null; + + // Return the overridden model name. + return { + model: overrideModel, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`, + }, + }; } } diff --git a/packages/core/src/services/chatRecordingService.test.ts b/packages/core/src/services/chatRecordingService.test.ts index 50ee04b182..dcd77c986f 100644 --- a/packages/core/src/services/chatRecordingService.test.ts +++ b/packages/core/src/services/chatRecordingService.test.ts @@ -127,7 +127,11 @@ describe('ChatRecordingService', () => { const writeFileSyncSpy = vi .spyOn(fs, 'writeFileSync') .mockImplementation(() => undefined); - chatRecordingService.recordMessage({ type: 'user', content: 'Hello' }); + chatRecordingService.recordMessage({ + type: 'user', + content: 'Hello', + model: 'gemini-pro', + }); expect(mkdirSyncSpy).toHaveBeenCalled(); expect(writeFileSyncSpy).toHaveBeenCalled(); const conversation = JSON.parse( @@ -161,6 +165,7 @@ describe('ChatRecordingService', () => { chatRecordingService.recordMessage({ type: 'user', content: 'World', + model: 'gemini-pro', }); expect(mkdirSyncSpy).toHaveBeenCalled(); @@ -311,7 +316,7 @@ describe('ChatRecordingService', () => { status: 'awaiting_approval', timestamp: new Date().toISOString(), }; - chatRecordingService.recordToolCalls([toolCall]); + chatRecordingService.recordToolCalls('gemini-pro', [toolCall]); expect(mkdirSyncSpy).toHaveBeenCalled(); expect(writeFileSyncSpy).toHaveBeenCalled(); @@ -358,7 +363,7 @@ describe('ChatRecordingService', () => { status: 'awaiting_approval', timestamp: new Date().toISOString(), }; - chatRecordingService.recordToolCalls([toolCall]); + chatRecordingService.recordToolCalls('gemini-pro', [toolCall]); expect(mkdirSyncSpy).toHaveBeenCalled(); expect(writeFileSyncSpy).toHaveBeenCalled(); diff --git a/packages/core/src/services/chatRecordingService.ts b/packages/core/src/services/chatRecordingService.ts index 6179ae7d08..3b27b2736e 100644 --- a/packages/core/src/services/chatRecordingService.ts +++ b/packages/core/src/services/chatRecordingService.ts @@ -195,6 +195,7 @@ export class ChatRecordingService { * Records a message in the conversation. */ recordMessage(message: { + model: string; type: ConversationRecordExtra['type']; content: PartListUnion; }): void { @@ -209,7 +210,7 @@ export class ChatRecordingService { ...msg, thoughts: this.queuedThoughts, tokens: this.queuedTokens, - model: this.config.getModel(), + model: message.model, }); this.queuedThoughts = []; this.queuedTokens = null; @@ -279,7 +280,7 @@ export class ChatRecordingService { * Adds tool calls to the last message in the conversation (which should be by Gemini). * This method enriches tool calls with metadata from the ToolRegistry. */ - recordToolCalls(toolCalls: ToolCallRecord[]): void { + recordToolCalls(model: string, toolCalls: ToolCallRecord[]): void { if (!this.conversationFile) return; // Enrich tool calls with metadata from the ToolRegistry @@ -318,7 +319,7 @@ export class ChatRecordingService { type: 'gemini' as const, toolCalls: enrichedToolCalls, thoughts: this.queuedThoughts, - model: this.config.getModel(), + model, }; // If there are any queued thoughts join them to this message. if (this.queuedThoughts.length > 0) {