From 30b6e987c63eae6f95af9d07bb45a1c9d8f0240f Mon Sep 17 00:00:00 2001 From: Sehoon Shon Date: Wed, 11 Mar 2026 01:12:55 -0400 Subject: [PATCH] fix(core): enable numerical routing by default and set threshold fallback to 90 --- packages/core/src/config/config.test.ts | 37 ++++++ packages/core/src/config/config.ts | 8 +- .../src/routing/modelRouterService.test.ts | 10 +- .../core/src/routing/modelRouterService.ts | 10 +- .../numericalClassifierStrategy.test.ts | 109 ++++-------------- .../strategies/numericalClassifierStrategy.ts | 51 ++------ 6 files changed, 89 insertions(+), 136 deletions(-) diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index fc262e2b13..0b58dee810 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -492,6 +492,43 @@ describe('Server Config (config.ts)', () => { expect(await config.getUserCaching()).toBeUndefined(); }); }); + + describe('getNumericalRoutingEnabled', () => { + it('should return true by default if there are no experiments', async () => { + const config = new Config(baseParams); + expect(await config.getNumericalRoutingEnabled()).toBe(true); + }); + + it('should return true if the remote flag is set to true', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: { + boolValue: true, + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getNumericalRoutingEnabled()).toBe(true); + }); + + it('should return false if the remote flag is explicitly set to false', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: { + boolValue: false, + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getNumericalRoutingEnabled()).toBe(false); + }); + }); }); describe('refreshAuth', () => { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index bc52050286..5ca13e88d4 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -2508,8 +2508,12 @@ export class Config implements McpContext, AgentLoopContext { async getNumericalRoutingEnabled(): Promise { await this.ensureExperimentsLoaded(); - return !!this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING] - ?.boolValue; + const flag = + this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]; + if (flag?.boolValue !== undefined) { + return flag.boolValue; + } + return true; } async getClassifierThreshold(): Promise { diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts index ad0e3c890e..9824088ddf 100644 --- a/packages/core/src/routing/modelRouterService.test.ts +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -54,7 +54,7 @@ describe('ModelRouterService', () => { vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue( mockLocalLiteRtLmClient, ); - vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false); + vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true); vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined); vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({ enabled: false, @@ -182,8 +182,8 @@ describe('ModelRouterService', () => { false, undefined, ApprovalMode.DEFAULT, - false, - undefined, + true, + '90', ); expect(logModelRouting).toHaveBeenCalledWith( mockConfig, @@ -209,8 +209,8 @@ describe('ModelRouterService', () => { true, 'Strategy failed', ApprovalMode.DEFAULT, - false, - undefined, + true, + '90', ); expect(logModelRouting).toHaveBeenCalledWith( mockConfig, diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts index 1bd19f3622..93afb8ecb5 100644 --- a/packages/core/src/routing/modelRouterService.ts +++ b/packages/core/src/routing/modelRouterService.ts @@ -14,7 +14,10 @@ import type { } from './routingStrategy.js'; import { DefaultStrategy } from './strategies/defaultStrategy.js'; import { ClassifierStrategy } from './strategies/classifierStrategy.js'; -import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js'; +import { + NumericalClassifierStrategy, + DEFAULT_CLASSIFIER_THRESHOLD, +} from './strategies/numericalClassifierStrategy.js'; import { CompositeStrategy } from './strategies/compositeStrategy.js'; import { FallbackStrategy } from './strategies/fallbackStrategy.js'; import { OverrideStrategy } from './strategies/overrideStrategy.js'; @@ -80,8 +83,9 @@ export class ModelRouterService { this.config.getNumericalRoutingEnabled(), this.config.getClassifierThreshold(), ]); - const classifierThreshold = - thresholdValue !== undefined ? String(thresholdValue) : undefined; + const classifierThreshold = String( + thresholdValue ?? DEFAULT_CLASSIFIER_THRESHOLD, + ); let failed = false; let error_message: string | undefined; diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 7a0439bd19..f0ff0d8860 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -152,12 +152,11 @@ describe('NumericalClassifierStrategy', () => { expect(textPart?.text).toBe('simple task'); }); - describe('A/B Testing Logic (Deterministic)', () => { - it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control + describe('Default Logic', () => { + it('should route to FLASH when score is below 90', async () => { const mockApiResponse = { complexity_reasoning: 'Standard task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -173,72 +172,17 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); - const mockApiResponse = { - complexity_reasoning: 'Complex task', - complexity_score: 60, - }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); - - const decision = await strategy.route( - mockContext, - mockConfig, - mockBaseLlmClient, - mockLocalLiteRtLmClient, - ); - - expect(decision).toEqual({ - model: PREVIEW_GEMINI_MODEL, - metadata: { - source: 'NumericalClassifier (Control)', - latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 50'), - }, - }); - }); - - it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict - const mockApiResponse = { - complexity_reasoning: 'Complex task', - complexity_score: 60, - }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); - - const decision = await strategy.route( - mockContext, - mockConfig, - mockBaseLlmClient, - mockLocalLiteRtLmClient, - ); - - expect(decision).toEqual({ - model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 - metadata: { - source: 'NumericalClassifier (Strict)', - latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 80'), - }, - }); - }); - - it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); + it('should route to PRO when score is 90 or above', async () => { const mockApiResponse = { complexity_reasoning: 'Extreme task', - complexity_score: 90, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -254,9 +198,9 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, metadata: { - source: 'NumericalClassifier (Strict)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 90 / Threshold: 80'), + reasoning: expect.stringContaining('Score: 95 / Threshold: 90'), }, }); }); @@ -344,13 +288,12 @@ describe('NumericalClassifierStrategy', () => { }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => { // Mock getClassifierThreshold to return undefined vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50) const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -364,21 +307,20 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90 metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -394,19 +336,18 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 60, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -422,9 +363,9 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 95 / Threshold: 90'), }, }); }); @@ -591,7 +532,7 @@ describe('NumericalClassifierStrategy', () => { vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -613,7 +554,7 @@ describe('NumericalClassifierStrategy', () => { }); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -636,7 +577,7 @@ describe('NumericalClassifierStrategy', () => { }); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 1b5b67aac4..5673b15169 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -25,6 +25,12 @@ const HISTORY_TURNS_FOR_CONTEXT = 8; const FLASH_MODEL = 'flash'; const PRO_MODEL = 'pro'; +/** + * The default complexity threshold for routing. + * If the score is greater than or equal to this threshold, the Pro model is used. + */ +export const DEFAULT_CLASSIFIER_THRESHOLD = 90; + const RESPONSE_SCHEMA = { type: Type.OBJECT, properties: { @@ -93,39 +99,6 @@ const ClassifierResponseSchema = z.object({ complexity_score: z.number().min(1).max(100), }); -/** - * Deterministically calculates the routing threshold based on the session ID. - * This ensures a consistent experience for the user within a session. - * - * This implementation uses the FNV-1a hash algorithm (32-bit). - * @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function - * - * @param sessionId The unique session identifier. - * @returns The threshold (50 or 80). - */ -function getComplexityThreshold(sessionId: string): number { - const FNV_OFFSET_BASIS_32 = 0x811c9dc5; - const FNV_PRIME_32 = 0x01000193; - - let hash = FNV_OFFSET_BASIS_32; - - for (let i = 0; i < sessionId.length; i++) { - hash ^= sessionId.charCodeAt(i); - // Multiply by prime (simulate 32-bit overflow with bitwise shift) - hash = Math.imul(hash, FNV_PRIME_32); - } - - // Ensure positive integer - hash = hash >>> 0; - - // Normalize to 0-99 - const normalized = hash % 100; - // 50% split: - // 0-49: Strict (80) - // 50-99: Control (50) - return normalized < 50 ? 80 : 50; -} - export class NumericalClassifierStrategy implements RoutingStrategy { readonly name = 'numerical_classifier'; @@ -179,11 +152,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy { const score = routerResponse.complexity_score; const { threshold, groupLabel, modelAlias } = - await this.getRoutingDecision( - score, - config, - config.getSessionId() || 'unknown-session', - ); + await this.getRoutingDecision(score, config); const [useGemini3_1, useCustomToolModel] = await Promise.all([ config.getGemini31Launched(), config.getUseCustomToolModel(), @@ -214,7 +183,6 @@ export class NumericalClassifierStrategy implements RoutingStrategy { private async getRoutingDecision( score: number, config: Config, - sessionId: string, ): Promise<{ threshold: number; groupLabel: string; @@ -234,9 +202,8 @@ export class NumericalClassifierStrategy implements RoutingStrategy { threshold = remoteThresholdValue; groupLabel = 'Remote'; } else { - // Fallback to deterministic A/B test - threshold = getComplexityThreshold(sessionId); - groupLabel = threshold === 80 ? 'Strict' : 'Control'; + threshold = DEFAULT_CLASSIFIER_THRESHOLD; + groupLabel = 'Default'; } const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;