mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-28 15:01:14 -07:00
fix(core): enable numerical routing for api key users (#21977)
This commit is contained in:
@@ -54,7 +54,10 @@ describe('ModelRouterService', () => {
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
|
||||
vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockResolvedValue(
|
||||
90,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: false,
|
||||
@@ -182,8 +185,8 @@ describe('ModelRouterService', () => {
|
||||
false,
|
||||
undefined,
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
true,
|
||||
'90',
|
||||
);
|
||||
expect(logModelRouting).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
@@ -209,8 +212,8 @@ describe('ModelRouterService', () => {
|
||||
true,
|
||||
'Strategy failed',
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
true,
|
||||
'90',
|
||||
);
|
||||
expect(logModelRouting).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
|
||||
@@ -78,10 +78,9 @@ export class ModelRouterService {
|
||||
|
||||
const [enableNumericalRouting, thresholdValue] = await Promise.all([
|
||||
this.config.getNumericalRoutingEnabled(),
|
||||
this.config.getClassifierThreshold(),
|
||||
this.config.getResolvedClassifierThreshold(),
|
||||
]);
|
||||
const classifierThreshold =
|
||||
thresholdValue !== undefined ? String(thresholdValue) : undefined;
|
||||
const classifierThreshold = String(thresholdValue);
|
||||
|
||||
let failed = false;
|
||||
let error_message: string | undefined;
|
||||
|
||||
@@ -56,6 +56,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||
getResolvedClassifierThreshold: vi.fn().mockResolvedValue(90),
|
||||
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
||||
getGemini31Launched: vi.fn().mockResolvedValue(false),
|
||||
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
|
||||
@@ -152,12 +153,11 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(textPart?.text).toBe('simple task');
|
||||
});
|
||||
|
||||
describe('A/B Testing Logic (Deterministic)', () => {
|
||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control
|
||||
describe('Default Logic', () => {
|
||||
it('should route to FLASH when score is below 90', async () => {
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Standard task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -173,72 +173,17 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 60,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 60,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 80'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1');
|
||||
it('should route to PRO when score is 90 or above', async () => {
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Extreme task',
|
||||
complexity_score: 90,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -254,9 +199,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 90 / Threshold: 80'),
|
||||
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -265,6 +210,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
describe('Remote Threshold Logic', () => {
|
||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70);
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
70,
|
||||
);
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 60,
|
||||
@@ -292,6 +240,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
|
||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5);
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
45.5,
|
||||
);
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
@@ -319,6 +270,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
|
||||
it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30);
|
||||
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||
30,
|
||||
);
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 35,
|
||||
@@ -344,13 +298,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
||||
// Mock getClassifierThreshold to return undefined
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50)
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -364,21 +317,20 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -394,19 +346,18 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 60,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -422,9 +373,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -591,7 +542,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -613,7 +564,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -636,7 +587,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
|
||||
@@ -93,39 +93,6 @@ const ClassifierResponseSchema = z.object({
|
||||
complexity_score: z.number().min(1).max(100),
|
||||
});
|
||||
|
||||
/**
|
||||
* Deterministically calculates the routing threshold based on the session ID.
|
||||
* This ensures a consistent experience for the user within a session.
|
||||
*
|
||||
* This implementation uses the FNV-1a hash algorithm (32-bit).
|
||||
* @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
||||
*
|
||||
* @param sessionId The unique session identifier.
|
||||
* @returns The threshold (50 or 80).
|
||||
*/
|
||||
function getComplexityThreshold(sessionId: string): number {
|
||||
const FNV_OFFSET_BASIS_32 = 0x811c9dc5;
|
||||
const FNV_PRIME_32 = 0x01000193;
|
||||
|
||||
let hash = FNV_OFFSET_BASIS_32;
|
||||
|
||||
for (let i = 0; i < sessionId.length; i++) {
|
||||
hash ^= sessionId.charCodeAt(i);
|
||||
// Multiply by prime (simulate 32-bit overflow with bitwise shift)
|
||||
hash = Math.imul(hash, FNV_PRIME_32);
|
||||
}
|
||||
|
||||
// Ensure positive integer
|
||||
hash = hash >>> 0;
|
||||
|
||||
// Normalize to 0-99
|
||||
const normalized = hash % 100;
|
||||
// 50% split:
|
||||
// 0-49: Strict (80)
|
||||
// 50-99: Control (50)
|
||||
return normalized < 50 ? 80 : 50;
|
||||
}
|
||||
|
||||
export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
readonly name = 'numerical_classifier';
|
||||
|
||||
@@ -179,11 +146,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
const score = routerResponse.complexity_score;
|
||||
|
||||
const { threshold, groupLabel, modelAlias } =
|
||||
await this.getRoutingDecision(
|
||||
score,
|
||||
config,
|
||||
config.getSessionId() || 'unknown-session',
|
||||
);
|
||||
await this.getRoutingDecision(score, config);
|
||||
const [useGemini3_1, useCustomToolModel] = await Promise.all([
|
||||
config.getGemini31Launched(),
|
||||
config.getUseCustomToolModel(),
|
||||
@@ -214,29 +177,19 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
private async getRoutingDecision(
|
||||
score: number,
|
||||
config: Config,
|
||||
sessionId: string,
|
||||
): Promise<{
|
||||
threshold: number;
|
||||
groupLabel: string;
|
||||
modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL;
|
||||
}> {
|
||||
let threshold: number;
|
||||
let groupLabel: string;
|
||||
|
||||
const threshold = await config.getResolvedClassifierThreshold();
|
||||
const remoteThresholdValue = await config.getClassifierThreshold();
|
||||
|
||||
if (
|
||||
remoteThresholdValue !== undefined &&
|
||||
!isNaN(remoteThresholdValue) &&
|
||||
remoteThresholdValue >= 0 &&
|
||||
remoteThresholdValue <= 100
|
||||
) {
|
||||
threshold = remoteThresholdValue;
|
||||
let groupLabel: string;
|
||||
if (threshold === remoteThresholdValue) {
|
||||
groupLabel = 'Remote';
|
||||
} else {
|
||||
// Fallback to deterministic A/B test
|
||||
threshold = getComplexityThreshold(sessionId);
|
||||
groupLabel = threshold === 80 ? 'Strict' : 'Control';
|
||||
groupLabel = 'Default';
|
||||
}
|
||||
|
||||
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;
|
||||
|
||||
Reference in New Issue
Block a user