perf(core): skip model routing classification when redundant

This commit is contained in:
Akhilesh Kumar
2026-04-16 18:38:43 +00:00
parent 44f9b590eb
commit 5f181f96f8
5 changed files with 175 additions and 8 deletions
+3
View File
@@ -7,5 +7,8 @@
},
"general": {
"devtools": true
},
"model": {
"gemma4Variant": "gemma-4-31b-it"
}
}
@@ -383,6 +383,56 @@ describe('ClassifierStrategy', () => {
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
it('should skip classification if both pro and flash resolve to the same model', async () => {
// We mock the config to trigger the fast path by returning a specific model
// that the router will see as identical for both 'pro' and 'flash' tiers.
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
// By overriding the modelConfigService, we can simulate gemma4Variant
// or any other scenario where both tiers resolve to the same target model.
const mockResolveClassifierModelId = vi
.fn()
.mockReturnValue('gemma-4-31b-it');
Object.defineProperty(
mockConfig.modelConfigService,
'resolveClassifierModelId',
{
value: mockResolveClassifierModelId,
writable: true,
},
);
// We also need to mock config.getExperimentalDynamicModelConfiguration()
// if that is what resolveClassifierModel uses. Since resolveClassifierModel
// is a standalone function, we can mock its behavior indirectly via config.
Object.defineProperty(
mockConfig,
'getExperimentalDynamicModelConfiguration',
{
value: vi.fn().mockReturnValue(true),
writable: true,
},
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: 'gemma-4-31b-it',
metadata: {
source: 'classifier',
latencyMs: 0,
reasoning:
'Skipped classification because both tiers resolve to the same model: gemma-4-31b-it',
},
});
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
describe('Gemini 3.1 and Custom Tools Routing', () => {
it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
@@ -137,6 +137,47 @@ export class ClassifierStrategy implements RoutingStrategy {
const startTime = Date.now();
try {
const model = context.requestedModel ?? config.getModel();
const [useGemini3_1, useGemini3_1FlashLite, useCustomToolModel] =
await Promise.all([
config.getGemini31Launched(),
config.getGemini31FlashLiteLaunched(),
config.getUseCustomToolModel(),
]);
const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true;
// Check if classification is redundant (i.e., both tiers resolve to the same model)
const proModel = resolveClassifierModel(
model,
'pro',
useGemini3_1,
useGemini3_1FlashLite,
useCustomToolModel,
hasAccessToPreview,
config,
);
const flashModel = resolveClassifierModel(
model,
'flash',
useGemini3_1,
useGemini3_1FlashLite,
useCustomToolModel,
hasAccessToPreview,
config,
);
if (proModel === flashModel) {
return {
model: proModel,
metadata: {
source: this.name,
latencyMs: 0,
reasoning: `Skipped classification because both tiers resolve to the same model: ${proModel}`,
},
};
}
if (
(await config.getNumericalRoutingEnabled()) &&
isGemini3Model(model, config)
@@ -171,19 +212,13 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const [useGemini3_1, useGemini3_1FlashLite, useCustomToolModel] =
await Promise.all([
config.getGemini31Launched(),
config.getGemini31FlashLiteLaunched(),
config.getUseCustomToolModel(),
]);
const selectedModel = resolveClassifierModel(
model,
routerResponse.model_choice,
useGemini3_1,
useGemini3_1FlashLite,
useCustomToolModel,
config.getHasAccessToPreviewModel?.() ?? true,
hasAccessToPreview,
config,
);
@@ -32,6 +32,9 @@ describe('GemmaClassifierStrategy', () => {
mockGenerateJson = vi.fn();
mockConfig = {
modelConfigService: {
resolveClassifierModelId: vi.fn(),
},
getGemmaModelRouterSettings: vi.fn().mockReturnValue({
enabled: true,
classifier: { model: 'gemma3-1b-gpu-custom' },
@@ -83,6 +86,44 @@ describe('GemmaClassifierStrategy', () => {
).rejects.toThrow('Only gemma3-1b-gpu-custom has been tested');
});
it('should skip classification if both pro and flash resolve to the same model', async () => {
// Setup the mock config to use the dynamic model config service.
Object.defineProperty(
mockConfig,
'getExperimentalDynamicModelConfiguration',
{
value: vi.fn().mockReturnValue(true),
writable: true,
},
);
Object.defineProperty(
mockConfig.modelConfigService,
'resolveClassifierModelId',
{
value: vi.fn().mockReturnValue('gemma-4-31b-it'),
writable: true,
},
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
mockLocalLiteRtLmClient,
);
expect(decision).toEqual({
model: 'gemma-4-31b-it',
metadata: {
source: 'gemma-classifier',
latencyMs: 0,
reasoning:
'Skipped classification because both tiers resolve to the same model: gemma-4-31b-it',
},
});
expect(mockGenerateJson).not.toHaveBeenCalled();
});
it('should call generateJson with the correct parameters', async () => {
const mockApiResponse = {
reasoning: 'Simple task',
@@ -178,6 +178,39 @@ ${formattedHistory}
return null;
}
const model = context.requestedModel ?? config.getModel();
// Check if classification is redundant (i.e., both tiers resolve to the same model)
const proModel = resolveClassifierModel(
model,
'pro',
false, // useGemini3_1
false, // useGemini3_1FlashLite
false, // useCustomToolModel
true, // hasAccessToPreview
config,
);
const flashModel = resolveClassifierModel(
model,
'flash',
false, // useGemini3_1
false, // useGemini3_1FlashLite
false, // useCustomToolModel
true, // hasAccessToPreview
config,
);
if (proModel === flashModel) {
return {
model: proModel,
metadata: {
source: this.name,
latencyMs: 0,
reasoning: `Skipped classification because both tiers resolve to the same model: ${proModel}`,
},
};
}
// Only the gemma3-1b-gpu-custom model has been tested and verified.
if (gemmaRouterSettings.classifier?.model !== 'gemma3-1b-gpu-custom') {
throw new Error('Only gemma3-1b-gpu-custom has been tested');
@@ -210,8 +243,13 @@ ${formattedHistory}
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
context.requestedModel ?? config.getModel(),
model,
routerResponse.model_choice,
false, // useGemini3_1
false, // useGemini3_1FlashLite
false, // useCustomToolModel
true, // hasAccessToPreview
config,
);
return {