mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
feat(routing): Introduce Classifier-based Model Routing Strategy (#8455)
This commit is contained in:
@@ -13,6 +13,7 @@ import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
|||||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||||
|
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||||
|
|
||||||
vi.mock('../config/config.js');
|
vi.mock('../config/config.js');
|
||||||
vi.mock('../core/baseLlmClient.js');
|
vi.mock('../core/baseLlmClient.js');
|
||||||
@@ -20,6 +21,7 @@ vi.mock('./strategies/defaultStrategy.js');
|
|||||||
vi.mock('./strategies/compositeStrategy.js');
|
vi.mock('./strategies/compositeStrategy.js');
|
||||||
vi.mock('./strategies/fallbackStrategy.js');
|
vi.mock('./strategies/fallbackStrategy.js');
|
||||||
vi.mock('./strategies/overrideStrategy.js');
|
vi.mock('./strategies/overrideStrategy.js');
|
||||||
|
vi.mock('./strategies/classifierStrategy.js');
|
||||||
|
|
||||||
describe('ModelRouterService', () => {
|
describe('ModelRouterService', () => {
|
||||||
let service: ModelRouterService;
|
let service: ModelRouterService;
|
||||||
@@ -36,7 +38,12 @@ describe('ModelRouterService', () => {
|
|||||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||||
|
|
||||||
mockCompositeStrategy = new CompositeStrategy(
|
mockCompositeStrategy = new CompositeStrategy(
|
||||||
[new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
|
[
|
||||||
|
new FallbackStrategy(),
|
||||||
|
new OverrideStrategy(),
|
||||||
|
new ClassifierStrategy(),
|
||||||
|
new DefaultStrategy(),
|
||||||
|
],
|
||||||
'agent-router',
|
'agent-router',
|
||||||
);
|
);
|
||||||
vi.mocked(CompositeStrategy).mockImplementation(
|
vi.mocked(CompositeStrategy).mockImplementation(
|
||||||
@@ -62,10 +69,11 @@ describe('ModelRouterService', () => {
|
|||||||
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
||||||
const childStrategies = compositeStrategyArgs[0];
|
const childStrategies = compositeStrategyArgs[0];
|
||||||
|
|
||||||
expect(childStrategies.length).toBe(3);
|
expect(childStrategies.length).toBe(4);
|
||||||
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
||||||
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
||||||
expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy);
|
expect(childStrategies[2]).toBeInstanceOf(ClassifierStrategy);
|
||||||
|
expect(childStrategies[3]).toBeInstanceOf(DefaultStrategy);
|
||||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import type {
|
|||||||
TerminalStrategy,
|
TerminalStrategy,
|
||||||
} from './routingStrategy.js';
|
} from './routingStrategy.js';
|
||||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||||
|
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||||
@@ -31,7 +32,12 @@ export class ModelRouterService {
|
|||||||
// Initialize the composite strategy with the desired priority order.
|
// Initialize the composite strategy with the desired priority order.
|
||||||
// The strategies are ordered in order of highest priority.
|
// The strategies are ordered in order of highest priority.
|
||||||
return new CompositeStrategy(
|
return new CompositeStrategy(
|
||||||
[new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
|
[
|
||||||
|
new FallbackStrategy(),
|
||||||
|
new OverrideStrategy(),
|
||||||
|
new ClassifierStrategy(),
|
||||||
|
new DefaultStrategy(),
|
||||||
|
],
|
||||||
'agent-router',
|
'agent-router',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,277 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||||
|
import { ClassifierStrategy } from './classifierStrategy.js';
|
||||||
|
import type { RoutingContext } from '../routingStrategy.js';
|
||||||
|
import type { Config } from '../../config/config.js';
|
||||||
|
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||||
|
import {
|
||||||
|
isFunctionCall,
|
||||||
|
isFunctionResponse,
|
||||||
|
} from '../../utils/messageInspectors.js';
|
||||||
|
import {
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||||
|
DEFAULT_GEMINI_MODEL,
|
||||||
|
} from '../../config/models.js';
|
||||||
|
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||||
|
import type { Content } from '@google/genai';
|
||||||
|
|
||||||
|
vi.mock('../../core/baseLlmClient.js');
|
||||||
|
vi.mock('../../utils/promptIdContext.js');
|
||||||
|
|
||||||
|
describe('ClassifierStrategy', () => {
|
||||||
|
let strategy: ClassifierStrategy;
|
||||||
|
let mockContext: RoutingContext;
|
||||||
|
let mockConfig: Config;
|
||||||
|
let mockBaseLlmClient: BaseLlmClient;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks();
|
||||||
|
|
||||||
|
strategy = new ClassifierStrategy();
|
||||||
|
mockContext = {
|
||||||
|
history: [],
|
||||||
|
request: [{ text: 'simple task' }],
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
};
|
||||||
|
mockConfig = {} as Config;
|
||||||
|
mockBaseLlmClient = {
|
||||||
|
generateJson: vi.fn(),
|
||||||
|
} as unknown as BaseLlmClient;
|
||||||
|
|
||||||
|
vi.mocked(promptIdContext.getStore).mockReturnValue('test-prompt-id');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should call generateJson with the correct parameters', async () => {
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'Simple task',
|
||||||
|
model_choice: 'flash',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
|
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||||
|
config: expect.objectContaining({
|
||||||
|
temperature: 0,
|
||||||
|
maxOutputTokens: 1024,
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingBudget: 512,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
promptId: 'test-prompt-id',
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should route to FLASH model for a simple task', async () => {
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'This is a simple task.',
|
||||||
|
model_choice: 'flash',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
||||||
|
expect(decision).toEqual({
|
||||||
|
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
metadata: {
|
||||||
|
source: 'Classifier',
|
||||||
|
latencyMs: expect.any(Number),
|
||||||
|
reasoning: mockApiResponse.reasoning,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should route to PRO model for a complex task', async () => {
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'This is a complex task.',
|
||||||
|
model_choice: 'pro',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
mockContext.request = [{ text: 'how do I build a spaceship?' }];
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
||||||
|
expect(decision).toEqual({
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
metadata: {
|
||||||
|
source: 'Classifier',
|
||||||
|
latencyMs: expect.any(Number),
|
||||||
|
reasoning: mockApiResponse.reasoning,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null if the classifier API call fails', async () => {
|
||||||
|
const consoleWarnSpy = vi
|
||||||
|
.spyOn(console, 'warn')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
const testError = new Error('API Failure');
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).toBeNull();
|
||||||
|
expect(consoleWarnSpy).toHaveBeenCalled();
|
||||||
|
consoleWarnSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null if the classifier returns a malformed JSON object', async () => {
|
||||||
|
const consoleWarnSpy = vi
|
||||||
|
.spyOn(console, 'warn')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
const malformedApiResponse = {
|
||||||
|
reasoning: 'This is a simple task.',
|
||||||
|
// model_choice is missing, which will cause a Zod parsing error.
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
malformedApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).toBeNull();
|
||||||
|
expect(consoleWarnSpy).toHaveBeenCalled();
|
||||||
|
consoleWarnSpy.mockRestore();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should filter out tool-related history before sending to classifier', async () => {
|
||||||
|
mockContext.history = [
|
||||||
|
{ role: 'user', parts: [{ text: 'call a tool' }] },
|
||||||
|
{ role: 'model', parts: [{ functionCall: { name: 'test_tool' } }] },
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [
|
||||||
|
{ functionResponse: { name: 'test_tool', response: { ok: true } } },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{ role: 'user', parts: [{ text: 'another user turn' }] },
|
||||||
|
];
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'Simple.',
|
||||||
|
model_choice: 'flash',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
|
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||||
|
.calls[0][0];
|
||||||
|
const contents = generateJsonCall.contents;
|
||||||
|
|
||||||
|
const expectedContents = [
|
||||||
|
{ role: 'user', parts: [{ text: 'call a tool' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'another user turn' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'simple task' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
expect(contents).toEqual(expectedContents);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should respect HISTORY_SEARCH_WINDOW and HISTORY_TURNS_FOR_CONTEXT', async () => {
|
||||||
|
const longHistory: Content[] = [];
|
||||||
|
for (let i = 0; i < 30; i++) {
|
||||||
|
longHistory.push({ role: 'user', parts: [{ text: `Message ${i}` }] });
|
||||||
|
// Add noise that should be filtered
|
||||||
|
if (i % 2 === 0) {
|
||||||
|
longHistory.push({
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ functionCall: { name: 'noise', args: {} } }],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mockContext.history = longHistory;
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'Simple.',
|
||||||
|
model_choice: 'flash',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
|
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||||
|
.calls[0][0];
|
||||||
|
const contents = generateJsonCall.contents;
|
||||||
|
|
||||||
|
// Manually calculate what the history should be
|
||||||
|
const HISTORY_SEARCH_WINDOW = 20;
|
||||||
|
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
||||||
|
const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW);
|
||||||
|
const cleanHistory = historySlice.filter(
|
||||||
|
(content) => !isFunctionCall(content) && !isFunctionResponse(content),
|
||||||
|
);
|
||||||
|
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||||
|
|
||||||
|
expect(contents).toEqual([
|
||||||
|
...finalHistory,
|
||||||
|
{ role: 'user', parts: mockContext.request },
|
||||||
|
]);
|
||||||
|
// There should be 4 history items + the current request
|
||||||
|
expect(contents).toHaveLength(5);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should use a fallback promptId if not found in context', async () => {
|
||||||
|
const consoleWarnSpy = vi
|
||||||
|
.spyOn(console, 'warn')
|
||||||
|
.mockImplementation(() => {});
|
||||||
|
vi.mocked(promptIdContext.getStore).mockReturnValue(undefined);
|
||||||
|
const mockApiResponse = {
|
||||||
|
reasoning: 'Simple.',
|
||||||
|
model_choice: 'flash',
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
|
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||||
|
.calls[0][0];
|
||||||
|
|
||||||
|
expect(generateJsonCall.promptId).toMatch(
|
||||||
|
/^classifier-router-fallback-\d+-\w+$/,
|
||||||
|
);
|
||||||
|
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||||
|
expect.stringContaining(
|
||||||
|
'Could not find promptId in context. This is unexpected. Using a fallback ID:',
|
||||||
|
),
|
||||||
|
);
|
||||||
|
consoleWarnSpy.mockRestore();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,213 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { z } from 'zod';
|
||||||
|
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||||
|
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||||
|
import type {
|
||||||
|
RoutingContext,
|
||||||
|
RoutingDecision,
|
||||||
|
RoutingStrategy,
|
||||||
|
} from '../routingStrategy.js';
|
||||||
|
import {
|
||||||
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||||
|
DEFAULT_GEMINI_MODEL,
|
||||||
|
} from '../../config/models.js';
|
||||||
|
import {
|
||||||
|
type GenerateContentConfig,
|
||||||
|
createUserContent,
|
||||||
|
Type,
|
||||||
|
} from '@google/genai';
|
||||||
|
import type { Config } from '../../config/config.js';
|
||||||
|
import {
|
||||||
|
isFunctionCall,
|
||||||
|
isFunctionResponse,
|
||||||
|
} from '../../utils/messageInspectors.js';
|
||||||
|
|
||||||
|
const CLASSIFIER_GENERATION_CONFIG: GenerateContentConfig = {
|
||||||
|
temperature: 0,
|
||||||
|
maxOutputTokens: 1024,
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingBudget: 512, // This counts towards output max, so we don't want -1.
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
// The number of recent history turns to provide to the router for context.
|
||||||
|
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
||||||
|
const HISTORY_SEARCH_WINDOW = 20;
|
||||||
|
|
||||||
|
const FLASH_MODEL = 'flash';
|
||||||
|
const PRO_MODEL = 'pro';
|
||||||
|
|
||||||
|
const CLASSIFIER_SYSTEM_PROMPT = `
|
||||||
|
You are a specialized Task Routing AI. Your sole function is to analyze the user's request and classify its complexity. Choose between \`${FLASH_MODEL}\` (SIMPLE) or \`${PRO_MODEL}\` (COMPLEX).
|
||||||
|
1. \`${FLASH_MODEL}\`: A fast, efficient model for simple, well-defined tasks.
|
||||||
|
2. \`${PRO_MODEL}\`: A powerful, advanced model for complex, open-ended, or multi-step tasks.
|
||||||
|
<complexity_rubric>
|
||||||
|
A task is COMPLEX (Choose \`${PRO_MODEL}\`) if it meets ONE OR MORE of the following criteria:
|
||||||
|
1. **High Operational Complexity (Est. 4+ Steps/Tool Calls):** Requires dependent actions, significant planning, or multiple coordinated changes.
|
||||||
|
2. **Strategic Planning & Conceptual Design:** Asking "how" or "why." Requires advice, architecture, or high-level strategy.
|
||||||
|
3. **High Ambiguity or Large Scope (Extensive Investigation):** Broadly defined requests requiring extensive investigation.
|
||||||
|
4. **Deep Debugging & Root Cause Analysis:** Diagnosing unknown or complex problems from symptoms.
|
||||||
|
A task is SIMPLE (Choose \`${FLASH_MODEL}\`) if it is highly specific, bounded, and has Low Operational Complexity (Est. 1-3 tool calls). Operational simplicity overrides strategic phrasing.
|
||||||
|
</complexity_rubric>
|
||||||
|
**Output Format:**
|
||||||
|
Respond *only* in JSON format according to the following schema. Do not include any text outside the JSON structure.
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"reasoning": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "A brief, step-by-step explanation for the model choice, referencing the rubric."
|
||||||
|
},
|
||||||
|
"model_choice": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["${FLASH_MODEL}", "${PRO_MODEL}"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["reasoning", "model_choice"]
|
||||||
|
}
|
||||||
|
--- EXAMPLES ---
|
||||||
|
**Example 1 (Strategic Planning):**
|
||||||
|
*User Prompt:* "How should I architect the data pipeline for this new analytics service?"
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "The user is asking for high-level architectural design and strategy. This falls under 'Strategic Planning & Conceptual Design'.",
|
||||||
|
"model_choice": "${PRO_MODEL}"
|
||||||
|
}
|
||||||
|
**Example 2 (Simple Tool Use):**
|
||||||
|
*User Prompt:* "list the files in the current directory"
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "This is a direct command requiring a single tool call (ls). It has Low Operational Complexity (1 step).",
|
||||||
|
"model_choice": "${FLASH_MODEL}"
|
||||||
|
}
|
||||||
|
**Example 3 (High Operational Complexity):**
|
||||||
|
*User Prompt:* "I need to add a new 'email' field to the User schema in 'src/models/user.ts', migrate the database, and update the registration endpoint."
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "This request involves multiple coordinated steps across different files and systems. This meets the criteria for High Operational Complexity (4+ steps).",
|
||||||
|
"model_choice": "${PRO_MODEL}"
|
||||||
|
}
|
||||||
|
**Example 4 (Simple Read):**
|
||||||
|
*User Prompt:* "Read the contents of 'package.json'."
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "This is a direct command requiring a single read. It has Low Operational Complexity (1 step).",
|
||||||
|
"model_choice": "${FLASH_MODEL}"
|
||||||
|
}
|
||||||
|
|
||||||
|
**Example 5 (Deep Debugging):**
|
||||||
|
*User Prompt:* "I'm getting an error 'Cannot read property 'map' of undefined' when I click the save button. Can you fix it?"
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "The user is reporting an error symptom without a known cause. This requires investigation and falls under 'Deep Debugging'.",
|
||||||
|
"model_choice": "${PRO_MODEL}"
|
||||||
|
}
|
||||||
|
**Example 6 (Simple Edit despite Phrasing):**
|
||||||
|
*User Prompt:* "What is the best way to rename the variable 'data' to 'userData' in 'src/utils.js'?"
|
||||||
|
*Your JSON Output:*
|
||||||
|
{
|
||||||
|
"reasoning": "Although the user uses strategic language ('best way'), the underlying task is a localized edit. The operational complexity is low (1-2 steps).",
|
||||||
|
"model_choice": "${FLASH_MODEL}"
|
||||||
|
}
|
||||||
|
`;
|
||||||
|
|
||||||
|
const RESPONSE_SCHEMA = {
|
||||||
|
type: Type.OBJECT,
|
||||||
|
properties: {
|
||||||
|
reasoning: {
|
||||||
|
type: Type.STRING,
|
||||||
|
description:
|
||||||
|
'A brief, step-by-step explanation for the model choice, referencing the rubric.',
|
||||||
|
},
|
||||||
|
model_choice: {
|
||||||
|
type: Type.STRING,
|
||||||
|
enum: [FLASH_MODEL, PRO_MODEL],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
required: ['reasoning', 'model_choice'],
|
||||||
|
};
|
||||||
|
|
||||||
|
const ClassifierResponseSchema = z.object({
|
||||||
|
reasoning: z.string(),
|
||||||
|
model_choice: z.enum([FLASH_MODEL, PRO_MODEL]),
|
||||||
|
});
|
||||||
|
|
||||||
|
export class ClassifierStrategy implements RoutingStrategy {
|
||||||
|
readonly name = 'classifier';
|
||||||
|
|
||||||
|
async route(
|
||||||
|
context: RoutingContext,
|
||||||
|
_config: Config,
|
||||||
|
baseLlmClient: BaseLlmClient,
|
||||||
|
): Promise<RoutingDecision | null> {
|
||||||
|
const startTime = Date.now();
|
||||||
|
try {
|
||||||
|
let promptId = promptIdContext.getStore();
|
||||||
|
if (!promptId) {
|
||||||
|
promptId = `classifier-router-fallback-${Date.now()}-${Math.random()
|
||||||
|
.toString(16)
|
||||||
|
.slice(2)}`;
|
||||||
|
console.warn(
|
||||||
|
`Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const historySlice = context.history.slice(-HISTORY_SEARCH_WINDOW);
|
||||||
|
|
||||||
|
// Filter out tool-related turns.
|
||||||
|
// TODO - Consider using function req/res if they help accuracy.
|
||||||
|
const cleanHistory = historySlice.filter(
|
||||||
|
(content) => !isFunctionCall(content) && !isFunctionResponse(content),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Take the last N turns from the *cleaned* history.
|
||||||
|
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||||
|
|
||||||
|
const jsonResponse = await baseLlmClient.generateJson({
|
||||||
|
contents: [...finalHistory, createUserContent(context.request)],
|
||||||
|
schema: RESPONSE_SCHEMA,
|
||||||
|
model: DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||||
|
systemInstruction: CLASSIFIER_SYSTEM_PROMPT,
|
||||||
|
config: CLASSIFIER_GENERATION_CONFIG,
|
||||||
|
abortSignal: context.signal,
|
||||||
|
promptId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const routerResponse = ClassifierResponseSchema.parse(jsonResponse);
|
||||||
|
|
||||||
|
const reasoning = routerResponse.reasoning;
|
||||||
|
const latencyMs = Date.now() - startTime;
|
||||||
|
|
||||||
|
if (routerResponse.model_choice === FLASH_MODEL) {
|
||||||
|
return {
|
||||||
|
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
|
metadata: {
|
||||||
|
source: 'Classifier',
|
||||||
|
latencyMs,
|
||||||
|
reasoning,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
|
metadata: {
|
||||||
|
source: 'Classifier',
|
||||||
|
reasoning,
|
||||||
|
latencyMs,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// If the classifier fails for any reason (API error, parsing error, etc.),
|
||||||
|
// we log it and return null to allow the composite strategy to proceed.
|
||||||
|
console.warn(`[Routing] ClassifierStrategy failed:`, error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user