feat(routing): restrict numerical routing to Gemini 3 family (#18478)

This commit is contained in:
matt korwel
2026-02-10 08:25:21 -06:00
committed by GitHub
parent 79753ec5ec
commit 37f128a109
6 changed files with 112 additions and 23 deletions

View File

@@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest';
import { import {
resolveModel, resolveModel,
resolveClassifierModel, resolveClassifierModel,
isGemini3Model,
isGemini2Model, isGemini2Model,
isAutoModel, isAutoModel,
getDisplayString, getDisplayString,
@@ -24,6 +25,29 @@ import {
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
} from './models.js'; } from './models.js';
describe('isGemini3Model', () => {
it('should return true for gemini-3 models', () => {
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
});
it('should return true for aliases that resolve to Gemini 3', () => {
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO)).toBe(true);
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO)).toBe(true);
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});
it('should return false for Gemini 2 models', () => {
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
});
it('should return false for arbitrary strings', () => {
expect(isGemini3Model('gpt-4')).toBe(false);
});
});
describe('getDisplayString', () => { describe('getDisplayString', () => {
it('should return Auto (Gemini 3) for preview auto model', () => { it('should return Auto (Gemini 3) for preview auto model', () => {
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)'); expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');

View File

@@ -120,6 +120,17 @@ export function isPreviewModel(model: string): boolean {
); );
} }
/**
* Checks if the model is a Gemini 3 model.
*
* @param model The model name to check.
* @returns True if the model is a Gemini 3 model.
*/
export function isGemini3Model(model: string): boolean {
const resolved = resolveModel(model);
return /^gemini-3(\.|-|$)/.test(resolved);
}
/** /**
* Checks if the model is a Gemini 2.x model. * Checks if the model is a Gemini 2.x model.
* *

View File

@@ -17,6 +17,7 @@ import {
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../../config/models.js'; } from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js'; import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
@@ -50,7 +51,7 @@ describe('ClassifierStrategy', () => {
modelConfigService: { modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
}, },
getModel: () => DEFAULT_GEMINI_MODEL_AUTO, getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
} as unknown as Config; } as unknown as Config;
mockBaseLlmClient = { mockBaseLlmClient = {
@@ -60,8 +61,9 @@ describe('ClassifierStrategy', () => {
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id'); vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
}); });
it('should return null if numerical routing is enabled', async () => { it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true); vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
const decision = await strategy.route( const decision = await strategy.route(
mockContext, mockContext,
@@ -73,6 +75,24 @@ describe('ClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
}); });
it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
reasoning: 'test',
model_choice: 'flash',
});
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).not.toBeNull();
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
});
it('should call generateJson with the correct parameters', async () => { it('should call generateJson with the correct parameters', async () => {
const mockApiResponse = { const mockApiResponse = {
reasoning: 'Simple task', reasoning: 'Simple task',

View File

@@ -12,7 +12,7 @@ import type {
RoutingDecision, RoutingDecision,
RoutingStrategy, RoutingStrategy,
} from '../routingStrategy.js'; } from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js'; import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai'; import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js'; import type { Config } from '../../config/config.js';
import { import {
@@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> { ): Promise<RoutingDecision | null> {
const startTime = Date.now(); const startTime = Date.now();
try { try {
if (await config.getNumericalRoutingEnabled()) { const model = context.requestedModel ?? config.getModel();
if (
(await config.getNumericalRoutingEnabled()) &&
isGemini3Model(model)
) {
return null; return null;
} }
@@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning; const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime; const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel( const selectedModel = resolveClassifierModel(
context.requestedModel ?? config.getModel(), model,
routerResponse.model_choice, routerResponse.model_choice,
); );

View File

@@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js'; import type { Config } from '../../config/config.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { import {
DEFAULT_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL,
} from '../../config/models.js'; } from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js'; import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
@@ -46,7 +48,7 @@ describe('NumericalClassifierStrategy', () => {
modelConfigService: { modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig), getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
}, },
getModel: () => DEFAULT_GEMINI_MODEL_AUTO, getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined), getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
@@ -75,6 +77,32 @@ describe('NumericalClassifierStrategy', () => {
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
}); });
it('should return null if the model is not a Gemini 3 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
it('should return null if the model is explicitly a Gemini 2 model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toBeNull();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
it('should call generateJson with the correct parameters and wrapped user content', async () => { it('should call generateJson with the correct parameters and wrapped user content', async () => {
const mockApiResponse = { const mockApiResponse = {
complexity_reasoning: 'Simple task', complexity_reasoning: 'Simple task',
@@ -119,7 +147,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: { metadata: {
source: 'NumericalClassifier (Control)', source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -145,7 +173,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, model: PREVIEW_GEMINI_MODEL,
metadata: { metadata: {
source: 'NumericalClassifier (Control)', source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -171,7 +199,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: { metadata: {
source: 'NumericalClassifier (Strict)', source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -197,7 +225,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, model: PREVIEW_GEMINI_MODEL,
metadata: { metadata: {
source: 'NumericalClassifier (Strict)', source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -225,7 +253,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70 model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
metadata: { metadata: {
source: 'NumericalClassifier (Remote)', source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -251,7 +279,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5 model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
metadata: { metadata: {
source: 'NumericalClassifier (Remote)', source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -277,7 +305,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30 model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
metadata: { metadata: {
source: 'NumericalClassifier (Remote)', source: 'NumericalClassifier (Remote)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -305,7 +333,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
metadata: { metadata: {
source: 'NumericalClassifier (Control)', source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -332,7 +360,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL, model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: { metadata: {
source: 'NumericalClassifier (Control)', source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),
@@ -359,7 +387,7 @@ describe('NumericalClassifierStrategy', () => {
); );
expect(decision).toEqual({ expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL, model: PREVIEW_GEMINI_MODEL,
metadata: { metadata: {
source: 'NumericalClassifier (Control)', source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number), latencyMs: expect.any(Number),

View File

@@ -12,7 +12,7 @@ import type {
RoutingDecision, RoutingDecision,
RoutingStrategy, RoutingStrategy,
} from '../routingStrategy.js'; } from '../routingStrategy.js';
import { resolveClassifierModel } from '../../config/models.js'; import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
import { createUserContent, Type } from '@google/genai'; import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js'; import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
@@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
): Promise<RoutingDecision | null> { ): Promise<RoutingDecision | null> {
const startTime = Date.now(); const startTime = Date.now();
try { try {
const model = context.requestedModel ?? config.getModel();
if (!(await config.getNumericalRoutingEnabled())) { if (!(await config.getNumericalRoutingEnabled())) {
return null; return null;
} }
if (!isGemini3Model(model)) {
return null;
}
const promptId = getPromptIdWithFallback('classifier-router'); const promptId = getPromptIdWithFallback('classifier-router');
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT); const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
@@ -176,10 +181,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config.getSessionId() || 'unknown-session', config.getSessionId() || 'unknown-session',
); );
const selectedModel = resolveClassifierModel( const selectedModel = resolveClassifierModel(model, modelAlias);
config.getModel(),
modelAlias,
);
const latencyMs = Date.now() - startTime; const latencyMs = Date.now() - startTime;