diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index bd8fa9919a..8ece4ce56c 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest'; import { resolveModel, resolveClassifierModel, + isGemini3Model, isGemini2Model, isAutoModel, getDisplayString, @@ -24,6 +25,29 @@ import { DEFAULT_GEMINI_MODEL_AUTO, } from './models.js'; +describe('isGemini3Model', () => { + it('should return true for gemini-3 models', () => { + expect(isGemini3Model('gemini-3-pro-preview')).toBe(true); + expect(isGemini3Model('gemini-3-flash-preview')).toBe(true); + }); + + it('should return true for aliases that resolve to Gemini 3', () => { + expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true); + expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true); + expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return false for Gemini 2 models', () => { + expect(isGemini3Model('gemini-2.5-pro')).toBe(false); + expect(isGemini3Model('gemini-2.5-flash')).toBe(false); + expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false); + }); + + it('should return false for arbitrary strings', () => { + expect(isGemini3Model('gpt-4')).toBe(false); + }); +}); + describe('getDisplayString', () => { it('should return Auto (Gemini 3) for preview auto model', () => { expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)'); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index b23fe35dcc..5bdd15e792 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -120,6 +120,17 @@ export function isPreviewModel(model: string): boolean { ); } +/** + * Checks if the model is a Gemini 3 model. + * + * @param model The model name to check. + * @returns True if the model is a Gemini 3 model. + */ +export function isGemini3Model(model: string): boolean { + const resolved = resolveModel(model); + return /^gemini-3(\.|-|$)/.test(resolved); +} + /** * Checks if the model is a Gemini 2.x model. * diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index a516439557..b2c7a8797e 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -17,6 +17,7 @@ import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_MODEL_AUTO, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; @@ -50,7 +51,7 @@ describe('ClassifierStrategy', () => { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, - getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), } as unknown as Config; mockBaseLlmClient = { @@ -60,8 +61,9 @@ describe('ClassifierStrategy', () => { vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); }); - it('should return null if numerical routing is enabled', async () => { + it('should return null if numerical routing is enabled and model is Gemini 3', async () => { vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); const decision = await strategy.route( mockContext, @@ -73,6 +75,24 @@ describe('ClassifierStrategy', () => { expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); }); + it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => { + vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({ + reasoning: 'test', + model_choice: 'flash', + }); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).not.toBeNull(); + expect(mockBaseLlmClient.generateJson).toHaveBeenCalled(); + }); + it('should call generateJson with the correct parameters', async () => { const mockApiResponse = { reasoning: 'Simple task', diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 387151046b..b21bb5e471 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -12,7 +12,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; -import { resolveClassifierModel } from '../../config/models.js'; +import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { @@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy { ): Promise { const startTime = Date.now(); try { - if (await config.getNumericalRoutingEnabled()) { + const model = context.requestedModel ?? config.getModel(); + if ( + (await config.getNumericalRoutingEnabled()) && + isGemini3Model(model) + ) { return null; } @@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; const selectedModel = resolveClassifierModel( - context.requestedModel ?? config.getModel(), + model, routerResponse.model_choice, ); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 73c1d91efc..8767709f68 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js'; import type { Config } from '../../config/config.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { - DEFAULT_GEMINI_FLASH_MODEL, - DEFAULT_GEMINI_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, + DEFAULT_GEMINI_MODEL, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; @@ -46,7 +48,7 @@ describe('NumericalClassifierStrategy', () => { modelConfigService: { getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), }, - getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO), getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), @@ -75,6 +77,32 @@ describe('NumericalClassifierStrategy', () => { expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); }); + it('should return null if the model is not a Gemini 3 model', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); + }); + + it('should return null if the model is explicitly a Gemini 2 model', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); + }); + it('should call generateJson with the correct parameters and wrapped user content', async () => { const mockApiResponse = { complexity_reasoning: 'Simple task', @@ -119,7 +147,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -145,7 +173,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -171,7 +199,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 + model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 metadata: { source: 'NumericalClassifier (Strict)', latencyMs: expect.any(Number), @@ -197,7 +225,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Strict)', latencyMs: expect.any(Number), @@ -225,7 +253,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -251,7 +279,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -277,7 +305,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30 + model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30 metadata: { source: 'NumericalClassifier (Remote)', latencyMs: expect.any(Number), @@ -305,7 +333,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -332,7 +360,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_FLASH_MODEL, + model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), @@ -359,7 +387,7 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: DEFAULT_GEMINI_MODEL, + model: PREVIEW_GEMINI_MODEL, metadata: { source: 'NumericalClassifier (Control)', latencyMs: expect.any(Number), diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 10ccb6dc4f..5c31fa3057 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -12,7 +12,7 @@ import type { RoutingDecision, RoutingStrategy, } from '../routingStrategy.js'; -import { resolveClassifierModel } from '../../config/models.js'; +import { resolveClassifierModel, isGemini3Model } from '../../config/models.js'; import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; @@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy { ): Promise { const startTime = Date.now(); try { + const model = context.requestedModel ?? config.getModel(); if (!(await config.getNumericalRoutingEnabled())) { return null; } + if (!isGemini3Model(model)) { + return null; + } + const promptId = getPromptIdWithFallback('classifier-router'); const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT); @@ -176,10 +181,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config.getSessionId() || 'unknown-session', ); - const selectedModel = resolveClassifierModel( - config.getModel(), - modelAlias, - ); + const selectedModel = resolveClassifierModel(model, modelAlias); const latencyMs = Date.now() - startTime;