diff --git a/packages/core/src/availability/policyHelpers.test.ts b/packages/core/src/availability/policyHelpers.test.ts index 2eb6129f61..23c6ef4fd4 100644 --- a/packages/core/src/availability/policyHelpers.test.ts +++ b/packages/core/src/availability/policyHelpers.test.ts @@ -20,14 +20,21 @@ import { } from '../config/models.js'; import { AuthType } from '../core/contentGenerator.js'; -const createMockConfig = (overrides: Partial = {}): Config => - ({ +const createMockConfig = (overrides: Partial = {}): Config => { + const config = { getUserTier: () => undefined, getModel: () => 'gemini-2.5-pro', getGemini31LaunchedSync: () => false, + getUseCustomToolModelSync: () => { + const useGemini31 = config.getGemini31LaunchedSync(); + const authType = config.getContentGeneratorConfig().authType; + return useGemini31 && authType === AuthType.USE_GEMINI; + }, getContentGeneratorConfig: () => ({ authType: undefined }), ...overrides, - }) as unknown as Config; + } as unknown as Config; + return config; +}; describe('policyHelpers', () => { describe('resolvePolicyChain', () => { diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index 47c465585c..406abde5e3 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -6,7 +6,6 @@ import type { GenerateContentConfig } from '@google/genai'; import type { Config } from '../config/config.js'; -import { AuthType } from '../core/contentGenerator.js'; import type { FailureKind, FallbackAction, @@ -46,9 +45,7 @@ export function resolvePolicyChain( let chain; const useGemini31 = config.getGemini31LaunchedSync?.() ?? false; - const useCustomToolModel = - useGemini31 && - config.getContentGeneratorConfig?.()?.authType === AuthType.USE_GEMINI; + const useCustomToolModel = config.getUseCustomToolModelSync?.() ?? false; const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true; const resolvedModel = resolveModel( diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index ba8f5d508b..86cdf584b5 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -2529,6 +2529,26 @@ export class Config implements McpContext, AgentLoopContext { return this.getGemini31LaunchedSync(); } + /** + * Returns whether the custom tool model should be used. + */ + async getUseCustomToolModel(): Promise { + const useGemini3_1 = await this.getGemini31Launched(); + const authType = this.contentGeneratorConfig?.authType; + return useGemini3_1 && authType === AuthType.USE_GEMINI; + } + + /** + * Returns whether the custom tool model should be used. + * + * Note: This method should only be called after startup, once experiments have been loaded. + */ + getUseCustomToolModelSync(): boolean { + const useGemini3_1 = this.getGemini31LaunchedSync(); + const authType = this.contentGeneratorConfig?.authType; + return useGemini3_1 && authType === AuthType.USE_GEMINI; + } + /** * Returns whether Gemini 3.1 has been launched. * diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 32014d5fbd..ffbf597793 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -168,7 +168,8 @@ export function isPreviewModel(model: string): boolean { model === PREVIEW_GEMINI_3_1_MODEL || model === PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL || model === PREVIEW_GEMINI_FLASH_MODEL || - model === PREVIEW_GEMINI_MODEL_AUTO + model === PREVIEW_GEMINI_MODEL_AUTO || + model === GEMINI_MODEL_ALIAS_AUTO ); } diff --git a/packages/core/src/routing/strategies/approvalModeStrategy.test.ts b/packages/core/src/routing/strategies/approvalModeStrategy.test.ts index 4a332ec77f..123a2329a6 100644 --- a/packages/core/src/routing/strategies/approvalModeStrategy.test.ts +++ b/packages/core/src/routing/strategies/approvalModeStrategy.test.ts @@ -15,7 +15,9 @@ import { PREVIEW_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_AUTO, } from '../../config/models.js'; +import { AuthType } from '../../core/contentGenerator.js'; import { ApprovalMode } from '../../policy/types.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; @@ -40,6 +42,15 @@ describe('ApprovalModeStrategy', () => { getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT), getApprovedPlanPath: vi.fn().mockReturnValue(undefined), getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true), + getGemini31Launched: vi.fn().mockResolvedValue(false), + getUseCustomToolModel: vi.fn().mockImplementation(async () => { + const launched = await mockConfig.getGemini31Launched(); + const authType = mockConfig.getContentGeneratorConfig?.()?.authType; + return launched && authType === AuthType.USE_GEMINI; + }), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), } as unknown as Config; mockBaseLlmClient = {} as BaseLlmClient; @@ -184,4 +195,50 @@ describe('ApprovalModeStrategy', () => { expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL); }); + + it('should route to Preview models when using "auto" alias', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO); + vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL); + + vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT); + vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue( + '/path/to/plan.md', + ); + + const implementationDecision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(implementationDecision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL); + }); + + it('should route to Preview Flash model when an approved plan exists and Gemini 3.1 is launched', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO); + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + + // Exit plan mode with approved plan + vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT); + vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue( + '/path/to/plan.md', + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + // Should resolve to Preview Flash (3.0) because resolveClassifierModel uses preview variants for Gemini 3 + expect(decision?.model).toBe(PREVIEW_GEMINI_FLASH_MODEL); + }); }); diff --git a/packages/core/src/routing/strategies/approvalModeStrategy.ts b/packages/core/src/routing/strategies/approvalModeStrategy.ts index 63b331f5a1..403a4c3176 100644 --- a/packages/core/src/routing/strategies/approvalModeStrategy.ts +++ b/packages/core/src/routing/strategies/approvalModeStrategy.ts @@ -6,12 +6,10 @@ import type { Config } from '../../config/config.js'; import { - DEFAULT_GEMINI_MODEL, - DEFAULT_GEMINI_FLASH_MODEL, - PREVIEW_GEMINI_MODEL, - PREVIEW_GEMINI_FLASH_MODEL, isAutoModel, - isPreviewModel, + resolveClassifierModel, + GEMINI_MODEL_ALIAS_FLASH, + GEMINI_MODEL_ALIAS_PRO, } from '../../config/models.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { ApprovalMode } from '../../policy/types.js'; @@ -50,11 +48,19 @@ export class ApprovalModeStrategy implements RoutingStrategy { const approvalMode = config.getApprovalMode(); const approvedPlanPath = config.getApprovedPlanPath(); - const isPreview = isPreviewModel(model); + const [useGemini3_1, useCustomToolModel] = await Promise.all([ + config.getGemini31Launched(), + config.getUseCustomToolModel(), + ]); // 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model. if (approvalMode === ApprovalMode.PLAN) { - const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL; + const proModel = resolveClassifierModel( + model, + GEMINI_MODEL_ALIAS_PRO, + useGemini3_1, + useCustomToolModel, + ); return { model: proModel, metadata: { @@ -65,9 +71,12 @@ export class ApprovalModeStrategy implements RoutingStrategy { }; } else if (approvedPlanPath) { // 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model. - const flashModel = isPreview - ? PREVIEW_GEMINI_FLASH_MODEL - : DEFAULT_GEMINI_FLASH_MODEL; + const flashModel = resolveClassifierModel( + model, + GEMINI_MODEL_ALIAS_FLASH, + useGemini3_1, + useCustomToolModel, + ); return { model: flashModel, metadata: { diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index 701e7de932..58908a7d3b 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -59,6 +59,11 @@ describe('ClassifierStrategy', () => { getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), getGemini31Launched: vi.fn().mockResolvedValue(false), + getUseCustomToolModel: vi.fn().mockImplementation(async () => { + const launched = await mockConfig.getGemini31Launched(); + const authType = mockConfig.getContentGeneratorConfig().authType; + return launched && authType === AuthType.USE_GEMINI; + }), getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.LOGIN_WITH_GOOGLE, }), diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 5fd6208b15..2040e7eccd 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -22,7 +22,6 @@ import { import { debugLogger } from '../../utils/debugLogger.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; -import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 4; @@ -172,10 +171,10 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; - const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; - const useCustomToolModel = - useGemini3_1 && - config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI; + const [useGemini3_1, useCustomToolModel] = await Promise.all([ + config.getGemini31Launched(), + config.getUseCustomToolModel(), + ]); const selectedModel = resolveClassifierModel( model, routerResponse.model_choice, diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 77fc69a218..7a0439bd19 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -58,6 +58,11 @@ describe('NumericalClassifierStrategy', () => { getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), getGemini31Launched: vi.fn().mockResolvedValue(false), + getUseCustomToolModel: vi.fn().mockImplementation(async () => { + const launched = await mockConfig.getGemini31Launched(); + const authType = mockConfig.getContentGeneratorConfig().authType; + return launched && authType === AuthType.USE_GEMINI; + }), getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.LOGIN_WITH_GOOGLE, }), diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 39805fb43c..1b5b67aac4 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -18,7 +18,6 @@ import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; -import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 8; @@ -185,10 +184,10 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config, config.getSessionId() || 'unknown-session', ); - const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; - const useCustomToolModel = - useGemini3_1 && - config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI; + const [useGemini3_1, useCustomToolModel] = await Promise.all([ + config.getGemini31Launched(), + config.getUseCustomToolModel(), + ]); const selectedModel = resolveClassifierModel( model, modelAlias,