From 36ce2ba96e33a6c32d7e11935fee765bc33ba1b8 Mon Sep 17 00:00:00 2001 From: Sehoon Shon Date: Wed, 11 Mar 2026 14:54:52 -0400 Subject: [PATCH] fix(core): enable numerical routing for api key users (#21977) --- packages/core/src/config/config.test.ts | 89 +++++++++++++ packages/core/src/config/config.ts | 26 +++- .../src/routing/modelRouterService.test.ts | 13 +- .../core/src/routing/modelRouterService.ts | 5 +- .../numericalClassifierStrategy.test.ts | 119 ++++++------------ .../strategies/numericalClassifierStrategy.ts | 57 +-------- 6 files changed, 163 insertions(+), 146 deletions(-) diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index fc262e2b13..822898b444 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -492,6 +492,95 @@ describe('Server Config (config.ts)', () => { expect(await config.getUserCaching()).toBeUndefined(); }); }); + + describe('getNumericalRoutingEnabled', () => { + it('should return true by default if there are no experiments', async () => { + const config = new Config(baseParams); + expect(await config.getNumericalRoutingEnabled()).toBe(true); + }); + + it('should return true if the remote flag is set to true', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: { + boolValue: true, + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getNumericalRoutingEnabled()).toBe(true); + }); + + it('should return false if the remote flag is explicitly set to false', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: { + boolValue: false, + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getNumericalRoutingEnabled()).toBe(false); + }); + }); + + describe('getResolvedClassifierThreshold', () => { + it('should return 90 by default if there are no experiments', async () => { + const config = new Config(baseParams); + expect(await config.getResolvedClassifierThreshold()).toBe(90); + }); + + it('should return the remote flag value if it is within range (0-100)', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.CLASSIFIER_THRESHOLD]: { + intValue: '75', + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getResolvedClassifierThreshold()).toBe(75); + }); + + it('should return 90 if the remote flag is out of range (less than 0)', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.CLASSIFIER_THRESHOLD]: { + intValue: '-10', + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getResolvedClassifierThreshold()).toBe(90); + }); + + it('should return 90 if the remote flag is out of range (greater than 100)', async () => { + const config = new Config({ + ...baseParams, + experiments: { + flags: { + [ExperimentFlags.CLASSIFIER_THRESHOLD]: { + intValue: '110', + }, + }, + experimentIds: [], + }, + } as unknown as ConfigParameters); + expect(await config.getResolvedClassifierThreshold()).toBe(90); + }); + }); }); describe('refreshAuth', () => { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 1f2c578f29..a07264f430 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -2512,8 +2512,30 @@ export class Config implements McpContext, AgentLoopContext { async getNumericalRoutingEnabled(): Promise { await this.ensureExperimentsLoaded(); - return !!this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING] - ?.boolValue; + const flag = + this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]; + return flag?.boolValue ?? true; + } + + /** + * Returns the resolved complexity threshold for routing. + * If a remote threshold is provided and within range (0-100), it is returned. + * Otherwise, the default threshold (90) is returned. + */ + async getResolvedClassifierThreshold(): Promise { + const remoteValue = await this.getClassifierThreshold(); + const defaultValue = 90; + + if ( + remoteValue !== undefined && + !isNaN(remoteValue) && + remoteValue >= 0 && + remoteValue <= 100 + ) { + return remoteValue; + } + + return defaultValue; } async getClassifierThreshold(): Promise { diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts index ad0e3c890e..4e0c32c62f 100644 --- a/packages/core/src/routing/modelRouterService.test.ts +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -54,7 +54,10 @@ describe('ModelRouterService', () => { vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue( mockLocalLiteRtLmClient, ); - vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false); + vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true); + vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockResolvedValue( + 90, + ); vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined); vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({ enabled: false, @@ -182,8 +185,8 @@ describe('ModelRouterService', () => { false, undefined, ApprovalMode.DEFAULT, - false, - undefined, + true, + '90', ); expect(logModelRouting).toHaveBeenCalledWith( mockConfig, @@ -209,8 +212,8 @@ describe('ModelRouterService', () => { true, 'Strategy failed', ApprovalMode.DEFAULT, - false, - undefined, + true, + '90', ); expect(logModelRouting).toHaveBeenCalledWith( mockConfig, diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts index 1bd19f3622..a62deacd31 100644 --- a/packages/core/src/routing/modelRouterService.ts +++ b/packages/core/src/routing/modelRouterService.ts @@ -78,10 +78,9 @@ export class ModelRouterService { const [enableNumericalRouting, thresholdValue] = await Promise.all([ this.config.getNumericalRoutingEnabled(), - this.config.getClassifierThreshold(), + this.config.getResolvedClassifierThreshold(), ]); - const classifierThreshold = - thresholdValue !== undefined ? String(thresholdValue) : undefined; + const classifierThreshold = String(thresholdValue); let failed = false; let error_message: string | undefined; diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 7a0439bd19..d8a9c48ed1 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -56,6 +56,7 @@ describe('NumericalClassifierStrategy', () => { getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO), getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), + getResolvedClassifierThreshold: vi.fn().mockResolvedValue(90), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), getGemini31Launched: vi.fn().mockResolvedValue(false), getUseCustomToolModel: vi.fn().mockImplementation(async () => { @@ -152,12 +153,11 @@ describe('NumericalClassifierStrategy', () => { expect(textPart?.text).toBe('simple task'); }); - describe('A/B Testing Logic (Deterministic)', () => { - it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control + describe('Default Logic', () => { + it('should route to FLASH when score is below 90', async () => { const mockApiResponse = { complexity_reasoning: 'Standard task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -173,72 +173,17 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); - const mockApiResponse = { - complexity_reasoning: 'Complex task', - complexity_score: 60, - }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); - - const decision = await strategy.route( - mockContext, - mockConfig, - mockBaseLlmClient, - mockLocalLiteRtLmClient, - ); - - expect(decision).toEqual({ - model: PREVIEW_GEMINI_MODEL, - metadata: { - source: 'NumericalClassifier (Control)', - latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 50'), - }, - }); - }); - - it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict - const mockApiResponse = { - complexity_reasoning: 'Complex task', - complexity_score: 60, - }; - vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( - mockApiResponse, - ); - - const decision = await strategy.route( - mockContext, - mockConfig, - mockBaseLlmClient, - mockLocalLiteRtLmClient, - ); - - expect(decision).toEqual({ - model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80 - metadata: { - source: 'NumericalClassifier (Strict)', - latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 80'), - }, - }); - }); - - it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => { - vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); + it('should route to PRO when score is 90 or above', async () => { const mockApiResponse = { complexity_reasoning: 'Extreme task', - complexity_score: 90, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -254,9 +199,9 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, metadata: { - source: 'NumericalClassifier (Strict)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 90 / Threshold: 80'), + reasoning: expect.stringContaining('Score: 95 / Threshold: 90'), }, }); }); @@ -265,6 +210,9 @@ describe('NumericalClassifierStrategy', () => { describe('Remote Threshold Logic', () => { it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70); + vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue( + 70, + ); const mockApiResponse = { complexity_reasoning: 'Test task', complexity_score: 60, @@ -292,6 +240,9 @@ describe('NumericalClassifierStrategy', () => { it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5); + vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue( + 45.5, + ); const mockApiResponse = { complexity_reasoning: 'Test task', complexity_score: 40, @@ -319,6 +270,9 @@ describe('NumericalClassifierStrategy', () => { it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30); + vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue( + 30, + ); const mockApiResponse = { complexity_reasoning: 'Test task', complexity_score: 35, @@ -344,13 +298,12 @@ describe('NumericalClassifierStrategy', () => { }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => { // Mock getClassifierThreshold to return undefined vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50) const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -364,21 +317,20 @@ describe('NumericalClassifierStrategy', () => { ); expect(decision).toEqual({ - model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50 + model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90 metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 40, + complexity_score: 80, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -394,19 +346,18 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_FLASH_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 40 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 80 / Threshold: 90'), }, }); }); - it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => { + it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => { vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110); - vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); const mockApiResponse = { complexity_reasoning: 'Test task', - complexity_score: 60, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -422,9 +373,9 @@ describe('NumericalClassifierStrategy', () => { expect(decision).toEqual({ model: PREVIEW_GEMINI_MODEL, metadata: { - source: 'NumericalClassifier (Control)', + source: 'NumericalClassifier (Default)', latencyMs: expect.any(Number), - reasoning: expect.stringContaining('Score: 60 / Threshold: 50'), + reasoning: expect.stringContaining('Score: 95 / Threshold: 90'), }, }); }); @@ -591,7 +542,7 @@ describe('NumericalClassifierStrategy', () => { vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -613,7 +564,7 @@ describe('NumericalClassifierStrategy', () => { }); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, @@ -636,7 +587,7 @@ describe('NumericalClassifierStrategy', () => { }); const mockApiResponse = { complexity_reasoning: 'Complex task', - complexity_score: 80, + complexity_score: 95, }; vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( mockApiResponse, diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 1b5b67aac4..c86576d6ce 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -93,39 +93,6 @@ const ClassifierResponseSchema = z.object({ complexity_score: z.number().min(1).max(100), }); -/** - * Deterministically calculates the routing threshold based on the session ID. - * This ensures a consistent experience for the user within a session. - * - * This implementation uses the FNV-1a hash algorithm (32-bit). - * @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function - * - * @param sessionId The unique session identifier. - * @returns The threshold (50 or 80). - */ -function getComplexityThreshold(sessionId: string): number { - const FNV_OFFSET_BASIS_32 = 0x811c9dc5; - const FNV_PRIME_32 = 0x01000193; - - let hash = FNV_OFFSET_BASIS_32; - - for (let i = 0; i < sessionId.length; i++) { - hash ^= sessionId.charCodeAt(i); - // Multiply by prime (simulate 32-bit overflow with bitwise shift) - hash = Math.imul(hash, FNV_PRIME_32); - } - - // Ensure positive integer - hash = hash >>> 0; - - // Normalize to 0-99 - const normalized = hash % 100; - // 50% split: - // 0-49: Strict (80) - // 50-99: Control (50) - return normalized < 50 ? 80 : 50; -} - export class NumericalClassifierStrategy implements RoutingStrategy { readonly name = 'numerical_classifier'; @@ -179,11 +146,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy { const score = routerResponse.complexity_score; const { threshold, groupLabel, modelAlias } = - await this.getRoutingDecision( - score, - config, - config.getSessionId() || 'unknown-session', - ); + await this.getRoutingDecision(score, config); const [useGemini3_1, useCustomToolModel] = await Promise.all([ config.getGemini31Launched(), config.getUseCustomToolModel(), @@ -214,29 +177,19 @@ export class NumericalClassifierStrategy implements RoutingStrategy { private async getRoutingDecision( score: number, config: Config, - sessionId: string, ): Promise<{ threshold: number; groupLabel: string; modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL; }> { - let threshold: number; - let groupLabel: string; - + const threshold = await config.getResolvedClassifierThreshold(); const remoteThresholdValue = await config.getClassifierThreshold(); - if ( - remoteThresholdValue !== undefined && - !isNaN(remoteThresholdValue) && - remoteThresholdValue >= 0 && - remoteThresholdValue <= 100 - ) { - threshold = remoteThresholdValue; + let groupLabel: string; + if (threshold === remoteThresholdValue) { groupLabel = 'Remote'; } else { - // Fallback to deterministic A/B test - threshold = getComplexityThreshold(sessionId); - groupLabel = threshold === 80 ? 'Strict' : 'Control'; + groupLabel = 'Default'; } const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;