diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 5c49e50ec1..7bcf1ff941 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -556,29 +556,84 @@ describe('Server Config (config.ts)', () => { }); }); - describe('UseModelRouter Configuration', () => { - it('should default useModelRouter to false when not provided', () => { - const config = new Config(baseParams); + describe('Model Router with Auth', () => { + it('should disable model router by default for oauth-personal', async () => { + const config = new Config({ + ...baseParams, + useModelRouter: true, + }); + await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); expect(config.getUseModelRouter()).toBe(false); }); - it('should set useModelRouter to true when provided as true', () => { - const paramsWithModelRouter: ConfigParameters = { + it('should enable model router by default for other auth types', async () => { + const config = new Config({ ...baseParams, useModelRouter: true, - }; - const config = new Config(paramsWithModelRouter); + }); + await config.refreshAuth(AuthType.USE_GEMINI); expect(config.getUseModelRouter()).toBe(true); }); - it('should set useModelRouter to false when explicitly provided as false', () => { - const paramsWithModelRouter: ConfigParameters = { + it('should disable model router for specified auth type', async () => { + const config = new Config({ + ...baseParams, + useModelRouter: true, + disableModelRouterForAuth: [AuthType.USE_GEMINI], + }); + await config.refreshAuth(AuthType.USE_GEMINI); + expect(config.getUseModelRouter()).toBe(false); + }); + + it('should enable model router for other auth type', async () => { + const config = new Config({ + ...baseParams, + useModelRouter: true, + disableModelRouterForAuth: [], + }); + await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); + expect(config.getUseModelRouter()).toBe(true); + }); + + it('should keep model router disabled when useModelRouter is false', async () => { + const config = new Config({ ...baseParams, useModelRouter: false, - }; - const config = new Config(paramsWithModelRouter); + disableModelRouterForAuth: [AuthType.USE_GEMINI], + }); + await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); expect(config.getUseModelRouter()).toBe(false); }); + + it('should keep the user-chosen model after refreshAuth, even when model router is disabled for the auth type', async () => { + const config = new Config({ + ...baseParams, + useModelRouter: true, + disableModelRouterForAuth: [AuthType.USE_GEMINI], + }); + const chosenModel = 'gemini-1.5-pro-latest'; + config.setModel(chosenModel); + + await config.refreshAuth(AuthType.USE_GEMINI); + + expect(config.getUseModelRouter()).toBe(false); + expect(config.getModel()).toBe(chosenModel); + }); + + it('should keep the user-chosen model after refreshAuth, when model router is enabled for the auth type', async () => { + const config = new Config({ + ...baseParams, + useModelRouter: true, + disableModelRouterForAuth: [AuthType.USE_GEMINI], + }); + const chosenModel = 'gemini-1.5-pro-latest'; + config.setModel(chosenModel); + + await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE); + + expect(config.getUseModelRouter()).toBe(true); + expect(config.getModel()).toBe(chosenModel); + }); }); describe('ContinueOnFailedApiCall Configuration', () => { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 6b683ac5ac..878c2fe782 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -46,6 +46,7 @@ import { DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_THINKING_MODE, } from './models.js'; import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; @@ -279,6 +280,7 @@ export interface ConfigParameters { output?: OutputSettings; useModelRouter?: boolean; enableMessageBusIntegration?: boolean; + disableModelRouterForAuth?: AuthType[]; codebaseInvestigatorSettings?: CodebaseInvestigatorSettings; continueOnFailedApiCall?: boolean; retryFetchErrors?: boolean; @@ -377,7 +379,9 @@ export class Config { private readonly messageBus: MessageBus; private readonly policyEngine: PolicyEngine; private readonly outputSettings: OutputSettings; - private readonly useModelRouter: boolean; + private useModelRouter: boolean; + private readonly initialUseModelRouter: boolean; + private readonly disableModelRouterForAuth?: AuthType[]; private readonly enableMessageBusIntegration: boolean; private readonly codebaseInvestigatorSettings: CodebaseInvestigatorSettings; private readonly continueOnFailedApiCall: boolean; @@ -477,7 +481,11 @@ export class Config { this.enableToolOutputTruncation = params.enableToolOutputTruncation ?? true; this.useSmartEdit = params.useSmartEdit ?? true; this.useWriteTodos = params.useWriteTodos ?? false; - this.useModelRouter = params.useModelRouter ?? false; + this.initialUseModelRouter = params.useModelRouter ?? false; + this.useModelRouter = this.initialUseModelRouter; + this.disableModelRouterForAuth = params.disableModelRouterForAuth ?? [ + AuthType.LOGIN_WITH_GOOGLE, + ]; this.enableMessageBusIntegration = params.enableMessageBusIntegration ?? false; this.codebaseInvestigatorSettings = { @@ -551,6 +559,16 @@ export class Config { } async refreshAuth(authMethod: AuthType) { + this.useModelRouter = this.initialUseModelRouter; + if (this.disableModelRouterForAuth?.includes(authMethod)) { + this.useModelRouter = false; + if (this.model === DEFAULT_GEMINI_MODEL_AUTO) { + this.model = DEFAULT_GEMINI_MODEL; + } + } else if (this.useModelRouter && this.model === DEFAULT_GEMINI_MODEL) { + this.model = DEFAULT_GEMINI_MODEL_AUTO; + } + // Vertex and Genai have incompatible encryption and sending history with // thoughtSignature from Genai to Vertex will fail, we need to strip them if (