# feat(routing): Introduce useModelRouter feature flag (#8366)

This commit is contained in:
Abhi
2025-09-12 15:57:07 -04:00
committed by GitHub
parent bc7c7fe466
commit c15774ce68
14 changed files with 267 additions and 48 deletions
+25
View File
@@ -523,6 +523,31 @@ describe('Server Config (config.ts)', () => {
});
});
describe('UseModelRouter Configuration', () => {
it('should default useModelRouter to false when not provided', () => {
const config = new Config(baseParams);
expect(config.getUseModelRouter()).toBe(false);
});
it('should set useModelRouter to true when provided as true', () => {
const paramsWithModelRouter: ConfigParameters = {
...baseParams,
useModelRouter: true,
};
const config = new Config(paramsWithModelRouter);
expect(config.getUseModelRouter()).toBe(true);
});
it('should set useModelRouter to false when explicitly provided as false', () => {
const paramsWithModelRouter: ConfigParameters = {
...baseParams,
useModelRouter: false,
};
const config = new Config(paramsWithModelRouter);
expect(config.getUseModelRouter()).toBe(false);
});
});
describe('createToolRegistry', () => {
it('should register a tool if coreTools contains an argument-specific pattern', async () => {
const params: ConfigParameters = {
+8 -2
View File
@@ -44,7 +44,6 @@ import { StartSessionEvent } from '../telemetry/index.js';
import {
DEFAULT_GEMINI_EMBEDDING_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from './models.js';
import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
@@ -240,6 +239,7 @@ export interface ConfigParameters {
useSmartEdit?: boolean;
policyEngineConfig?: PolicyEngineConfig;
output?: OutputSettings;
useModelRouter?: boolean;
}
export class Config {
@@ -327,6 +327,7 @@ export class Config {
private readonly messageBus: MessageBus;
private readonly policyEngine: PolicyEngine;
private readonly outputSettings: OutputSettings;
private readonly useModelRouter: boolean;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -376,7 +377,7 @@ export class Config {
this.cwd = params.cwd ?? process.cwd();
this.fileDiscoveryService = params.fileDiscoveryService ?? null;
this.bugCommand = params.bugCommand;
this.model = params.model || DEFAULT_GEMINI_MODEL;
this.model = params.model;
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
this.maxSessionTurns = params.maxSessionTurns ?? -1;
this.experimentalZedIntegration =
@@ -411,6 +412,7 @@ export class Config {
this.enableToolOutputTruncation =
params.enableToolOutputTruncation ?? false;
this.useSmartEdit = params.useSmartEdit ?? true;
this.useModelRouter = params.useModelRouter ?? false;
this.extensionManagement = params.extensionManagement ?? true;
this.storage = new Storage(this.targetDir);
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
@@ -931,6 +933,10 @@ export class Config {
: OutputFormat.TEXT;
}
getUseModelRouter(): boolean {
return this.useModelRouter;
}
async getGitService(): Promise<GitService> {
if (!this.gitService) {
this.gitService = new GitService(this.targetDir, this.storage);
+2
View File
@@ -8,6 +8,8 @@ export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro';
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite';
export const DEFAULT_GEMINI_MODEL_AUTO = 'auto';
export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
// Some thinking models do not default to dynamic thinking which is done by a value of -1
+31 -20
View File
@@ -33,7 +33,10 @@ import type { ChatRecordingService } from '../services/chatRecordingService.js';
import type { ContentGenerator } from './contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_THINKING_MODE,
getEffectiveModel,
} from '../config/models.js';
import { LoopDetectionService } from '../services/loopDetectionService.js';
import { ideContextStore } from '../ide/ideContext.js';
@@ -52,14 +55,14 @@ import { handleFallback } from '../fallback/handler.js';
import type { RoutingContext } from '../routing/routingStrategy.js';
export function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true;
return false;
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
export function isThinkingDefault(model: string) {
if (model.startsWith('gemini-2.5-flash-lite')) return false;
if (model.startsWith('gemini-2.5')) return true;
return false;
if (model.startsWith('gemini-2.5-flash-lite')) {
return false;
}
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
}
/**
@@ -227,23 +230,21 @@ export class GeminiClient {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
const model = this.config.getModel();
const generateContentConfigWithThinking = isThinkingSupported(model)
? {
...this.generateContentConfig,
thinkingConfig: {
thinkingBudget: -1,
includeThoughts: true,
...(!isThinkingDefault(model)
? { thinkingBudget: DEFAULT_THINKING_MODE }
: {}),
},
}
: this.generateContentConfig;
const config: GenerateContentConfig = { ...this.generateContentConfig };
if (isThinkingSupported(model)) {
config.thinkingConfig = {
includeThoughts: true,
thinkingBudget: DEFAULT_THINKING_MODE,
};
}
return new GeminiChat(
this.config,
{
systemInstruction,
...generateContentConfigWithThinking,
...config,
tools,
},
history,
@@ -790,6 +791,18 @@ export class GeminiClient {
prompt_id: string,
force: boolean = false,
): Promise<ChatCompressionInfo> {
// If the model is 'auto', we will use a placeholder model to check.
// Compression occurs before we choose a model, so calling `count_tokens`
// before the model is chosen would result in an error.
const configModel = this.config.getModel();
let model: string =
configModel === DEFAULT_GEMINI_MODEL_AUTO
? DEFAULT_GEMINI_MODEL
: configModel;
// Check if the model needs to be a fallback
model = getEffectiveModel(this.config.isInFallbackMode(), model);
const curatedHistory = this.getChat().getHistory(true);
// Regardless of `force`, don't do anything if the history is empty.
@@ -804,8 +817,6 @@ export class GeminiClient {
};
}
const model = this.config.getModel();
const { totalTokens: originalTokenCount } =
await this.getContentGeneratorOrFail().countTokens({
model,
+9 -3
View File
@@ -231,6 +231,7 @@ export class GeminiChat {
: [params.message];
const userMessageContent = partListUnionToString(toParts(userMessage));
this.chatRecordingService.recordMessage({
model,
type: 'user',
content: userMessageContent,
});
@@ -371,7 +372,7 @@ export class GeminiChat {
authType: this.config.getContentGeneratorConfig()?.authType,
});
return this.processStreamResponse(streamResponse, userContent);
return this.processStreamResponse(model, streamResponse, userContent);
}
/**
@@ -474,6 +475,7 @@ export class GeminiChat {
}
private async *processStreamResponse(
model: string,
streamResponse: AsyncGenerator<GenerateContentResponse>,
userInput: Content,
): AsyncGenerator<GenerateContentResponse> {
@@ -552,6 +554,7 @@ export class GeminiChat {
if (responseText.trim()) {
this.chatRecordingService.recordMessage({
model,
type: 'gemini',
content: responseText,
});
@@ -662,7 +665,10 @@ export class GeminiChat {
* Records completed tool calls with full metadata.
* This is called by external components when tool calls complete, before sending responses to Gemini.
*/
recordCompletedToolCalls(toolCalls: CompletedToolCall[]): void {
recordCompletedToolCalls(
model: string,
toolCalls: CompletedToolCall[],
): void {
const toolCallRecords = toolCalls.map((call) => {
const resultDisplayRaw = call.response?.resultDisplay;
const resultDisplay =
@@ -679,7 +685,7 @@ export class GeminiChat {
};
});
this.chatRecordingService.recordToolCalls(toolCallRecords);
this.chatRecordingService.recordToolCalls(model, toolCallRecords);
}
/**
@@ -9,15 +9,16 @@ import { OverrideStrategy } from './overrideStrategy.js';
import type { RoutingContext } from '../routingStrategy.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import type { Config } from '../../config/config.js';
import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js';
describe('OverrideStrategy', () => {
const strategy = new OverrideStrategy();
const mockContext = {} as RoutingContext;
const mockClient = {} as BaseLlmClient;
it('should return null when no override model is specified', async () => {
it('should return null when the override model is auto', async () => {
const mockConfig = {
getModel: () => '', // Simulate no model override
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
} as Config;
const decision = await strategy.route(mockContext, mockConfig, mockClient);
@@ -5,6 +5,7 @@
*/
import type { Config } from '../../config/config.js';
import { DEFAULT_GEMINI_MODEL_AUTO } from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import type {
RoutingContext,
@@ -24,17 +25,18 @@ export class OverrideStrategy implements RoutingStrategy {
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const overrideModel = config.getModel();
if (overrideModel) {
return {
model: overrideModel,
metadata: {
source: this.name,
latencyMs: 0,
reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`,
},
};
}
// No override specified, pass to the next strategy.
return null;
// If the model is 'auto' we should pass to the next strategy.
if (overrideModel === DEFAULT_GEMINI_MODEL_AUTO) return null;
// Return the overridden model name.
return {
model: overrideModel,
metadata: {
source: this.name,
latencyMs: 0,
reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`,
},
};
}
}
@@ -127,7 +127,11 @@ describe('ChatRecordingService', () => {
const writeFileSyncSpy = vi
.spyOn(fs, 'writeFileSync')
.mockImplementation(() => undefined);
chatRecordingService.recordMessage({ type: 'user', content: 'Hello' });
chatRecordingService.recordMessage({
type: 'user',
content: 'Hello',
model: 'gemini-pro',
});
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
const conversation = JSON.parse(
@@ -161,6 +165,7 @@ describe('ChatRecordingService', () => {
chatRecordingService.recordMessage({
type: 'user',
content: 'World',
model: 'gemini-pro',
});
expect(mkdirSyncSpy).toHaveBeenCalled();
@@ -311,7 +316,7 @@ describe('ChatRecordingService', () => {
status: 'awaiting_approval',
timestamp: new Date().toISOString(),
};
chatRecordingService.recordToolCalls([toolCall]);
chatRecordingService.recordToolCalls('gemini-pro', [toolCall]);
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
@@ -358,7 +363,7 @@ describe('ChatRecordingService', () => {
status: 'awaiting_approval',
timestamp: new Date().toISOString(),
};
chatRecordingService.recordToolCalls([toolCall]);
chatRecordingService.recordToolCalls('gemini-pro', [toolCall]);
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
@@ -195,6 +195,7 @@ export class ChatRecordingService {
* Records a message in the conversation.
*/
recordMessage(message: {
model: string;
type: ConversationRecordExtra['type'];
content: PartListUnion;
}): void {
@@ -209,7 +210,7 @@ export class ChatRecordingService {
...msg,
thoughts: this.queuedThoughts,
tokens: this.queuedTokens,
model: this.config.getModel(),
model: message.model,
});
this.queuedThoughts = [];
this.queuedTokens = null;
@@ -279,7 +280,7 @@ export class ChatRecordingService {
* Adds tool calls to the last message in the conversation (which should be by Gemini).
* This method enriches tool calls with metadata from the ToolRegistry.
*/
recordToolCalls(toolCalls: ToolCallRecord[]): void {
recordToolCalls(model: string, toolCalls: ToolCallRecord[]): void {
if (!this.conversationFile) return;
// Enrich tool calls with metadata from the ToolRegistry
@@ -318,7 +319,7 @@ export class ChatRecordingService {
type: 'gemini' as const,
toolCalls: enrichedToolCalls,
thoughts: this.queuedThoughts,
model: this.config.getModel(),
model,
};
// If there are any queued thoughts join them to this message.
if (this.queuedThoughts.length > 0) {