diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index 55b1751484..8e6c3ea895 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -9,6 +9,7 @@ import { resolveModel, resolveClassifierModel, isGemini2Model, + isAutoModel, getDisplayString, DEFAULT_GEMINI_MODEL, PREVIEW_GEMINI_MODEL, @@ -18,6 +19,7 @@ import { GEMINI_MODEL_ALIAS_PRO, GEMINI_MODEL_ALIAS_FLASH, GEMINI_MODEL_ALIAS_FLASH_LITE, + GEMINI_MODEL_ALIAS_AUTO, PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, @@ -171,6 +173,26 @@ describe('isGemini2Model', () => { }); }); +describe('isAutoModel', () => { + it('should return true for "auto"', () => { + expect(isAutoModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(true); + }); + + it('should return true for "auto-gemini-3"', () => { + expect(isAutoModel(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return true for "auto-gemini-2.5"', () => { + expect(isAutoModel(DEFAULT_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return false for concrete models', () => { + expect(isAutoModel(DEFAULT_GEMINI_MODEL)).toBe(false); + expect(isAutoModel(PREVIEW_GEMINI_MODEL)).toBe(false); + expect(isAutoModel('some-random-model')).toBe(false); + }); +}); + describe('resolveClassifierModel', () => { it('should return flash model when alias is flash', () => { expect( diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index ca87ee2d40..4475a5db97 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -146,6 +146,20 @@ export function isGemini2Model(model: string): boolean { return /^gemini-2(\.|$)/.test(model); } +/** + * Checks if the model is an auto model. + * + * @param model The model name to check. + * @returns True if the model is an auto model. + */ +export function isAutoModel(model: string): boolean { + return ( + model === GEMINI_MODEL_ALIAS_AUTO || + model === PREVIEW_GEMINI_MODEL_AUTO || + model === DEFAULT_GEMINI_MODEL_AUTO + ); +} + /** * Checks if the model supports multimodal function responses (multimodal data nested within function response). * This is supported in Gemini 3. diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index bf70aa2200..3a8603ae65 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -606,6 +606,7 @@ export class GeminiClient { history: this.getChat().getHistory(/*curated=*/ true), request, signal, + requestedModel: this.config.getModel(), }; let modelToUse: string; diff --git a/packages/core/src/routing/routingStrategy.ts b/packages/core/src/routing/routingStrategy.ts index d5d8df8dc9..de8bcf04f1 100644 --- a/packages/core/src/routing/routingStrategy.ts +++ b/packages/core/src/routing/routingStrategy.ts @@ -35,6 +35,8 @@ export interface RoutingContext { request: PartListUnion; /** An abort signal to cancel an LLM call during routing. */ signal: AbortSignal; + /** The model string requested for this turn, if any. */ + requestedModel?: string; } /** diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index 21d324c1fb..e883b0be45 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -281,4 +281,30 @@ describe('ClassifierStrategy', () => { ); consoleWarnSpy.mockRestore(); }); + + it('should respect requestedModel from context in resolveClassifierModel', async () => { + const requestedModel = DEFAULT_GEMINI_MODEL; // Pro model + const mockApiResponse = { + reasoning: 'Choice is flash', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const contextWithRequestedModel = { + ...mockContext, + requestedModel, + } as RoutingContext; + + const decision = await strategy.route( + contextWithRequestedModel, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).not.toBeNull(); + // Since requestedModel is Pro, and choice is flash, it should resolve to Flash + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); }); diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 4747bc5352..59c5ff6fca 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -168,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; const selectedModel = resolveClassifierModel( - config.getModel(), + context.requestedModel ?? config.getModel(), routerResponse.model_choice, config.getPreviewFeatures(), ); diff --git a/packages/core/src/routing/strategies/fallbackStrategy.test.ts b/packages/core/src/routing/strategies/fallbackStrategy.test.ts index 6196e59526..2d30b153e5 100644 --- a/packages/core/src/routing/strategies/fallbackStrategy.test.ts +++ b/packages/core/src/routing/strategies/fallbackStrategy.test.ts @@ -108,4 +108,25 @@ describe('FallbackStrategy', () => { // Important: check that it queried snapshot with the RESOLVED model, not 'auto' expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL); }); + + it('should respect requestedModel from context', async () => { + const requestedModel = 'requested-model'; + const configModel = 'config-model'; + vi.mocked(mockConfig.getModel).mockReturnValue(configModel); + vi.mocked(mockService.snapshot).mockReturnValue({ available: true }); + + const contextWithRequestedModel = { + requestedModel, + } as RoutingContext; + + const decision = await strategy.route( + contextWithRequestedModel, + mockConfig, + mockClient, + ); + + expect(decision).toBeNull(); + // Should check availability of the requested model from context + expect(mockService.snapshot).toHaveBeenCalledWith(requestedModel); + }); }); diff --git a/packages/core/src/routing/strategies/fallbackStrategy.ts b/packages/core/src/routing/strategies/fallbackStrategy.ts index dbaa484094..383f441713 100644 --- a/packages/core/src/routing/strategies/fallbackStrategy.ts +++ b/packages/core/src/routing/strategies/fallbackStrategy.ts @@ -18,11 +18,11 @@ export class FallbackStrategy implements RoutingStrategy { readonly name = 'fallback'; async route( - _context: RoutingContext, + context: RoutingContext, config: Config, _baseLlmClient: BaseLlmClient, ): Promise { - const requestedModel = config.getModel(); + const requestedModel = context.requestedModel ?? config.getModel(); const resolvedModel = resolveModel( requestedModel, config.getPreviewFeatures(), diff --git a/packages/core/src/routing/strategies/overrideStrategy.test.ts b/packages/core/src/routing/strategies/overrideStrategy.test.ts index f1ec54098d..97e9f4915f 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.test.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.test.ts @@ -56,4 +56,25 @@ describe('OverrideStrategy', () => { expect(decision).not.toBeNull(); expect(decision?.model).toBe(overrideModel); }); + + it('should respect requestedModel from context', async () => { + const requestedModel = 'requested-model'; + const configModel = 'config-model'; + const mockConfig = { + getModel: () => configModel, + getPreviewFeatures: () => false, + } as Config; + const contextWithRequestedModel = { + requestedModel, + } as RoutingContext; + + const decision = await strategy.route( + contextWithRequestedModel, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(requestedModel); + }); }); diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts index 6a4c2a50d2..c5f632ca3d 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -5,11 +5,7 @@ */ import type { Config } from '../../config/config.js'; -import { - DEFAULT_GEMINI_MODEL_AUTO, - PREVIEW_GEMINI_MODEL_AUTO, - resolveModel, -} from '../../config/models.js'; +import { isAutoModel, resolveModel } from '../../config/models.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { RoutingContext, @@ -24,18 +20,16 @@ export class OverrideStrategy implements RoutingStrategy { readonly name = 'override'; async route( - _context: RoutingContext, + context: RoutingContext, config: Config, _baseLlmClient: BaseLlmClient, ): Promise { - const overrideModel = config.getModel(); + const overrideModel = context.requestedModel ?? config.getModel(); // If the model is 'auto' we should pass to the next strategy. - if ( - overrideModel === DEFAULT_GEMINI_MODEL_AUTO || - overrideModel === PREVIEW_GEMINI_MODEL_AUTO - ) + if (isAutoModel(overrideModel)) { return null; + } // Return the overridden model name. return { diff --git a/packages/core/src/services/modelConfigService.test.ts b/packages/core/src/services/modelConfigService.test.ts index 8d08e4f775..ee6cd09f40 100644 --- a/packages/core/src/services/modelConfigService.test.ts +++ b/packages/core/src/services/modelConfigService.test.ts @@ -577,6 +577,81 @@ describe('ModelConfigService', () => { }); }); + describe('runtime overrides', () => { + it('should resolve a simple runtime-registered override', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [], + }; + const service = new ModelConfigService(config); + + service.registerRuntimeModelOverride({ + match: { model: 'gemini-pro' }, + modelConfig: { + generateContentConfig: { + temperature: 0.99, + }, + }, + }); + + const resolved = service.getResolvedConfig({ model: 'gemini-pro' }); + + expect(resolved.model).toBe('gemini-pro'); + expect(resolved.generateContentConfig.temperature).toBe(0.99); + }); + + it('should prioritize runtime overrides over default overrides when they have the same specificity', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [ + { + match: { model: 'gemini-pro' }, + modelConfig: { generateContentConfig: { temperature: 0.1 } }, + }, + ], + }; + const service = new ModelConfigService(config); + + service.registerRuntimeModelOverride({ + match: { model: 'gemini-pro' }, + modelConfig: { generateContentConfig: { temperature: 0.9 } }, + }); + + const resolved = service.getResolvedConfig({ model: 'gemini-pro' }); + + // Runtime overrides are appended after overrides/customOverrides, so they should win. + expect(resolved.generateContentConfig.temperature).toBe(0.9); + }); + + it('should still respect specificity with runtime overrides', () => { + const config: ModelConfigServiceConfig = { + aliases: {}, + overrides: [], + }; + const service = new ModelConfigService(config); + + // Register a more specific runtime override + service.registerRuntimeModelOverride({ + match: { model: 'gemini-pro', overrideScope: 'my-agent' }, + modelConfig: { generateContentConfig: { temperature: 0.1 } }, + }); + + // Register a less specific runtime override later + service.registerRuntimeModelOverride({ + match: { model: 'gemini-pro' }, + modelConfig: { generateContentConfig: { temperature: 0.9 } }, + }); + + const resolved = service.getResolvedConfig({ + model: 'gemini-pro', + overrideScope: 'my-agent', + }); + + // Specificity should win over order + expect(resolved.generateContentConfig.temperature).toBe(0.1); + }); + }); + describe('custom aliases', () => { it('should resolve a custom alias', () => { const config: ModelConfigServiceConfig = { diff --git a/packages/core/src/services/modelConfigService.ts b/packages/core/src/services/modelConfigService.ts index 0b86baa4ad..6fb712243c 100644 --- a/packages/core/src/services/modelConfigService.ts +++ b/packages/core/src/services/modelConfigService.ts @@ -65,6 +65,7 @@ export interface _ResolvedModelConfig { export class ModelConfigService { private readonly runtimeAliases: Record = {}; + private readonly runtimeOverrides: ModelConfigOverride[] = []; // TODO(12597): Process config to build a typed alias hierarchy. constructor(private readonly config: ModelConfigServiceConfig) {} @@ -73,6 +74,10 @@ export class ModelConfigService { this.runtimeAliases[aliasName] = alias; } + registerRuntimeModelOverride(override: ModelConfigOverride): void { + this.runtimeOverrides.push(override); + } + private resolveAlias( aliasName: string, aliases: Record, @@ -123,7 +128,11 @@ export class ModelConfigService { ...customAliases, ...this.runtimeAliases, }; - const allOverrides = [...overrides, ...customOverrides]; + const allOverrides = [ + ...overrides, + ...customOverrides, + ...this.runtimeOverrides, + ]; let baseModel: string | undefined = context.model; let resolvedConfig: GenerateContentConfig = {};