fix(core): enable numerical routing for api key users (#21977)

This commit is contained in:
Sehoon Shon
2026-03-11 14:54:52 -04:00
committed by GitHub
parent 08e174a05c
commit 36ce2ba96e
6 changed files with 163 additions and 146 deletions

View File

@@ -54,7 +54,10 @@ describe('ModelRouterService', () => {
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
mockLocalLiteRtLmClient,
);
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockResolvedValue(
90,
);
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
enabled: false,
@@ -182,8 +185,8 @@ describe('ModelRouterService', () => {
false,
undefined,
ApprovalMode.DEFAULT,
false,
undefined,
true,
'90',
);
expect(logModelRouting).toHaveBeenCalledWith(
mockConfig,
@@ -209,8 +212,8 @@ describe('ModelRouterService', () => {
true,
'Strategy failed',
ApprovalMode.DEFAULT,
false,
undefined,
true,
'90',
);
expect(logModelRouting).toHaveBeenCalledWith(
mockConfig,

View File

@@ -78,10 +78,9 @@ export class ModelRouterService {
const [enableNumericalRouting, thresholdValue] = await Promise.all([
this.config.getNumericalRoutingEnabled(),
this.config.getClassifierThreshold(),
this.config.getResolvedClassifierThreshold(),
]);
const classifierThreshold =
thresholdValue !== undefined ? String(thresholdValue) : undefined;
const classifierThreshold = String(thresholdValue);
let failed = false;
let error_message: string | undefined;

View File

@@ -56,6 +56,7 @@ describe('NumericalClassifierStrategy', () => {
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getResolvedClassifierThreshold: vi.fn().mockResolvedValue(90),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
@@ -152,12 +153,11 @@ describe('NumericalClassifierStrategy', () => {
expect(textPart?.text).toBe('simple task');
});
describe('A/B Testing Logic (Deterministic)', () => {
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control
describe('Default Logic', () => {
it('should route to FLASH when score is below 90', async () => {
const mockApiResponse = {
complexity_reasoning: 'Standard task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -173,72 +173,17 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 60,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
},
});
});
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 60,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
metadata: {
source: 'NumericalClassifier (Strict)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 80'),
},
});
});
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => {
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1');
it('should route to PRO when score is 90 or above', async () => {
const mockApiResponse = {
complexity_reasoning: 'Extreme task',
complexity_score: 90,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -254,9 +199,9 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Strict)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 90 / Threshold: 80'),
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
},
});
});
@@ -265,6 +210,9 @@ describe('NumericalClassifierStrategy', () => {
describe('Remote Threshold Logic', () => {
it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70);
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
70,
);
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 60,
@@ -292,6 +240,9 @@ describe('NumericalClassifierStrategy', () => {
it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5);
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
45.5,
);
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 40,
@@ -319,6 +270,9 @@ describe('NumericalClassifierStrategy', () => {
it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30);
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
30,
);
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 35,
@@ -344,13 +298,12 @@ describe('NumericalClassifierStrategy', () => {
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
// Mock getClassifierThreshold to return undefined
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50)
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -364,21 +317,20 @@ describe('NumericalClassifierStrategy', () => {
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 40,
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -394,19 +346,18 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
},
});
});
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
const mockApiResponse = {
complexity_reasoning: 'Test task',
complexity_score: 60,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -422,9 +373,9 @@ describe('NumericalClassifierStrategy', () => {
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'NumericalClassifier (Control)',
source: 'NumericalClassifier (Default)',
latencyMs: expect.any(Number),
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
},
});
});
@@ -591,7 +542,7 @@ describe('NumericalClassifierStrategy', () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -613,7 +564,7 @@ describe('NumericalClassifierStrategy', () => {
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
@@ -636,7 +587,7 @@ describe('NumericalClassifierStrategy', () => {
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
complexity_score: 95,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,

View File

@@ -93,39 +93,6 @@ const ClassifierResponseSchema = z.object({
complexity_score: z.number().min(1).max(100),
});
/**
* Deterministically calculates the routing threshold based on the session ID.
* This ensures a consistent experience for the user within a session.
*
* This implementation uses the FNV-1a hash algorithm (32-bit).
* @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
*
* @param sessionId The unique session identifier.
* @returns The threshold (50 or 80).
*/
function getComplexityThreshold(sessionId: string): number {
const FNV_OFFSET_BASIS_32 = 0x811c9dc5;
const FNV_PRIME_32 = 0x01000193;
let hash = FNV_OFFSET_BASIS_32;
for (let i = 0; i < sessionId.length; i++) {
hash ^= sessionId.charCodeAt(i);
// Multiply by prime (simulate 32-bit overflow with bitwise shift)
hash = Math.imul(hash, FNV_PRIME_32);
}
// Ensure positive integer
hash = hash >>> 0;
// Normalize to 0-99
const normalized = hash % 100;
// 50% split:
// 0-49: Strict (80)
// 50-99: Control (50)
return normalized < 50 ? 80 : 50;
}
export class NumericalClassifierStrategy implements RoutingStrategy {
readonly name = 'numerical_classifier';
@@ -179,11 +146,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
const score = routerResponse.complexity_score;
const { threshold, groupLabel, modelAlias } =
await this.getRoutingDecision(
score,
config,
config.getSessionId() || 'unknown-session',
);
await this.getRoutingDecision(score, config);
const [useGemini3_1, useCustomToolModel] = await Promise.all([
config.getGemini31Launched(),
config.getUseCustomToolModel(),
@@ -214,29 +177,19 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
private async getRoutingDecision(
score: number,
config: Config,
sessionId: string,
): Promise<{
threshold: number;
groupLabel: string;
modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL;
}> {
let threshold: number;
let groupLabel: string;
const threshold = await config.getResolvedClassifierThreshold();
const remoteThresholdValue = await config.getClassifierThreshold();
if (
remoteThresholdValue !== undefined &&
!isNaN(remoteThresholdValue) &&
remoteThresholdValue >= 0 &&
remoteThresholdValue <= 100
) {
threshold = remoteThresholdValue;
let groupLabel: string;
if (threshold === remoteThresholdValue) {
groupLabel = 'Remote';
} else {
// Fallback to deterministic A/B test
threshold = getComplexityThreshold(sessionId);
groupLabel = threshold === 80 ? 'Strict' : 'Control';
groupLabel = 'Default';
}
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;