mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 11:34:44 -07:00
fix(core): enable numerical routing by default and set threshold fallback to 90
This commit is contained in:
@@ -492,6 +492,43 @@ describe('Server Config (config.ts)', () => {
|
||||
expect(await config.getUserCaching()).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getNumericalRoutingEnabled', () => {
|
||||
it('should return true by default if there are no experiments', async () => {
|
||||
const config = new Config(baseParams);
|
||||
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true if the remote flag is set to true', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
experiments: {
|
||||
flags: {
|
||||
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
|
||||
boolValue: true,
|
||||
},
|
||||
},
|
||||
experimentIds: [],
|
||||
},
|
||||
} as unknown as ConfigParameters);
|
||||
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false if the remote flag is explicitly set to false', async () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
experiments: {
|
||||
flags: {
|
||||
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
|
||||
boolValue: false,
|
||||
},
|
||||
},
|
||||
experimentIds: [],
|
||||
},
|
||||
} as unknown as ConfigParameters);
|
||||
expect(await config.getNumericalRoutingEnabled()).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('refreshAuth', () => {
|
||||
|
||||
@@ -2508,8 +2508,12 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
async getNumericalRoutingEnabled(): Promise<boolean> {
|
||||
await this.ensureExperimentsLoaded();
|
||||
|
||||
return !!this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]
|
||||
?.boolValue;
|
||||
const flag =
|
||||
this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING];
|
||||
if (flag?.boolValue !== undefined) {
|
||||
return flag.boolValue;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async getClassifierThreshold(): Promise<number | undefined> {
|
||||
|
||||
@@ -54,7 +54,7 @@ describe('ModelRouterService', () => {
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: false,
|
||||
@@ -182,8 +182,8 @@ describe('ModelRouterService', () => {
|
||||
false,
|
||||
undefined,
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
true,
|
||||
'90',
|
||||
);
|
||||
expect(logModelRouting).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
@@ -209,8 +209,8 @@ describe('ModelRouterService', () => {
|
||||
true,
|
||||
'Strategy failed',
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
true,
|
||||
'90',
|
||||
);
|
||||
expect(logModelRouting).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
|
||||
@@ -14,7 +14,10 @@ import type {
|
||||
} from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
|
||||
import {
|
||||
NumericalClassifierStrategy,
|
||||
DEFAULT_CLASSIFIER_THRESHOLD,
|
||||
} from './strategies/numericalClassifierStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||
@@ -80,8 +83,9 @@ export class ModelRouterService {
|
||||
this.config.getNumericalRoutingEnabled(),
|
||||
this.config.getClassifierThreshold(),
|
||||
]);
|
||||
const classifierThreshold =
|
||||
thresholdValue !== undefined ? String(thresholdValue) : undefined;
|
||||
const classifierThreshold = String(
|
||||
thresholdValue ?? DEFAULT_CLASSIFIER_THRESHOLD,
|
||||
);
|
||||
|
||||
let failed = false;
|
||||
let error_message: string | undefined;
|
||||
|
||||
@@ -152,12 +152,11 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(textPart?.text).toBe('simple task');
|
||||
});
|
||||
|
||||
describe('A/B Testing Logic (Deterministic)', () => {
|
||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control
|
||||
describe('Default Logic', () => {
|
||||
it('should route to FLASH when score is below 90', async () => {
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Standard task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -173,72 +172,17 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 60,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 60,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 80'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => {
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1');
|
||||
it('should route to PRO when score is 90 or above', async () => {
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Extreme task',
|
||||
complexity_score: 90,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -254,9 +198,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 90 / Threshold: 80'),
|
||||
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -344,13 +288,12 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
||||
// Mock getClassifierThreshold to return undefined
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50)
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -364,21 +307,20 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 40,
|
||||
complexity_score: 80,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -394,19 +336,18 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
||||
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
|
||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Test task',
|
||||
complexity_score: 60,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -422,9 +363,9 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
source: 'NumericalClassifier (Default)',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
||||
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -591,7 +532,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -613,7 +554,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
@@ -636,7 +577,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
});
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Complex task',
|
||||
complexity_score: 80,
|
||||
complexity_score: 95,
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
|
||||
@@ -25,6 +25,12 @@ const HISTORY_TURNS_FOR_CONTEXT = 8;
|
||||
const FLASH_MODEL = 'flash';
|
||||
const PRO_MODEL = 'pro';
|
||||
|
||||
/**
|
||||
* The default complexity threshold for routing.
|
||||
* If the score is greater than or equal to this threshold, the Pro model is used.
|
||||
*/
|
||||
export const DEFAULT_CLASSIFIER_THRESHOLD = 90;
|
||||
|
||||
const RESPONSE_SCHEMA = {
|
||||
type: Type.OBJECT,
|
||||
properties: {
|
||||
@@ -93,39 +99,6 @@ const ClassifierResponseSchema = z.object({
|
||||
complexity_score: z.number().min(1).max(100),
|
||||
});
|
||||
|
||||
/**
|
||||
* Deterministically calculates the routing threshold based on the session ID.
|
||||
* This ensures a consistent experience for the user within a session.
|
||||
*
|
||||
* This implementation uses the FNV-1a hash algorithm (32-bit).
|
||||
* @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
||||
*
|
||||
* @param sessionId The unique session identifier.
|
||||
* @returns The threshold (50 or 80).
|
||||
*/
|
||||
function getComplexityThreshold(sessionId: string): number {
|
||||
const FNV_OFFSET_BASIS_32 = 0x811c9dc5;
|
||||
const FNV_PRIME_32 = 0x01000193;
|
||||
|
||||
let hash = FNV_OFFSET_BASIS_32;
|
||||
|
||||
for (let i = 0; i < sessionId.length; i++) {
|
||||
hash ^= sessionId.charCodeAt(i);
|
||||
// Multiply by prime (simulate 32-bit overflow with bitwise shift)
|
||||
hash = Math.imul(hash, FNV_PRIME_32);
|
||||
}
|
||||
|
||||
// Ensure positive integer
|
||||
hash = hash >>> 0;
|
||||
|
||||
// Normalize to 0-99
|
||||
const normalized = hash % 100;
|
||||
// 50% split:
|
||||
// 0-49: Strict (80)
|
||||
// 50-99: Control (50)
|
||||
return normalized < 50 ? 80 : 50;
|
||||
}
|
||||
|
||||
export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
readonly name = 'numerical_classifier';
|
||||
|
||||
@@ -179,11 +152,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
const score = routerResponse.complexity_score;
|
||||
|
||||
const { threshold, groupLabel, modelAlias } =
|
||||
await this.getRoutingDecision(
|
||||
score,
|
||||
config,
|
||||
config.getSessionId() || 'unknown-session',
|
||||
);
|
||||
await this.getRoutingDecision(score, config);
|
||||
const [useGemini3_1, useCustomToolModel] = await Promise.all([
|
||||
config.getGemini31Launched(),
|
||||
config.getUseCustomToolModel(),
|
||||
@@ -214,7 +183,6 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
private async getRoutingDecision(
|
||||
score: number,
|
||||
config: Config,
|
||||
sessionId: string,
|
||||
): Promise<{
|
||||
threshold: number;
|
||||
groupLabel: string;
|
||||
@@ -234,9 +202,8 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
threshold = remoteThresholdValue;
|
||||
groupLabel = 'Remote';
|
||||
} else {
|
||||
// Fallback to deterministic A/B test
|
||||
threshold = getComplexityThreshold(sessionId);
|
||||
groupLabel = threshold === 80 ? 'Strict' : 'Control';
|
||||
threshold = DEFAULT_CLASSIFIER_THRESHOLD;
|
||||
groupLabel = 'Default';
|
||||
}
|
||||
|
||||
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;
|
||||
|
||||
Reference in New Issue
Block a user