From bee5b638dd5a2b2ba6f86dbd484693217a81cd6f Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Mon, 15 Sep 2025 19:51:25 -0400 Subject: [PATCH] feat(routing): Introduce Classifier-based Model Routing Strategy (#8455) --- .../src/routing/modelRouterService.test.ts | 14 +- .../core/src/routing/modelRouterService.ts | 8 +- .../strategies/classifierStrategy.test.ts | 277 ++++++++++++++++++ .../routing/strategies/classifierStrategy.ts | 213 ++++++++++++++ 4 files changed, 508 insertions(+), 4 deletions(-) create mode 100644 packages/core/src/routing/strategies/classifierStrategy.test.ts create mode 100644 packages/core/src/routing/strategies/classifierStrategy.ts diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts index 0f83796787..ad2fda7ae1 100644 --- a/packages/core/src/routing/modelRouterService.test.ts +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -13,6 +13,7 @@ import { DefaultStrategy } from './strategies/defaultStrategy.js'; import { CompositeStrategy } from './strategies/compositeStrategy.js'; import { FallbackStrategy } from './strategies/fallbackStrategy.js'; import { OverrideStrategy } from './strategies/overrideStrategy.js'; +import { ClassifierStrategy } from './strategies/classifierStrategy.js'; vi.mock('../config/config.js'); vi.mock('../core/baseLlmClient.js'); @@ -20,6 +21,7 @@ vi.mock('./strategies/defaultStrategy.js'); vi.mock('./strategies/compositeStrategy.js'); vi.mock('./strategies/fallbackStrategy.js'); vi.mock('./strategies/overrideStrategy.js'); +vi.mock('./strategies/classifierStrategy.js'); describe('ModelRouterService', () => { let service: ModelRouterService; @@ -36,7 +38,12 @@ describe('ModelRouterService', () => { vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient); mockCompositeStrategy = new CompositeStrategy( - [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + [ + new FallbackStrategy(), + new OverrideStrategy(), + new ClassifierStrategy(), + new DefaultStrategy(), + ], 'agent-router', ); vi.mocked(CompositeStrategy).mockImplementation( @@ -62,10 +69,11 @@ describe('ModelRouterService', () => { const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0]; const childStrategies = compositeStrategyArgs[0]; - expect(childStrategies.length).toBe(3); + expect(childStrategies.length).toBe(4); expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy); expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy); - expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy); + expect(childStrategies[2]).toBeInstanceOf(ClassifierStrategy); + expect(childStrategies[3]).toBeInstanceOf(DefaultStrategy); expect(compositeStrategyArgs[1]).toBe('agent-router'); }); diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts index a984125f89..b4feea5d26 100644 --- a/packages/core/src/routing/modelRouterService.ts +++ b/packages/core/src/routing/modelRouterService.ts @@ -11,6 +11,7 @@ import type { TerminalStrategy, } from './routingStrategy.js'; import { DefaultStrategy } from './strategies/defaultStrategy.js'; +import { ClassifierStrategy } from './strategies/classifierStrategy.js'; import { CompositeStrategy } from './strategies/compositeStrategy.js'; import { FallbackStrategy } from './strategies/fallbackStrategy.js'; import { OverrideStrategy } from './strategies/overrideStrategy.js'; @@ -31,7 +32,12 @@ export class ModelRouterService { // Initialize the composite strategy with the desired priority order. // The strategies are ordered in order of highest priority. return new CompositeStrategy( - [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + [ + new FallbackStrategy(), + new OverrideStrategy(), + new ClassifierStrategy(), + new DefaultStrategy(), + ], 'agent-router', ); } diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts new file mode 100644 index 0000000000..34c6500df1 --- /dev/null +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -0,0 +1,277 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ClassifierStrategy } from './classifierStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import { + isFunctionCall, + isFunctionResponse, +} from '../../utils/messageInspectors.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../../config/models.js'; +import { promptIdContext } from '../../utils/promptIdContext.js'; +import type { Content } from '@google/genai'; + +vi.mock('../../core/baseLlmClient.js'); +vi.mock('../../utils/promptIdContext.js'); + +describe('ClassifierStrategy', () => { + let strategy: ClassifierStrategy; + let mockContext: RoutingContext; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + + beforeEach(() => { + vi.clearAllMocks(); + + strategy = new ClassifierStrategy(); + mockContext = { + history: [], + request: [{ text: 'simple task' }], + signal: new AbortController().signal, + }; + mockConfig = {} as Config; + mockBaseLlmClient = { + generateJson: vi.fn(), + } as unknown as BaseLlmClient; + + vi.mocked(promptIdContext.getStore).mockReturnValue('test-prompt-id'); + }); + + it('should call generateJson with the correct parameters', async () => { + const mockApiResponse = { + reasoning: 'Simple task', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + + expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_LITE_MODEL, + config: expect.objectContaining({ + temperature: 0, + maxOutputTokens: 1024, + thinkingConfig: { + thinkingBudget: 512, + }, + }), + promptId: 'test-prompt-id', + }), + ); + }); + + it('should route to FLASH model for a simple task', async () => { + const mockApiResponse = { + reasoning: 'This is a simple task.', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); + expect(decision).toEqual({ + model: DEFAULT_GEMINI_FLASH_MODEL, + metadata: { + source: 'Classifier', + latencyMs: expect.any(Number), + reasoning: mockApiResponse.reasoning, + }, + }); + }); + + it('should route to PRO model for a complex task', async () => { + const mockApiResponse = { + reasoning: 'This is a complex task.', + model_choice: 'pro', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + mockContext.request = [{ text: 'how do I build a spaceship?' }]; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce(); + expect(decision).toEqual({ + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: 'Classifier', + latencyMs: expect.any(Number), + reasoning: mockApiResponse.reasoning, + }, + }); + }); + + it('should return null if the classifier API call fails', async () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + const testError = new Error('API Failure'); + vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(consoleWarnSpy).toHaveBeenCalled(); + consoleWarnSpy.mockRestore(); + }); + + it('should return null if the classifier returns a malformed JSON object', async () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + const malformedApiResponse = { + reasoning: 'This is a simple task.', + // model_choice is missing, which will cause a Zod parsing error. + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + malformedApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision).toBeNull(); + expect(consoleWarnSpy).toHaveBeenCalled(); + consoleWarnSpy.mockRestore(); + }); + + it('should filter out tool-related history before sending to classifier', async () => { + mockContext.history = [ + { role: 'user', parts: [{ text: 'call a tool' }] }, + { role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] }, + { + role: 'user', + parts: [ + { functionResponse: { name: 'test_tool', response: { ok: true } } }, + ], + }, + { role: 'user', parts: [{ text: 'another user turn' }] }, + ]; + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + const expectedContents = [ + { role: 'user', parts: [{ text: 'call a tool' }] }, + { role: 'user', parts: [{ text: 'another user turn' }] }, + { role: 'user', parts: [{ text: 'simple task' }] }, + ]; + + expect(contents).toEqual(expectedContents); + }); + + it('should respect HISTORY_SEARCH_WINDOW and HISTORY_TURNS_FOR_CONTEXT', async () => { + const longHistory: Content[] = []; + for (let i = 0; i < 30; i++) { + longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] }); + // Add noise that should be filtered + if (i % 2 === 0) { + longHistory.push({ + role: 'model', + parts: [{ functionCall: { name: 'noise', args: {} } }], + }); + } + } + mockContext.history = longHistory; + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + const contents = generateJsonCall.contents; + + // Manually calculate what the history should be + const HISTORY_SEARCH_WINDOW = 20; + const HISTORY_TURNS_FOR_CONTEXT = 4; + const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW); + const cleanHistory = historySlice.filter( + (content) => !isFunctionCall(content) && !isFunctionResponse(content), + ); + const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT); + + expect(contents).toEqual([ + ...finalHistory, + { role: 'user', parts: mockContext.request }, + ]); + // There should be 4 history items + the current request + expect(contents).toHaveLength(5); + }); + + it('should use a fallback promptId if not found in context', async () => { + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + vi.mocked(promptIdContext.getStore).mockReturnValue(undefined); + const mockApiResponse = { + reasoning: 'Simple.', + model_choice: 'flash', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + await strategy.route(mockContext, mockConfig, mockBaseLlmClient); + + const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock + .calls[0][0]; + + expect(generateJsonCall.promptId).toMatch( + /^classifier-router-fallback-\d+-\w+$/, + ); + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'Could not find promptId in context. This is unexpected. Using a fallback ID:', + ), + ); + consoleWarnSpy.mockRestore(); + }); +}); diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts new file mode 100644 index 0000000000..f3af80ecb6 --- /dev/null +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -0,0 +1,213 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { z } from 'zod'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import { promptIdContext } from '../../utils/promptIdContext.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../../config/models.js'; +import { + type GenerateContentConfig, + createUserContent, + Type, +} from '@google/genai'; +import type { Config } from '../../config/config.js'; +import { + isFunctionCall, + isFunctionResponse, +} from '../../utils/messageInspectors.js'; + +const CLASSIFIER_GENERATION_CONFIG: GenerateContentConfig = { + temperature: 0, + maxOutputTokens: 1024, + thinkingConfig: { + thinkingBudget: 512, // This counts towards output max, so we don't want -1. + }, +}; + +// The number of recent history turns to provide to the router for context. +const HISTORY_TURNS_FOR_CONTEXT = 4; +const HISTORY_SEARCH_WINDOW = 20; + +const FLASH_MODEL = 'flash'; +const PRO_MODEL = 'pro'; + +const CLASSIFIER_SYSTEM_PROMPT = ` +You are a specialized Task Routing AI. Your sole function is to analyze the user's request and classify its complexity. Choose between \`${FLASH_MODEL}\` (SIMPLE) or \`${PRO_MODEL}\` (COMPLEX). +1. \`${FLASH_MODEL}\`: A fast, efficient model for simple, well-defined tasks. +2. \`${PRO_MODEL}\`: A powerful, advanced model for complex, open-ended, or multi-step tasks. + +A task is COMPLEX (Choose \`${PRO_MODEL}\`) if it meets ONE OR MORE of the following criteria: +1. **High Operational Complexity (Est. 4+ Steps/Tool Calls):** Requires dependent actions, significant planning, or multiple coordinated changes. +2. **Strategic Planning & Conceptual Design:** Asking "how" or "why." Requires advice, architecture, or high-level strategy. +3. **High Ambiguity or Large Scope (Extensive Investigation):** Broadly defined requests requiring extensive investigation. +4. **Deep Debugging & Root Cause Analysis:** Diagnosing unknown or complex problems from symptoms. +A task is SIMPLE (Choose \`${FLASH_MODEL}\`) if it is highly specific, bounded, and has Low Operational Complexity (Est. 1-3 tool calls). Operational simplicity overrides strategic phrasing. + +**Output Format:** +Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure. +{ + "type": "object", + "properties": { + "reasoning": { + "type": "string", + "description": "A brief, step-by-step explanation for the model choice, referencing the rubric." + }, + "model_choice": { + "type": "string", + "enum": ["${FLASH_MODEL}", "${PRO_MODEL}"] + } + }, + "required": ["reasoning", "model_choice"] +} +--- EXAMPLES --- +**Example 1 (Strategic Planning):** +*User Prompt:* "How should I architect the data pipeline for this new analytics service?" +*Your JSON Output:* +{ + "reasoning": "The user is asking for high-level architectural design and strategy. This falls under 'Strategic Planning & Conceptual Design'.", + "model_choice": "${PRO_MODEL}" +} +**Example 2 (Simple Tool Use):** +*User Prompt:* "list the files in the current directory" +*Your JSON Output:* +{ + "reasoning": "This is a direct command requiring a single tool call (ls). It has Low Operational Complexity (1 step).", + "model_choice": "${FLASH_MODEL}" +} +**Example 3 (High Operational Complexity):** +*User Prompt:* "I need to add a new 'email' field to the User schema in 'src/models/user.ts', migrate the database, and update the registration endpoint." +*Your JSON Output:* +{ + "reasoning": "This request involves multiple coordinated steps across different files and systems. This meets the criteria for High Operational Complexity (4+ steps).", + "model_choice": "${PRO_MODEL}" +} +**Example 4 (Simple Read):** +*User Prompt:* "Read the contents of 'package.json'." +*Your JSON Output:* +{ + "reasoning": "This is a direct command requiring a single read. It has Low Operational Complexity (1 step).", + "model_choice": "${FLASH_MODEL}" +} + +**Example 5 (Deep Debugging):** +*User Prompt:* "I'm getting an error 'Cannot read property 'map' of undefined' when I click the save button. Can you fix it?" +*Your JSON Output:* +{ + "reasoning": "The user is reporting an error symptom without a known cause. This requires investigation and falls under 'Deep Debugging'.", + "model_choice": "${PRO_MODEL}" +} +**Example 6 (Simple Edit despite Phrasing):** +*User Prompt:* "What is the best way to rename the variable 'data' to 'userData' in 'src/utils.js'?" +*Your JSON Output:* +{ + "reasoning": "Although the user uses strategic language ('best way'), the underlying task is a localized edit. The operational complexity is low (1-2 steps).", + "model_choice": "${FLASH_MODEL}" +} +`; + +const RESPONSE_SCHEMA = { + type: Type.OBJECT, + properties: { + reasoning: { + type: Type.STRING, + description: + 'A brief, step-by-step explanation for the model choice, referencing the rubric.', + }, + model_choice: { + type: Type.STRING, + enum: [FLASH_MODEL, PRO_MODEL], + }, + }, + required: ['reasoning', 'model_choice'], +}; + +const ClassifierResponseSchema = z.object({ + reasoning: z.string(), + model_choice: z.enum([FLASH_MODEL, PRO_MODEL]), +}); + +export class ClassifierStrategy implements RoutingStrategy { + readonly name = 'classifier'; + + async route( + context: RoutingContext, + _config: Config, + baseLlmClient: BaseLlmClient, + ): Promise { + const startTime = Date.now(); + try { + let promptId = promptIdContext.getStore(); + if (!promptId) { + promptId = `classifier-router-fallback-${Date.now()}-${Math.random() + .toString(16) + .slice(2)}`; + console.warn( + `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`, + ); + } + + const historySlice = context.history.slice(-HISTORY_SEARCH_WINDOW); + + // Filter out tool-related turns. + // TODO - Consider using function req/res if they help accuracy. + const cleanHistory = historySlice.filter( + (content) => !isFunctionCall(content) && !isFunctionResponse(content), + ); + + // Take the last N turns from the *cleaned* history. + const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT); + + const jsonResponse = await baseLlmClient.generateJson({ + contents: [...finalHistory, createUserContent(context.request)], + schema: RESPONSE_SCHEMA, + model: DEFAULT_GEMINI_FLASH_LITE_MODEL, + systemInstruction: CLASSIFIER_SYSTEM_PROMPT, + config: CLASSIFIER_GENERATION_CONFIG, + abortSignal: context.signal, + promptId, + }); + + const routerResponse = ClassifierResponseSchema.parse(jsonResponse); + + const reasoning = routerResponse.reasoning; + const latencyMs = Date.now() - startTime; + + if (routerResponse.model_choice === FLASH_MODEL) { + return { + model: DEFAULT_GEMINI_FLASH_MODEL, + metadata: { + source: 'Classifier', + latencyMs, + reasoning, + }, + }; + } else { + return { + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: 'Classifier', + reasoning, + latencyMs, + }, + }; + } + } catch (error) { + // If the classifier fails for any reason (API error, parsing error, etc.), + // we log it and return null to allow the composite strategy to proceed. + console.warn(`[Routing] ClassifierStrategy failed:`, error); + return null; + } + } +}