mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
fix(core): enable numerical routing for api key users (#21977)
This commit is contained in:
@@ -492,6 +492,95 @@ describe('Server Config (config.ts)', () => {
|
|||||||
expect(await config.getUserCaching()).toBeUndefined();
|
expect(await config.getUserCaching()).toBeUndefined();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('getNumericalRoutingEnabled', () => {
|
||||||
|
it('should return true by default if there are no experiments', async () => {
|
||||||
|
const config = new Config(baseParams);
|
||||||
|
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true if the remote flag is set to true', async () => {
|
||||||
|
const config = new Config({
|
||||||
|
...baseParams,
|
||||||
|
experiments: {
|
||||||
|
flags: {
|
||||||
|
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
|
||||||
|
boolValue: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
experimentIds: [],
|
||||||
|
},
|
||||||
|
} as unknown as ConfigParameters);
|
||||||
|
expect(await config.getNumericalRoutingEnabled()).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false if the remote flag is explicitly set to false', async () => {
|
||||||
|
const config = new Config({
|
||||||
|
...baseParams,
|
||||||
|
experiments: {
|
||||||
|
flags: {
|
||||||
|
[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]: {
|
||||||
|
boolValue: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
experimentIds: [],
|
||||||
|
},
|
||||||
|
} as unknown as ConfigParameters);
|
||||||
|
expect(await config.getNumericalRoutingEnabled()).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('getResolvedClassifierThreshold', () => {
|
||||||
|
it('should return 90 by default if there are no experiments', async () => {
|
||||||
|
const config = new Config(baseParams);
|
||||||
|
expect(await config.getResolvedClassifierThreshold()).toBe(90);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return the remote flag value if it is within range (0-100)', async () => {
|
||||||
|
const config = new Config({
|
||||||
|
...baseParams,
|
||||||
|
experiments: {
|
||||||
|
flags: {
|
||||||
|
[ExperimentFlags.CLASSIFIER_THRESHOLD]: {
|
||||||
|
intValue: '75',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
experimentIds: [],
|
||||||
|
},
|
||||||
|
} as unknown as ConfigParameters);
|
||||||
|
expect(await config.getResolvedClassifierThreshold()).toBe(75);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 90 if the remote flag is out of range (less than 0)', async () => {
|
||||||
|
const config = new Config({
|
||||||
|
...baseParams,
|
||||||
|
experiments: {
|
||||||
|
flags: {
|
||||||
|
[ExperimentFlags.CLASSIFIER_THRESHOLD]: {
|
||||||
|
intValue: '-10',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
experimentIds: [],
|
||||||
|
},
|
||||||
|
} as unknown as ConfigParameters);
|
||||||
|
expect(await config.getResolvedClassifierThreshold()).toBe(90);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return 90 if the remote flag is out of range (greater than 100)', async () => {
|
||||||
|
const config = new Config({
|
||||||
|
...baseParams,
|
||||||
|
experiments: {
|
||||||
|
flags: {
|
||||||
|
[ExperimentFlags.CLASSIFIER_THRESHOLD]: {
|
||||||
|
intValue: '110',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
experimentIds: [],
|
||||||
|
},
|
||||||
|
} as unknown as ConfigParameters);
|
||||||
|
expect(await config.getResolvedClassifierThreshold()).toBe(90);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('refreshAuth', () => {
|
describe('refreshAuth', () => {
|
||||||
|
|||||||
@@ -2512,8 +2512,30 @@ export class Config implements McpContext, AgentLoopContext {
|
|||||||
async getNumericalRoutingEnabled(): Promise<boolean> {
|
async getNumericalRoutingEnabled(): Promise<boolean> {
|
||||||
await this.ensureExperimentsLoaded();
|
await this.ensureExperimentsLoaded();
|
||||||
|
|
||||||
return !!this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING]
|
const flag =
|
||||||
?.boolValue;
|
this.experiments?.flags[ExperimentFlags.ENABLE_NUMERICAL_ROUTING];
|
||||||
|
return flag?.boolValue ?? true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the resolved complexity threshold for routing.
|
||||||
|
* If a remote threshold is provided and within range (0-100), it is returned.
|
||||||
|
* Otherwise, the default threshold (90) is returned.
|
||||||
|
*/
|
||||||
|
async getResolvedClassifierThreshold(): Promise<number> {
|
||||||
|
const remoteValue = await this.getClassifierThreshold();
|
||||||
|
const defaultValue = 90;
|
||||||
|
|
||||||
|
if (
|
||||||
|
remoteValue !== undefined &&
|
||||||
|
!isNaN(remoteValue) &&
|
||||||
|
remoteValue >= 0 &&
|
||||||
|
remoteValue <= 100
|
||||||
|
) {
|
||||||
|
return remoteValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
async getClassifierThreshold(): Promise<number | undefined> {
|
async getClassifierThreshold(): Promise<number | undefined> {
|
||||||
|
|||||||
@@ -54,7 +54,10 @@ describe('ModelRouterService', () => {
|
|||||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||||
mockLocalLiteRtLmClient,
|
mockLocalLiteRtLmClient,
|
||||||
);
|
);
|
||||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(true);
|
||||||
|
vi.spyOn(mockConfig, 'getResolvedClassifierThreshold').mockResolvedValue(
|
||||||
|
90,
|
||||||
|
);
|
||||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||||
enabled: false,
|
enabled: false,
|
||||||
@@ -182,8 +185,8 @@ describe('ModelRouterService', () => {
|
|||||||
false,
|
false,
|
||||||
undefined,
|
undefined,
|
||||||
ApprovalMode.DEFAULT,
|
ApprovalMode.DEFAULT,
|
||||||
false,
|
true,
|
||||||
undefined,
|
'90',
|
||||||
);
|
);
|
||||||
expect(logModelRouting).toHaveBeenCalledWith(
|
expect(logModelRouting).toHaveBeenCalledWith(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
@@ -209,8 +212,8 @@ describe('ModelRouterService', () => {
|
|||||||
true,
|
true,
|
||||||
'Strategy failed',
|
'Strategy failed',
|
||||||
ApprovalMode.DEFAULT,
|
ApprovalMode.DEFAULT,
|
||||||
false,
|
true,
|
||||||
undefined,
|
'90',
|
||||||
);
|
);
|
||||||
expect(logModelRouting).toHaveBeenCalledWith(
|
expect(logModelRouting).toHaveBeenCalledWith(
|
||||||
mockConfig,
|
mockConfig,
|
||||||
|
|||||||
@@ -78,10 +78,9 @@ export class ModelRouterService {
|
|||||||
|
|
||||||
const [enableNumericalRouting, thresholdValue] = await Promise.all([
|
const [enableNumericalRouting, thresholdValue] = await Promise.all([
|
||||||
this.config.getNumericalRoutingEnabled(),
|
this.config.getNumericalRoutingEnabled(),
|
||||||
this.config.getClassifierThreshold(),
|
this.config.getResolvedClassifierThreshold(),
|
||||||
]);
|
]);
|
||||||
const classifierThreshold =
|
const classifierThreshold = String(thresholdValue);
|
||||||
thresholdValue !== undefined ? String(thresholdValue) : undefined;
|
|
||||||
|
|
||||||
let failed = false;
|
let failed = false;
|
||||||
let error_message: string | undefined;
|
let error_message: string | undefined;
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||||
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
||||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||||
|
getResolvedClassifierThreshold: vi.fn().mockResolvedValue(90),
|
||||||
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
||||||
getGemini31Launched: vi.fn().mockResolvedValue(false),
|
getGemini31Launched: vi.fn().mockResolvedValue(false),
|
||||||
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
|
getUseCustomToolModel: vi.fn().mockImplementation(async () => {
|
||||||
@@ -152,12 +153,11 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(textPart?.text).toBe('simple task');
|
expect(textPart?.text).toBe('simple task');
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('A/B Testing Logic (Deterministic)', () => {
|
describe('Default Logic', () => {
|
||||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 40 -> FLASH', async () => {
|
it('should route to FLASH when score is below 90', async () => {
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Hash 71 -> Control
|
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Standard task',
|
complexity_reasoning: 'Standard task',
|
||||||
complexity_score: 40,
|
complexity_score: 80,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -173,72 +173,17 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Default)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('Control Group (SessionID "control-group-id" -> Threshold 50): Score 60 -> PRO', async () => {
|
it('should route to PRO when score is 90 or above', async () => {
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
|
||||||
const mockApiResponse = {
|
|
||||||
complexity_reasoning: 'Complex task',
|
|
||||||
complexity_score: 60,
|
|
||||||
};
|
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
const decision = await strategy.route(
|
|
||||||
mockContext,
|
|
||||||
mockConfig,
|
|
||||||
mockBaseLlmClient,
|
|
||||||
mockLocalLiteRtLmClient,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(decision).toEqual({
|
|
||||||
model: PREVIEW_GEMINI_MODEL,
|
|
||||||
metadata: {
|
|
||||||
source: 'NumericalClassifier (Control)',
|
|
||||||
latencyMs: expect.any(Number),
|
|
||||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 60 -> FLASH', async () => {
|
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1'); // FNV Normalized 18 < 50 -> Strict
|
|
||||||
const mockApiResponse = {
|
|
||||||
complexity_reasoning: 'Complex task',
|
|
||||||
complexity_score: 60,
|
|
||||||
};
|
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
const decision = await strategy.route(
|
|
||||||
mockContext,
|
|
||||||
mockConfig,
|
|
||||||
mockBaseLlmClient,
|
|
||||||
mockLocalLiteRtLmClient,
|
|
||||||
);
|
|
||||||
|
|
||||||
expect(decision).toEqual({
|
|
||||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
|
||||||
metadata: {
|
|
||||||
source: 'NumericalClassifier (Strict)',
|
|
||||||
latencyMs: expect.any(Number),
|
|
||||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 80'),
|
|
||||||
},
|
|
||||||
});
|
|
||||||
});
|
|
||||||
|
|
||||||
it('Strict Group (SessionID "test-session-1" -> Threshold 80): Score 90 -> PRO', async () => {
|
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('test-session-1');
|
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Extreme task',
|
complexity_reasoning: 'Extreme task',
|
||||||
complexity_score: 90,
|
complexity_score: 95,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -254,9 +199,9 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: PREVIEW_GEMINI_MODEL,
|
model: PREVIEW_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Strict)',
|
source: 'NumericalClassifier (Default)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
reasoning: expect.stringContaining('Score: 90 / Threshold: 80'),
|
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -265,6 +210,9 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
describe('Remote Threshold Logic', () => {
|
describe('Remote Threshold Logic', () => {
|
||||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => {
|
it('should use the remote CLASSIFIER_THRESHOLD if provided (int value)', async () => {
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(70);
|
||||||
|
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||||
|
70,
|
||||||
|
);
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 60,
|
complexity_score: 60,
|
||||||
@@ -292,6 +240,9 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
|
|
||||||
it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => {
|
it('should use the remote CLASSIFIER_THRESHOLD if provided (float value)', async () => {
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(45.5);
|
||||||
|
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||||
|
45.5,
|
||||||
|
);
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 40,
|
complexity_score: 40,
|
||||||
@@ -319,6 +270,9 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
|
|
||||||
it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => {
|
it('should use PRO model if score >= remote CLASSIFIER_THRESHOLD', async () => {
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(30);
|
||||||
|
vi.mocked(mockConfig.getResolvedClassifierThreshold).mockResolvedValue(
|
||||||
|
30,
|
||||||
|
);
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 35,
|
complexity_score: 35,
|
||||||
@@ -344,13 +298,12 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
it('should fall back to default logic if CLASSIFIER_THRESHOLD is not present in experiments', async () => {
|
||||||
// Mock getClassifierThreshold to return undefined
|
// Mock getClassifierThreshold to return undefined
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(undefined);
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id'); // Should resolve to Control (50)
|
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 40,
|
complexity_score: 80,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -364,21 +317,20 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 80 < Default Threshold 90
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Default)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (less than 0)', async () => {
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(-10);
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 40,
|
complexity_score: 80,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -394,19 +346,18 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Default)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
reasoning: expect.stringContaining('Score: 40 / Threshold: 50'),
|
reasoning: expect.stringContaining('Score: 80 / Threshold: 90'),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should fall back to A/B testing if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
it('should fall back to default logic if CLASSIFIER_THRESHOLD is out of range (greater than 100)', async () => {
|
||||||
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
|
vi.mocked(mockConfig.getClassifierThreshold).mockResolvedValue(110);
|
||||||
vi.mocked(mockConfig.getSessionId).mockReturnValue('control-group-id');
|
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Test task',
|
complexity_reasoning: 'Test task',
|
||||||
complexity_score: 60,
|
complexity_score: 95,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -422,9 +373,9 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: PREVIEW_GEMINI_MODEL,
|
model: PREVIEW_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Default)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
reasoning: expect.stringContaining('Score: 60 / Threshold: 50'),
|
reasoning: expect.stringContaining('Score: 95 / Threshold: 90'),
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -591,7 +542,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
|
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Complex task',
|
complexity_reasoning: 'Complex task',
|
||||||
complexity_score: 80,
|
complexity_score: 95,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -613,7 +564,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
});
|
});
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Complex task',
|
complexity_reasoning: 'Complex task',
|
||||||
complexity_score: 80,
|
complexity_score: 95,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
@@ -636,7 +587,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
});
|
});
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Complex task',
|
complexity_reasoning: 'Complex task',
|
||||||
complexity_score: 80,
|
complexity_score: 95,
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
mockApiResponse,
|
mockApiResponse,
|
||||||
|
|||||||
@@ -93,39 +93,6 @@ const ClassifierResponseSchema = z.object({
|
|||||||
complexity_score: z.number().min(1).max(100),
|
complexity_score: z.number().min(1).max(100),
|
||||||
});
|
});
|
||||||
|
|
||||||
/**
|
|
||||||
* Deterministically calculates the routing threshold based on the session ID.
|
|
||||||
* This ensures a consistent experience for the user within a session.
|
|
||||||
*
|
|
||||||
* This implementation uses the FNV-1a hash algorithm (32-bit).
|
|
||||||
* @see https://en.wikipedia.org/wiki/Fowler%E2%80%93Noll%E2%80%93Vo_hash_function
|
|
||||||
*
|
|
||||||
* @param sessionId The unique session identifier.
|
|
||||||
* @returns The threshold (50 or 80).
|
|
||||||
*/
|
|
||||||
function getComplexityThreshold(sessionId: string): number {
|
|
||||||
const FNV_OFFSET_BASIS_32 = 0x811c9dc5;
|
|
||||||
const FNV_PRIME_32 = 0x01000193;
|
|
||||||
|
|
||||||
let hash = FNV_OFFSET_BASIS_32;
|
|
||||||
|
|
||||||
for (let i = 0; i < sessionId.length; i++) {
|
|
||||||
hash ^= sessionId.charCodeAt(i);
|
|
||||||
// Multiply by prime (simulate 32-bit overflow with bitwise shift)
|
|
||||||
hash = Math.imul(hash, FNV_PRIME_32);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ensure positive integer
|
|
||||||
hash = hash >>> 0;
|
|
||||||
|
|
||||||
// Normalize to 0-99
|
|
||||||
const normalized = hash % 100;
|
|
||||||
// 50% split:
|
|
||||||
// 0-49: Strict (80)
|
|
||||||
// 50-99: Control (50)
|
|
||||||
return normalized < 50 ? 80 : 50;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class NumericalClassifierStrategy implements RoutingStrategy {
|
export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||||
readonly name = 'numerical_classifier';
|
readonly name = 'numerical_classifier';
|
||||||
|
|
||||||
@@ -179,11 +146,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
|||||||
const score = routerResponse.complexity_score;
|
const score = routerResponse.complexity_score;
|
||||||
|
|
||||||
const { threshold, groupLabel, modelAlias } =
|
const { threshold, groupLabel, modelAlias } =
|
||||||
await this.getRoutingDecision(
|
await this.getRoutingDecision(score, config);
|
||||||
score,
|
|
||||||
config,
|
|
||||||
config.getSessionId() || 'unknown-session',
|
|
||||||
);
|
|
||||||
const [useGemini3_1, useCustomToolModel] = await Promise.all([
|
const [useGemini3_1, useCustomToolModel] = await Promise.all([
|
||||||
config.getGemini31Launched(),
|
config.getGemini31Launched(),
|
||||||
config.getUseCustomToolModel(),
|
config.getUseCustomToolModel(),
|
||||||
@@ -214,29 +177,19 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
|||||||
private async getRoutingDecision(
|
private async getRoutingDecision(
|
||||||
score: number,
|
score: number,
|
||||||
config: Config,
|
config: Config,
|
||||||
sessionId: string,
|
|
||||||
): Promise<{
|
): Promise<{
|
||||||
threshold: number;
|
threshold: number;
|
||||||
groupLabel: string;
|
groupLabel: string;
|
||||||
modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL;
|
modelAlias: typeof FLASH_MODEL | typeof PRO_MODEL;
|
||||||
}> {
|
}> {
|
||||||
let threshold: number;
|
const threshold = await config.getResolvedClassifierThreshold();
|
||||||
let groupLabel: string;
|
|
||||||
|
|
||||||
const remoteThresholdValue = await config.getClassifierThreshold();
|
const remoteThresholdValue = await config.getClassifierThreshold();
|
||||||
|
|
||||||
if (
|
let groupLabel: string;
|
||||||
remoteThresholdValue !== undefined &&
|
if (threshold === remoteThresholdValue) {
|
||||||
!isNaN(remoteThresholdValue) &&
|
|
||||||
remoteThresholdValue >= 0 &&
|
|
||||||
remoteThresholdValue <= 100
|
|
||||||
) {
|
|
||||||
threshold = remoteThresholdValue;
|
|
||||||
groupLabel = 'Remote';
|
groupLabel = 'Remote';
|
||||||
} else {
|
} else {
|
||||||
// Fallback to deterministic A/B test
|
groupLabel = 'Default';
|
||||||
threshold = getComplexityThreshold(sessionId);
|
|
||||||
groupLabel = threshold === 80 ? 'Strict' : 'Control';
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;
|
const modelAlias = score >= threshold ? PRO_MODEL : FLASH_MODEL;
|
||||||
|
|||||||
Reference in New Issue
Block a user