fix(core): enable numerical routing by default and set threshold fallback to 90

This commit is contained in:
Sehoon Shon
2026-03-11 01:12:55 -04:00
parent f8ad3a200a
commit 30b6e987c6
6 changed files with 89 additions and 136 deletions
+37
View File
@@ -492,6 +492,43 @@ describe('Server Config (config.ts)', () => {
expect(await config.getUserCaching()).toBeUndefined();
});
});
describe('getNumericalRoutingEnabled', () => {
it('should return true by default if there are no experiments', async () => {
const config = new Config(baseParams);
expect(await config.getNumericalRoutingEnabled()).toBe(true);
});
it('should return true if the remote flag is set to true', async () => {
const config = new Config({
...baseParams,
experiments: {
flags: {
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
boolValue: true,
},
},
experimentIds: [],
},
} as unknown as ConfigParameters);
expect(await config.getNumericalRoutingEnabled()).toBe(true);
});
it('should return false if the remote flag is explicitly set to false', async () => {
const config = new Config({
...baseParams,
experiments: {
flags: {
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
boolValue: false,
},
},
experimentIds: [],
},
} as unknown as ConfigParameters);
expect(await config.getNumericalRoutingEnabled()).toBe(false);
});
});
});
describe('refreshAuth', () => {
+6 -2
View File
@@ -2508,8 +2508,12 @@ export class Config implements McpContext, AgentLoopContext {
async getNumericalRoutingEnabled(): Promise<boolean> {
await this.ensureExperimentsLoaded();
return !!this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]
?.boolValue;
const flag =
this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING];
if (flag?.boolValue !== undefined) {
return flag.boolValue;
}
return true;
}
async getClassifierThreshold(): Promise<number | undefined> {
@@ -54,7 +54,7 @@ describe('ModelRouterService', () => {
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
mockLocalLiteRtLmClient,
);
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
enabled: false,
@@ -182,8 +182,8 @@ describe('ModelRouterService', () => {
false,
undefined,
ApprovalMode.DEFAULT,
false,
undefined,
true,
'90',
);
expect(logModelRouting).toHaveBeenCalledWith(
mockConfig,
@@ -209,8 +209,8 @@ describe('ModelRouterService', () => {
true,
'Strategy failed',
ApprovalMode.DEFAULT,
false,
undefined,
true,
'90',
);
expect(logModelRouting).toHaveBeenCalledWith(
mockConfig,
@@ -14,7 +14,10 @@ import type {
} from './routingStrategy.js';
import { DefaultStrategy } from './strategies/defaultStrategy.js';
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
import {
NumericalClassifierStrategy,
DEFAULT_CLASSIFIER_THRESHOLD,
} from './strategies/numericalClassifierStrategy.js';
import { CompositeStrategy } from './strategies/compositeStrategy.js';
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
import { OverrideStrategy } from './strategies/overrideStrategy.js';
@@ -80,8 +83,9 @@ export class ModelRouterService {
this.config.getNumericalRoutingEnabled(),
this.config.getClassifierThreshold(),
]);
const classifierThreshold =
thresholdValue !== undefined ? String(thresholdValue) : undefined;
const classifierThreshold = String(
thresholdValue ?? DEFAULT_CLASSIFIER_THRESHOLD,
);
let failed = false;
let error_message: string | undefined;
@@ -152,12 +152,11 @@ describe('NumericalClassifierStrategy', () => {
expect(textPart?.text).toBe('simple task');
});
describe('A/B Testing Logic (Deterministic)', () => {
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control
describe('Default Logic', () => {
it('should route to FLASH when score is below 90', async () => {
const mockApiResponse = {
complexity_reasoning: 'Standard task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -173,72 +172,17 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 60,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
},
});
});
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 60,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 80'),
},
});
});
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1');
it('should route to PRO when score is 90 or above', async () => {
const mockApiResponse = {
complexity_reasoning: 'Extreme task',
complexity_score: 90,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -254,9 +198,9 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Strict)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 90 / Threshold: 80'),
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
},
});
});
@@ -344,13 +288,12 @@ describe('NumericalClassifierStrategy', () => {
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
// Mock getClassifierThreshold to return undefined
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50)
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -364,21 +307,20 @@ describe('NumericalClassifierStrategy', () => {
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -394,19 +336,18 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 60,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -422,9 +363,9 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
},
});
});
@@ -591,7 +532,7 @@ describe('NumericalClassifierStrategy', () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -613,7 +554,7 @@ describe('NumericalClassifierStrategy', () => {
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -636,7 +577,7 @@ describe('NumericalClassifierStrategy', () => {
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -25,6 +25,12 @@ const HISTORY_TURNS_FOR_CONTEXT = 8;
const FLASH_MODEL = 'flash';
const PRO_MODEL = 'pro';
/**
* The default complexity threshold for routing.
* If the score is greater than or equal to this threshold, the Pro model is used.
*/
export const DEFAULT_CLASSIFIER_THRESHOLD = 90;
const RESPONSE_SCHEMA = {
type: Type.OBJECT,
properties: {
@@ -93,39 +99,6 @@ const ClassifierResponseSchema = z.object({
complexity_score: z.number().min(1).max(100),
});
/**
* Deterministically calculates the routing threshold based on the session ID.
* This ensures a consistent experience for the user within a session.
*
* This implementation uses the FNV-1a hash algorithm (32-bit).
* @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
*
* @param sessionId The unique session identifier.
* @returns The threshold (50 or 80).
*/
function getComplexityThreshold(sessionId: string): number {
const FNV_OFFSET_BASIS_32 = 0x811c9dc5;
const FNV_PRIME_32 = 0x01000193;
let hash = FNV_OFFSET_BASIS_32;
for (let i = 0; i < sessionId.length; i++) {
hash ^= sessionId.charCodeAt(i);
// Multiply by prime (simulate 32-bit overflow with bitwise shift)
hash = Math.imul(hash, FNV_PRIME_32);
}
// Ensure positive integer
hash = hash >>> 0;
// Normalize to 0-99
const normalized = hash % 100;
// 50% split:
// 0-49: Strict (80)
// 50-99: Control (50)
return normalized < 50 ? 80 : 50;
}
export class NumericalClassifierStrategy implements RoutingStrategy {
readonly name = 'numerical_classifier';
@@ -179,11 +152,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
const score = routerResponse.complexity_score;
const { threshold, groupLabel, modelAlias } =
await this.getRoutingDecision(
score,
config,
config.getSessionId() || 'unknown-session',
);
await this.getRoutingDecision(score, config);
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
@@ -214,7 +183,6 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
private async getRoutingDecision(
score: number,
config: Config,
sessionId: string,
): Promise<{
threshold: number;
groupLabel: string;
@@ -234,9 +202,8 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
threshold = remoteThresholdValue;
groupLabel = 'Remote';
} else {
// Fallback to deterministic A/B test
threshold = getComplexityThreshold(sessionId);
groupLabel = threshold === 80 ? 'Strict' : 'Control';
threshold = DEFAULT_CLASSIFIER_THRESHOLD;
groupLabel = 'Default';
}
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;