From bee5b638dd5a2b2ba6f86dbd484693217a81cd6f Mon Sep 17 00:00:00 2001
From: Abhi <43648792+abhipatel12@users.noreply.github.com>
Date: Mon, 15 Sep 2025 19:51:25 -0400
Subject: [PATCH] feat(routing): Introduce Classifier-based Model Routing
Strategy (#8455)
---
.../src/routing/modelRouterService.test.ts | 14 +-
.../core/src/routing/modelRouterService.ts | 8 +-
.../strategies/classifierStrategy.test.ts | 277 ++++++++++++++++++
.../routing/strategies/classifierStrategy.ts | 213 ++++++++++++++
4 files changed, 508 insertions(+), 4 deletions(-)
create mode 100644 packages/core/src/routing/strategies/classifierStrategy.test.ts
create mode 100644 packages/core/src/routing/strategies/classifierStrategy.ts
diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts
index 0f83796787..ad2fda7ae1 100644
--- a/packages/core/src/routing/modelRouterService.test.ts
+++ b/packages/core/src/routing/modelRouterService.test.ts
@@ -13,6 +13,7 @@ import { DefaultStrategy } from './strategies/defaultStrategy.js';
import { CompositeStrategy } from './strategies/compositeStrategy.js';
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
import { OverrideStrategy } from './strategies/overrideStrategy.js';
+import { ClassifierStrategy } from './strategies/classifierStrategy.js';
vi.mock('../config/config.js');
vi.mock('../core/baseLlmClient.js');
@@ -20,6 +21,7 @@ vi.mock('./strategies/defaultStrategy.js');
vi.mock('./strategies/compositeStrategy.js');
vi.mock('./strategies/fallbackStrategy.js');
vi.mock('./strategies/overrideStrategy.js');
+vi.mock('./strategies/classifierStrategy.js');
describe('ModelRouterService', () => {
let service: ModelRouterService;
@@ -36,7 +38,12 @@ describe('ModelRouterService', () => {
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
mockCompositeStrategy = new CompositeStrategy(
- [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
+ [
+ new FallbackStrategy(),
+ new OverrideStrategy(),
+ new ClassifierStrategy(),
+ new DefaultStrategy(),
+ ],
'agent-router',
);
vi.mocked(CompositeStrategy).mockImplementation(
@@ -62,10 +69,11 @@ describe('ModelRouterService', () => {
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
const childStrategies = compositeStrategyArgs[0];
- expect(childStrategies.length).toBe(3);
+ expect(childStrategies.length).toBe(4);
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
- expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy);
+ expect(childStrategies[2]).toBeInstanceOf(ClassifierStrategy);
+ expect(childStrategies[3]).toBeInstanceOf(DefaultStrategy);
expect(compositeStrategyArgs[1]).toBe('agent-router');
});
diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts
index a984125f89..b4feea5d26 100644
--- a/packages/core/src/routing/modelRouterService.ts
+++ b/packages/core/src/routing/modelRouterService.ts
@@ -11,6 +11,7 @@ import type {
TerminalStrategy,
} from './routingStrategy.js';
import { DefaultStrategy } from './strategies/defaultStrategy.js';
+import { ClassifierStrategy } from './strategies/classifierStrategy.js';
import { CompositeStrategy } from './strategies/compositeStrategy.js';
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
import { OverrideStrategy } from './strategies/overrideStrategy.js';
@@ -31,7 +32,12 @@ export class ModelRouterService {
// Initialize the composite strategy with the desired priority order.
// The strategies are ordered in order of highest priority.
return new CompositeStrategy(
- [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
+ [
+ new FallbackStrategy(),
+ new OverrideStrategy(),
+ new ClassifierStrategy(),
+ new DefaultStrategy(),
+ ],
'agent-router',
);
}
diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts
new file mode 100644
index 0000000000..34c6500df1
--- /dev/null
+++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts
@@ -0,0 +1,277 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { describe, it, expect, vi, beforeEach } from 'vitest';
+import { ClassifierStrategy } from './classifierStrategy.js';
+import type { RoutingContext } from '../routingStrategy.js';
+import type { Config } from '../../config/config.js';
+import type { BaseLlmClient } from '../../core/baseLlmClient.js';
+import {
+ isFunctionCall,
+ isFunctionResponse,
+} from '../../utils/messageInspectors.js';
+import {
+ DEFAULT_GEMINI_FLASH_MODEL,
+ DEFAULT_GEMINI_FLASH_LITE_MODEL,
+ DEFAULT_GEMINI_MODEL,
+} from '../../config/models.js';
+import { promptIdContext } from '../../utils/promptIdContext.js';
+import type { Content } from '@google/genai';
+
+vi.mock('../../core/baseLlmClient.js');
+vi.mock('../../utils/promptIdContext.js');
+
+describe('ClassifierStrategy', () => {
+ let strategy: ClassifierStrategy;
+ let mockContext: RoutingContext;
+ let mockConfig: Config;
+ let mockBaseLlmClient: BaseLlmClient;
+
+ beforeEach(() => {
+ vi.clearAllMocks();
+
+ strategy = new ClassifierStrategy();
+ mockContext = {
+ history: [],
+ request: [{ text: 'simple task' }],
+ signal: new AbortController().signal,
+ };
+ mockConfig = {} as Config;
+ mockBaseLlmClient = {
+ generateJson: vi.fn(),
+ } as unknown as BaseLlmClient;
+
+ vi.mocked(promptIdContext.getStore).mockReturnValue('test-prompt-id');
+ });
+
+ it('should call generateJson with the correct parameters', async () => {
+ const mockApiResponse = {
+ reasoning: 'Simple task',
+ model_choice: 'flash',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+
+ await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
+
+ expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
+ expect.objectContaining({
+ model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
+ config: expect.objectContaining({
+ temperature: 0,
+ maxOutputTokens: 1024,
+ thinkingConfig: {
+ thinkingBudget: 512,
+ },
+ }),
+ promptId: 'test-prompt-id',
+ }),
+ );
+ });
+
+ it('should route to FLASH model for a simple task', async () => {
+ const mockApiResponse = {
+ reasoning: 'This is a simple task.',
+ model_choice: 'flash',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+
+ const decision = await strategy.route(
+ mockContext,
+ mockConfig,
+ mockBaseLlmClient,
+ );
+
+ expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
+ expect(decision).toEqual({
+ model: DEFAULT_GEMINI_FLASH_MODEL,
+ metadata: {
+ source: 'Classifier',
+ latencyMs: expect.any(Number),
+ reasoning: mockApiResponse.reasoning,
+ },
+ });
+ });
+
+ it('should route to PRO model for a complex task', async () => {
+ const mockApiResponse = {
+ reasoning: 'This is a complex task.',
+ model_choice: 'pro',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+ mockContext.request = [{ text: 'how do I build a spaceship?' }];
+
+ const decision = await strategy.route(
+ mockContext,
+ mockConfig,
+ mockBaseLlmClient,
+ );
+
+ expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
+ expect(decision).toEqual({
+ model: DEFAULT_GEMINI_MODEL,
+ metadata: {
+ source: 'Classifier',
+ latencyMs: expect.any(Number),
+ reasoning: mockApiResponse.reasoning,
+ },
+ });
+ });
+
+ it('should return null if the classifier API call fails', async () => {
+ const consoleWarnSpy = vi
+ .spyOn(console, 'warn')
+ .mockImplementation(() => {});
+ const testError = new Error('API Failure');
+ vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError);
+
+ const decision = await strategy.route(
+ mockContext,
+ mockConfig,
+ mockBaseLlmClient,
+ );
+
+ expect(decision).toBeNull();
+ expect(consoleWarnSpy).toHaveBeenCalled();
+ consoleWarnSpy.mockRestore();
+ });
+
+ it('should return null if the classifier returns a malformed JSON object', async () => {
+ const consoleWarnSpy = vi
+ .spyOn(console, 'warn')
+ .mockImplementation(() => {});
+ const malformedApiResponse = {
+ reasoning: 'This is a simple task.',
+ // model_choice is missing, which will cause a Zod parsing error.
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ malformedApiResponse,
+ );
+
+ const decision = await strategy.route(
+ mockContext,
+ mockConfig,
+ mockBaseLlmClient,
+ );
+
+ expect(decision).toBeNull();
+ expect(consoleWarnSpy).toHaveBeenCalled();
+ consoleWarnSpy.mockRestore();
+ });
+
+ it('should filter out tool-related history before sending to classifier', async () => {
+ mockContext.history = [
+ { role: 'user', parts: [{ text: 'call a tool' }] },
+ { role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] },
+ {
+ role: 'user',
+ parts: [
+ { functionResponse: { name: 'test_tool', response: { ok: true } } },
+ ],
+ },
+ { role: 'user', parts: [{ text: 'another user turn' }] },
+ ];
+ const mockApiResponse = {
+ reasoning: 'Simple.',
+ model_choice: 'flash',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+
+ await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
+
+ const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
+ .calls[0][0];
+ const contents = generateJsonCall.contents;
+
+ const expectedContents = [
+ { role: 'user', parts: [{ text: 'call a tool' }] },
+ { role: 'user', parts: [{ text: 'another user turn' }] },
+ { role: 'user', parts: [{ text: 'simple task' }] },
+ ];
+
+ expect(contents).toEqual(expectedContents);
+ });
+
+ it('should respect HISTORY_SEARCH_WINDOW and HISTORY_TURNS_FOR_CONTEXT', async () => {
+ const longHistory: Content[] = [];
+ for (let i = 0; i < 30; i++) {
+ longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
+ // Add noise that should be filtered
+ if (i % 2 === 0) {
+ longHistory.push({
+ role: 'model',
+ parts: [{ functionCall: { name: 'noise', args: {} } }],
+ });
+ }
+ }
+ mockContext.history = longHistory;
+ const mockApiResponse = {
+ reasoning: 'Simple.',
+ model_choice: 'flash',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+
+ await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
+
+ const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
+ .calls[0][0];
+ const contents = generateJsonCall.contents;
+
+ // Manually calculate what the history should be
+ const HISTORY_SEARCH_WINDOW = 20;
+ const HISTORY_TURNS_FOR_CONTEXT = 4;
+ const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW);
+ const cleanHistory = historySlice.filter(
+ (content) => !isFunctionCall(content) && !isFunctionResponse(content),
+ );
+ const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
+
+ expect(contents).toEqual([
+ ...finalHistory,
+ { role: 'user', parts: mockContext.request },
+ ]);
+ // There should be 4 history items + the current request
+ expect(contents).toHaveLength(5);
+ });
+
+ it('should use a fallback promptId if not found in context', async () => {
+ const consoleWarnSpy = vi
+ .spyOn(console, 'warn')
+ .mockImplementation(() => {});
+ vi.mocked(promptIdContext.getStore).mockReturnValue(undefined);
+ const mockApiResponse = {
+ reasoning: 'Simple.',
+ model_choice: 'flash',
+ };
+ vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
+ mockApiResponse,
+ );
+
+ await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
+
+ const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
+ .calls[0][0];
+
+ expect(generateJsonCall.promptId).toMatch(
+ /^classifier-router-fallback-\d+-\w+$/,
+ );
+ expect(consoleWarnSpy).toHaveBeenCalledWith(
+ expect.stringContaining(
+ 'Could not find promptId in context. This is unexpected. Using a fallback ID:',
+ ),
+ );
+ consoleWarnSpy.mockRestore();
+ });
+});
diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts
new file mode 100644
index 0000000000..f3af80ecb6
--- /dev/null
+++ b/packages/core/src/routing/strategies/classifierStrategy.ts
@@ -0,0 +1,213 @@
+/**
+ * @license
+ * Copyright 2025 Google LLC
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+import { z } from 'zod';
+import type { BaseLlmClient } from '../../core/baseLlmClient.js';
+import { promptIdContext } from '../../utils/promptIdContext.js';
+import type {
+ RoutingContext,
+ RoutingDecision,
+ RoutingStrategy,
+} from '../routingStrategy.js';
+import {
+ DEFAULT_GEMINI_FLASH_MODEL,
+ DEFAULT_GEMINI_FLASH_LITE_MODEL,
+ DEFAULT_GEMINI_MODEL,
+} from '../../config/models.js';
+import {
+ type GenerateContentConfig,
+ createUserContent,
+ Type,
+} from '@google/genai';
+import type { Config } from '../../config/config.js';
+import {
+ isFunctionCall,
+ isFunctionResponse,
+} from '../../utils/messageInspectors.js';
+
+const CLASSIFIER_GENERATION_CONFIG: GenerateContentConfig = {
+ temperature: 0,
+ maxOutputTokens: 1024,
+ thinkingConfig: {
+ thinkingBudget: 512, // This counts towards output max, so we don't want -1.
+ },
+};
+
+// The number of recent history turns to provide to the router for context.
+const HISTORY_TURNS_FOR_CONTEXT = 4;
+const HISTORY_SEARCH_WINDOW = 20;
+
+const FLASH_MODEL = 'flash';
+const PRO_MODEL = 'pro';
+
+const CLASSIFIER_SYSTEM_PROMPT = `
+You are a specialized Task Routing AI. Your sole function is to analyze the user's request and classify its complexity. Choose between \`${FLASH_MODEL}\` (SIMPLE) or \`${PRO_MODEL}\` (COMPLEX).
+1. \`${FLASH_MODEL}\`: A fast, efficient model for simple, well-defined tasks.
+2. \`${PRO_MODEL}\`: A powerful, advanced model for complex, open-ended, or multi-step tasks.
+
+A task is COMPLEX (Choose \`${PRO_MODEL}\`) if it meets ONE OR MORE of the following criteria:
+1. **High Operational Complexity (Est. 4+ Steps/Tool Calls):** Requires dependent actions, significant planning, or multiple coordinated changes.
+2. **Strategic Planning & Conceptual Design:** Asking "how" or "why." Requires advice, architecture, or high-level strategy.
+3. **High Ambiguity or Large Scope (Extensive Investigation):** Broadly defined requests requiring extensive investigation.
+4. **Deep Debugging & Root Cause Analysis:** Diagnosing unknown or complex problems from symptoms.
+A task is SIMPLE (Choose \`${FLASH_MODEL}\`) if it is highly specific, bounded, and has Low Operational Complexity (Est. 1-3 tool calls). Operational simplicity overrides strategic phrasing.
+
+**Output Format:**
+Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
+{
+ "type": "object",
+ "properties": {
+ "reasoning": {
+ "type": "string",
+ "description": "A brief, step-by-step explanation for the model choice, referencing the rubric."
+ },
+ "model_choice": {
+ "type": "string",
+ "enum": ["${FLASH_MODEL}", "${PRO_MODEL}"]
+ }
+ },
+ "required": ["reasoning", "model_choice"]
+}
+--- EXAMPLES ---
+**Example 1 (Strategic Planning):**
+*User Prompt:* "How should I architect the data pipeline for this new analytics service?"
+*Your JSON Output:*
+{
+ "reasoning": "The user is asking for high-level architectural design and strategy. This falls under 'Strategic Planning & Conceptual Design'.",
+ "model_choice": "${PRO_MODEL}"
+}
+**Example 2 (Simple Tool Use):**
+*User Prompt:* "list the files in the current directory"
+*Your JSON Output:*
+{
+ "reasoning": "This is a direct command requiring a single tool call (ls). It has Low Operational Complexity (1 step).",
+ "model_choice": "${FLASH_MODEL}"
+}
+**Example 3 (High Operational Complexity):**
+*User Prompt:* "I need to add a new 'email' field to the User schema in 'src/models/user.ts', migrate the database, and update the registration endpoint."
+*Your JSON Output:*
+{
+ "reasoning": "This request involves multiple coordinated steps across different files and systems. This meets the criteria for High Operational Complexity (4+ steps).",
+ "model_choice": "${PRO_MODEL}"
+}
+**Example 4 (Simple Read):**
+*User Prompt:* "Read the contents of 'package.json'."
+*Your JSON Output:*
+{
+ "reasoning": "This is a direct command requiring a single read. It has Low Operational Complexity (1 step).",
+ "model_choice": "${FLASH_MODEL}"
+}
+
+**Example 5 (Deep Debugging):**
+*User Prompt:* "I'm getting an error 'Cannot read property 'map' of undefined' when I click the save button. Can you fix it?"
+*Your JSON Output:*
+{
+ "reasoning": "The user is reporting an error symptom without a known cause. This requires investigation and falls under 'Deep Debugging'.",
+ "model_choice": "${PRO_MODEL}"
+}
+**Example 6 (Simple Edit despite Phrasing):**
+*User Prompt:* "What is the best way to rename the variable 'data' to 'userData' in 'src/utils.js'?"
+*Your JSON Output:*
+{
+ "reasoning": "Although the user uses strategic language ('best way'), the underlying task is a localized edit. The operational complexity is low (1-2 steps).",
+ "model_choice": "${FLASH_MODEL}"
+}
+`;
+
+const RESPONSE_SCHEMA = {
+ type: Type.OBJECT,
+ properties: {
+ reasoning: {
+ type: Type.STRING,
+ description:
+ 'A brief, step-by-step explanation for the model choice, referencing the rubric.',
+ },
+ model_choice: {
+ type: Type.STRING,
+ enum: [FLASH_MODEL, PRO_MODEL],
+ },
+ },
+ required: ['reasoning', 'model_choice'],
+};
+
+const ClassifierResponseSchema = z.object({
+ reasoning: z.string(),
+ model_choice: z.enum([FLASH_MODEL, PRO_MODEL]),
+});
+
+export class ClassifierStrategy implements RoutingStrategy {
+ readonly name = 'classifier';
+
+ async route(
+ context: RoutingContext,
+ _config: Config,
+ baseLlmClient: BaseLlmClient,
+ ): Promise {
+ const startTime = Date.now();
+ try {
+ let promptId = promptIdContext.getStore();
+ if (!promptId) {
+ promptId = `classifier-router-fallback-${Date.now()}-${Math.random()
+ .toString(16)
+ .slice(2)}`;
+ console.warn(
+ `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
+ );
+ }
+
+ const historySlice = context.history.slice(-HISTORY_SEARCH_WINDOW);
+
+ // Filter out tool-related turns.
+ // TODO - Consider using function req/res if they help accuracy.
+ const cleanHistory = historySlice.filter(
+ (content) => !isFunctionCall(content) && !isFunctionResponse(content),
+ );
+
+ // Take the last N turns from the *cleaned* history.
+ const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
+
+ const jsonResponse = await baseLlmClient.generateJson({
+ contents: [...finalHistory, createUserContent(context.request)],
+ schema: RESPONSE_SCHEMA,
+ model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
+ systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
+ config: CLASSIFIER_GENERATION_CONFIG,
+ abortSignal: context.signal,
+ promptId,
+ });
+
+ const routerResponse = ClassifierResponseSchema.parse(jsonResponse);
+
+ const reasoning = routerResponse.reasoning;
+ const latencyMs = Date.now() - startTime;
+
+ if (routerResponse.model_choice === FLASH_MODEL) {
+ return {
+ model: DEFAULT_GEMINI_FLASH_MODEL,
+ metadata: {
+ source: 'Classifier',
+ latencyMs,
+ reasoning,
+ },
+ };
+ } else {
+ return {
+ model: DEFAULT_GEMINI_MODEL,
+ metadata: {
+ source: 'Classifier',
+ reasoning,
+ latencyMs,
+ },
+ };
+ }
+ } catch (error) {
+ // If the classifier fails for any reason (API error, parsing error, etc.),
+ // we log it and return null to allow the composite strategy to proceed.
+ console.warn(`[Routing] ClassifierStrategy failed:`, error);
+ return null;
+ }
+ }
+}