From 538e6cd19a8e8cf4a445ce9709477e5b7dfa59d4 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:38:50 -0400 Subject: [PATCH] feat(routing): Initialize model routing architecture (#8153) --- packages/a2a-server/src/agent/task.ts | 2 +- .../src/ui/hooks/useQuotaAndFallback.test.ts | 2 - .../cli/src/zed-integration/zedIntegration.ts | 1 + packages/core/src/config/config.test.ts | 6 +- packages/core/src/config/config.ts | 21 +- packages/core/src/config/models.test.ts | 83 +++++++ packages/core/src/config/models.ts | 32 +++ packages/core/src/core/client.test.ts | 184 ++++++++++++++- packages/core/src/core/client.ts | 36 ++- .../core/src/core/contentGenerator.test.ts | 3 - packages/core/src/core/contentGenerator.ts | 6 - packages/core/src/core/geminiChat.test.ts | 36 ++- packages/core/src/core/geminiChat.ts | 28 +-- packages/core/src/core/subagent.test.ts | 10 +- packages/core/src/core/subagent.ts | 1 + packages/core/src/core/turn.test.ts | 38 +++- packages/core/src/core/turn.ts | 2 + .../src/routing/modelRouterService.test.ts | 96 ++++++++ .../core/src/routing/modelRouterService.ts | 54 +++++ packages/core/src/routing/routingStrategy.ts | 76 +++++++ .../strategies/compositeStrategy.test.ts | 215 ++++++++++++++++++ .../routing/strategies/compositeStrategy.ts | 109 +++++++++ .../strategies/defaultStrategy.test.ts | 32 +++ .../src/routing/strategies/defaultStrategy.ts | 33 +++ .../strategies/fallbackStrategy.test.ts | 86 +++++++ .../routing/strategies/fallbackStrategy.ts | 43 ++++ .../strategies/overrideStrategy.test.ts | 55 +++++ .../routing/strategies/overrideStrategy.ts | 40 ++++ 28 files changed, 1263 insertions(+), 67 deletions(-) create mode 100644 packages/core/src/config/models.test.ts create mode 100644 packages/core/src/routing/modelRouterService.test.ts create mode 100644 packages/core/src/routing/modelRouterService.ts create mode 100644 packages/core/src/routing/routingStrategy.ts create mode 100644 packages/core/src/routing/strategies/compositeStrategy.test.ts create mode 100644 packages/core/src/routing/strategies/compositeStrategy.ts create mode 100644 packages/core/src/routing/strategies/defaultStrategy.test.ts create mode 100644 packages/core/src/routing/strategies/defaultStrategy.ts create mode 100644 packages/core/src/routing/strategies/fallbackStrategy.test.ts create mode 100644 packages/core/src/routing/strategies/fallbackStrategy.ts create mode 100644 packages/core/src/routing/strategies/overrideStrategy.test.ts create mode 100644 packages/core/src/routing/strategies/overrideStrategy.ts diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 9fdebe1b2a..2461f9893c 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -132,7 +132,7 @@ export class Task { id: this.id, contextId: this.contextId, taskState: this.taskState, - model: this.config.getContentGeneratorConfig().model, + model: this.config.getModel(), mcpServers: servers, availableTools, }; diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts index 7dd93eb72e..6d7782694f 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -57,7 +57,6 @@ describe('useQuotaAndFallback', () => { // Spy on the method that requires the private field and mock its return. // This is cleaner than modifying the config class for tests. vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ - model: 'gemini-pro', authType: AuthType.LOGIN_WITH_GOOGLE, }); @@ -128,7 +127,6 @@ describe('useQuotaAndFallback', () => { it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => { // Override the default mock from beforeEach for this specific test vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ - model: 'gemini-pro', authType: AuthType.USE_GEMINI, }); diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index a6a502e4a4..db2669c178 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -255,6 +255,7 @@ class Session { try { const responseStream = await chat.sendMessageStream( + this.config.getModel(), { message: nextMessage?.parts ?? [], config: { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 81096b69d5..66ce8b00a1 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -206,9 +206,7 @@ describe('Server Config (config.ts)', () => { it('should refresh auth and update config', async () => { const config = new Config(baseParams); const authType = AuthType.USE_GEMINI; - const newModel = 'gemini-flash'; const mockContentConfig = { - model: newModel, apiKey: 'test-key', }; @@ -226,10 +224,8 @@ describe('Server Config (config.ts)', () => { config, authType, ); - // Verify that contentGeneratorConfig is updated with the new model + // Verify that contentGeneratorConfig is updated expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); - expect(config.getContentGeneratorConfig().model).toBe(newModel); - expect(config.getModel()).toBe(newModel); // getModel() should return the updated model expect(GeminiClient).toHaveBeenCalledWith(config); // Verify that fallback mode is reset expect(config.isInFallbackMode()).toBe(false); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index fc49da7a54..3dcf72bd0c 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -44,6 +44,7 @@ import { StartSessionEvent } from '../telemetry/index.js'; import { DEFAULT_GEMINI_EMBEDDING_MODEL, DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, } from './models.js'; import { shouldAttemptBrowserLaunch } from '../utils/browser.js'; import type { MCPOAuthConfig } from '../mcp/oauth-provider.js'; @@ -62,6 +63,7 @@ import { RipgrepFallbackEvent, } from '../telemetry/types.js'; import type { FallbackModelHandler } from '../fallback/types.js'; +import { ModelRouterService } from '../routing/modelRouterService.js'; import { OutputFormat } from '../output/types.js'; // Re-export OAuth config type @@ -270,6 +272,7 @@ export class Config { private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; private baseLlmClient!: BaseLlmClient; + private modelRouterService: ModelRouterService; private readonly fileFiltering: { respectGitIgnore: boolean; respectGeminiIgnore: boolean; @@ -282,7 +285,7 @@ export class Config { private readonly proxy: string | undefined; private readonly cwd: string; private readonly bugCommand: BugCommandSettings | undefined; - private readonly model: string; + private model: string; private readonly extensionContextFilePaths: string[]; private readonly noBrowser: boolean; private readonly folderTrustFeature: boolean; @@ -372,7 +375,7 @@ export class Config { this.cwd = params.cwd ?? process.cwd(); this.fileDiscoveryService = params.fileDiscoveryService ?? null; this.bugCommand = params.bugCommand; - this.model = params.model; + this.model = params.model || DEFAULT_GEMINI_MODEL; this.extensionContextFilePaths = params.extensionContextFilePaths ?? []; this.maxSessionTurns = params.maxSessionTurns ?? -1; this.experimentalZedIntegration = @@ -424,6 +427,7 @@ export class Config { setGlobalDispatcher(new ProxyAgent(this.getProxy() as string)); } this.geminiClient = new GeminiClient(this); + this.modelRouterService = new ModelRouterService(this); } /** @@ -523,13 +527,16 @@ export class Config { } getModel(): string { - return this.contentGeneratorConfig?.model || this.model; + return this.model; } setModel(newModel: string): void { - if (this.contentGeneratorConfig) { - this.contentGeneratorConfig.model = newModel; + // Do not allow Pro usage if the user is in fallback mode. + if (newModel.includes('pro') && this.isInFallbackMode()) { + return; } + + this.model = newModel; } isInFallbackMode(): boolean { @@ -699,6 +706,10 @@ export class Config { return this.geminiClient; } + getModelRouterService(): ModelRouterService { + return this.modelRouterService; + } + getEnableRecursiveFileSearch(): boolean { return this.fileFiltering.enableRecursiveFileSearch; } diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts new file mode 100644 index 0000000000..8c790dd1ae --- /dev/null +++ b/packages/core/src/config/models.test.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + getEffectiveModel, + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, +} from './models.js'; + +describe('getEffectiveModel', () => { + describe('When NOT in fallback mode', () => { + const isInFallbackMode = false; + + it('should return the Pro model when Pro is requested', () => { + const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); + expect(model).toBe(DEFAULT_GEMINI_MODEL); + }); + + it('should return the Flash model when Flash is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should return the Lite model when Lite is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + }); + + it('should return a custom model name when requested', () => { + const customModel = 'custom-model-v1'; + const model = getEffectiveModel(isInFallbackMode, customModel); + expect(model).toBe(customModel); + }); + }); + + describe('When IN fallback mode', () => { + const isInFallbackMode = true; + + it('should downgrade the Pro model to the Flash model', () => { + const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should return the Flash model when Flash is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + + it('should HONOR the Lite model when Lite is requested', () => { + const model = getEffectiveModel( + isInFallbackMode, + DEFAULT_GEMINI_FLASH_LITE_MODEL, + ); + expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + }); + + it('should HONOR any model with "lite" in its name', () => { + const customLiteModel = 'gemini-2.5-custom-lite-vNext'; + const model = getEffectiveModel(isInFallbackMode, customLiteModel); + expect(model).toBe(customLiteModel); + }); + + it('should downgrade any other custom model to the Flash model', () => { + const customModel = 'custom-model-v1-unlisted'; + const model = getEffectiveModel(isInFallbackMode, customModel); + expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + }); + }); +}); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 1d2c1310a7..a0aa73bfdd 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -12,3 +12,35 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001'; // Some thinking models do not default to dynamic thinking which is done by a value of -1 export const DEFAULT_THINKING_MODE = -1; + +/** + * Determines the effective model to use, applying fallback logic if necessary. + * + * When fallback mode is active, this function enforces the use of the standard + * fallback model. However, it makes an exception for "lite" models (any model + * with "lite" in its name), allowing them to be used to preserve cost savings. + * This ensures that "pro" models are always downgraded, while "lite" model + * requests are honored. + * + * @param isInFallbackMode Whether the application is in fallback mode. + * @param requestedModel The model that was originally requested. + * @returns The effective model name. + */ +export function getEffectiveModel( + isInFallbackMode: boolean, + requestedModel: string, +): string { + // If we are not in fallback mode, simply use the requested model. + if (!isInFallbackMode) { + return requestedModel; + } + + // If a "lite" model is requested, honor it. This allows for variations of + // lite models without needing to list them all as constants. + if (requestedModel.includes('lite')) { + return requestedModel; + } + + // Default fallback for Gemini CLI. + return DEFAULT_GEMINI_FLASH_MODEL; +} diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 3e4c169560..21f7c88989 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -41,6 +41,7 @@ import { setSimulate429 } from '../utils/testUtils.js'; import { tokenLimit } from './tokenLimits.js'; import { ideContext } from '../ide/ideContext.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; +import type { ModelRouterService } from '../routing/modelRouterService.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -234,7 +235,7 @@ describe('Gemini Client (client.ts)', () => { mockContentGenerator = { generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), - countTokens: vi.fn(), + countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }), embedContent: vi.fn(), batchEmbedContents: vi.fn(), } as unknown as ContentGenerator; @@ -248,7 +249,6 @@ describe('Gemini Client (client.ts)', () => { }; const fileService = new FileDiscoveryService('/test/dir'); const contentGeneratorConfig: ContentGeneratorConfig = { - model: 'test-model', apiKey: 'test-key', vertexai: false, authType: AuthType.USE_GEMINI, @@ -281,6 +281,9 @@ describe('Gemini Client (client.ts)', () => { getDirectories: vi.fn().mockReturnValue(['/test/dir']), }), getGeminiClient: vi.fn(), + getModelRouterService: vi.fn().mockReturnValue({ + route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }), + }), isInFallbackMode: vi.fn().mockReturnValue(false), setFallbackMode: vi.fn(), getChatCompression: vi.fn().mockReturnValue(undefined), @@ -1116,7 +1119,12 @@ ${JSON.stringify( // Assert expect(ideContext.getIdeContext).toHaveBeenCalled(); + // The `turn.run` method is now called with the model name as the first + // argument. We use `expect.any(String)` because this test is + // concerned with the IDE context logic, not the model routing, + // which is tested in its own dedicated suite. expect(mockTurnRunFn).toHaveBeenCalledWith( + expect.any(String), initialRequest, expect.any(Object), ); @@ -1506,6 +1514,178 @@ ${JSON.stringify( ); }); + describe('Model Routing', () => { + let mockRouterService: { route: Mock }; + + beforeEach(() => { + mockRouterService = { + route: vi + .fn() + .mockResolvedValue({ model: 'routed-model', reason: 'test' }), + }; + vi.mocked(mockConfig.getModelRouterService).mockReturnValue( + mockRouterService as unknown as ModelRouterService, + ); + + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); + }); + + it('should use the model router service to select a model on the first turn', async () => { + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); // consume stream + + expect(mockConfig.getModelRouterService).toHaveBeenCalled(); + expect(mockRouterService.route).toHaveBeenCalled(); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', // The model from the router + [{ text: 'Hi' }], + expect.any(Object), + ); + }); + + it('should use the same model for subsequent turns in the same prompt (stickiness)', async () => { + // First turn + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Hi' }], + expect.any(Object), + ); + + // Second turn + stream = client.sendMessageStream( + [{ text: 'Continue' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + // Router should not be called again + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + // Should stick to the first model + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Continue' }], + expect.any(Object), + ); + }); + + it('should reset the sticky model and re-route when the prompt_id changes', async () => { + // First prompt + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockRouterService.route).toHaveBeenCalledTimes(1); + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'routed-model', + [{ text: 'Hi' }], + expect.any(Object), + ); + + // New prompt + mockRouterService.route.mockResolvedValue({ + model: 'new-routed-model', + reason: 'test', + }); + stream = client.sendMessageStream( + [{ text: 'A new topic' }], + new AbortController().signal, + 'prompt-2', + ); + await fromAsync(stream); + + // Router should be called again for the new prompt + expect(mockRouterService.route).toHaveBeenCalledTimes(2); + // Should use the newly routed model + expect(mockTurnRunFn).toHaveBeenCalledWith( + 'new-routed-model', + [{ text: 'A new topic' }], + expect.any(Object), + ); + }); + + it('should use the fallback model and bypass routing when in fallback mode', async () => { + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + mockRouterService.route.mockResolvedValue({ + model: DEFAULT_GEMINI_FLASH_MODEL, + reason: 'fallback', + }); + + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-1', + ); + await fromAsync(stream); + + expect(mockTurnRunFn).toHaveBeenCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, + [{ text: 'Hi' }], + expect.any(Object), + ); + }); + + it('should stick to the fallback model for the entire sequence even if fallback mode ends', async () => { + // Start the sequence in fallback mode + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + mockRouterService.route.mockResolvedValue({ + model: DEFAULT_GEMINI_FLASH_MODEL, + reason: 'fallback', + }); + let stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-fallback-stickiness', + ); + await fromAsync(stream); + + // First call should use fallback model + expect(mockTurnRunFn).toHaveBeenCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, + [{ text: 'Hi' }], + expect.any(Object), + ); + + // End fallback mode + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false); + + // Second call in the same sequence + stream = client.sendMessageStream( + [{ text: 'Continue' }], + new AbortController().signal, + 'prompt-fallback-stickiness', + ); + await fromAsync(stream); + + // Router should still not be called, and it should stick to the fallback model + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again + expect(mockTurnRunFn).toHaveBeenLastCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, // Still the fallback model + [{ text: 'Continue' }], + expect.any(Object), + ); + }); + }); + describe('Editor context delta', () => { const mockStream = (async function* () { yield { type: 'content', value: 'Hello' }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index db331ad26a..e1436234e2 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -49,6 +49,7 @@ import { } from '../telemetry/types.js'; import type { IdeContext, File } from '../ide/types.js'; import { handleFallback } from '../fallback/handler.js'; +import type { RoutingContext } from '../routing/routingStrategy.js'; export function isThinkingSupported(model: string) { if (model.startsWith('gemini-2.5')) return true; @@ -118,6 +119,7 @@ export class GeminiClient { private readonly loopDetector: LoopDetectionService; private lastPromptId: string; + private currentSequenceModel: string | null = null; private lastSentIdeContext: IdeContext | undefined; private forceFullIdeContext = true; @@ -430,11 +432,11 @@ export class GeminiClient { signal: AbortSignal, prompt_id: string, turns: number = MAX_TURNS, - originalModel?: string, ): AsyncGenerator { if (this.lastPromptId !== prompt_id) { this.loopDetector.reset(prompt_id); this.lastPromptId = prompt_id; + this.currentSequenceModel = null; } this.sessionTurnCount++; if ( @@ -450,9 +452,6 @@ export class GeminiClient { return new Turn(this.getChat(), prompt_id); } - // Track the original model from the first call to detect model switching - const initialModel = originalModel || this.config.getModel(); - const compressed = await this.tryCompressChat(prompt_id); if (compressed.compressionStatus === CompressionStatus.COMPRESSED) { @@ -494,7 +493,26 @@ export class GeminiClient { return turn; } - const resultStream = turn.run(request, signal); + const routingContext: RoutingContext = { + history: this.getChat().getHistory(/*curated=*/ true), + request, + signal, + }; + + let modelToUse: string; + + // Determine Model (Stickiness vs. Routing) + if (this.currentSequenceModel) { + modelToUse = this.currentSequenceModel; + } else { + const router = await this.config.getModelRouterService(); + const decision = await router.route(routingContext); + modelToUse = decision.model; + // Lock the model for the rest of the sequence + this.currentSequenceModel = modelToUse; + } + + const resultStream = turn.run(modelToUse, request, signal); for await (const event of resultStream) { if (this.loopDetector.addAndCheck(event)) { yield { type: GeminiEventType.LoopDetected }; @@ -506,11 +524,8 @@ export class GeminiClient { } } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { - // Check if model was switched during the call (likely due to quota error) - const currentModel = this.config.getModel(); - if (currentModel !== initialModel) { - // Model was switched (likely due to quota error fallback) - // Don't continue with recursive call to prevent unwanted Flash execution + // Check if next speaker check is needed + if (this.config.getQuotaErrorOccurred()) { return turn; } @@ -540,7 +555,6 @@ export class GeminiClient { signal, prompt_id, boundedTurns - 1, - initialModel, ); } } diff --git a/packages/core/src/core/contentGenerator.test.ts b/packages/core/src/core/contentGenerator.test.ts index eba9d353ec..3084c84bd4 100644 --- a/packages/core/src/core/contentGenerator.test.ts +++ b/packages/core/src/core/contentGenerator.test.ts @@ -29,7 +29,6 @@ describe('createContentGenerator', () => { ); const generator = await createContentGenerator( { - model: 'test-model', authType: AuthType.LOGIN_WITH_GOOGLE, }, mockConfig, @@ -51,7 +50,6 @@ describe('createContentGenerator', () => { vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); const generator = await createContentGenerator( { - model: 'test-model', apiKey: 'test-api-key', authType: AuthType.USE_GEMINI, }, @@ -85,7 +83,6 @@ describe('createContentGenerator', () => { vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never); const generator = await createContentGenerator( { - model: 'test-model', apiKey: 'test-api-key', authType: AuthType.USE_GEMINI, }, diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index 4a794fd1f4..12f8ac7ae8 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -14,7 +14,6 @@ import type { } from '@google/genai'; import { GoogleGenAI } from '@google/genai'; import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js'; -import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; import type { Config } from '../config/config.js'; import type { UserTierId } from '../code_assist/types.js'; @@ -50,7 +49,6 @@ export enum AuthType { } export type ContentGeneratorConfig = { - model: string; apiKey?: string; vertexai?: boolean; authType?: AuthType; @@ -66,11 +64,7 @@ export function createContentGeneratorConfig( const googleCloudProject = process.env['GOOGLE_CLOUD_PROJECT'] || undefined; const googleCloudLocation = process.env['GOOGLE_CLOUD_LOCATION'] || undefined; - // Use runtime model from config if available; otherwise, fall back to parameter or default - const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL; - const contentGeneratorConfig: ContentGeneratorConfig = { - model: effectiveModel, authType, proxy: config?.getProxy(), }; diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 06f235e6b8..4d5b6f4ab1 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -169,6 +169,7 @@ describe('GeminiChat', () => { // 2. Action & Assert: The stream processing should complete without throwing an error // because the presence of a tool call makes the empty final chunk acceptable. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-tool-call-empty-end', ); @@ -220,6 +221,7 @@ describe('GeminiChat', () => { // 2. Action & Assert: The stream should fail because there's no finish reason. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-no-finish-empty-end', ); @@ -265,6 +267,7 @@ describe('GeminiChat', () => { // 2. Action & Assert: The stream should complete without throwing an error. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-valid-then-invalid-end', ); @@ -321,6 +324,7 @@ describe('GeminiChat', () => { // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-malformed-chunk', ); @@ -371,6 +375,7 @@ describe('GeminiChat', () => { // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-empty-chunk-consolidation', ); @@ -428,6 +433,7 @@ describe('GeminiChat', () => { // 2. Action: Send a message and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-multi-chunk', ); @@ -475,6 +481,7 @@ describe('GeminiChat', () => { // 2. Action: Send a message and fully consume the stream to trigger history recording. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test message' }, 'prompt-id-mixed-chunk', ); @@ -537,6 +544,7 @@ describe('GeminiChat', () => { // 3. Action: Send the function response back to the model and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: { functionResponse: { @@ -588,17 +596,23 @@ describe('GeminiChat', () => { ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'hello' }, 'prompt-id-1', ); for await (const _ of stream) { - // consume stream to trigger internal logic + // consume stream } expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( { - model: 'gemini-pro', - contents: [{ role: 'user', parts: [{ text: 'hello' }] }], + model: 'test-model', + contents: [ + { + role: 'user', + parts: [{ text: 'hello' }], + }, + ], config: {}, }, 'prompt-id-1', @@ -809,6 +823,7 @@ describe('GeminiChat', () => { // ACT: Send a message and collect all events from the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-yield-retry', ); @@ -849,6 +864,7 @@ describe('GeminiChat', () => { ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-retry-success', ); @@ -909,6 +925,7 @@ describe('GeminiChat', () => { ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-retry-fail', ); @@ -964,6 +981,7 @@ describe('GeminiChat', () => { // 3. Send a new message const stream = await chat.sendMessageStream( + 'test-model', { message: 'Second question' }, 'prompt-id-retry-existing', ); @@ -1034,6 +1052,7 @@ describe('GeminiChat', () => { // 2. Call the method and consume the stream. const stream = await chat.sendMessageStream( + 'test-model', { message: 'test empty stream' }, 'prompt-id-empty-stream', ); @@ -1113,6 +1132,7 @@ describe('GeminiChat', () => { // 3. Start the first stream and consume only the first chunk to pause it const firstStream = await chat.sendMessageStream( + 'test-model', { message: 'first' }, 'prompt-1', ); @@ -1121,6 +1141,7 @@ describe('GeminiChat', () => { // 4. While the first stream is paused, start the second call. It will block. const secondStreamPromise = chat.sendMessageStream( + 'test-model', { message: 'second' }, 'prompt-2', ); @@ -1180,6 +1201,7 @@ describe('GeminiChat', () => { ); const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-res3', ); @@ -1234,12 +1256,9 @@ describe('GeminiChat', () => { }); it('should call handleFallback with the specific failed model and retry if handler returns true', async () => { - const FAILED_MODEL = 'gemini-2.5-pro'; - vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL); const authType = AuthType.LOGIN_WITH_GOOGLE; vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ authType, - model: FAILED_MODEL, }); const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode'); @@ -1267,6 +1286,7 @@ describe('GeminiChat', () => { }); const stream = await chat.sendMessageStream( + 'test-model', { message: 'trigger 429' }, 'prompt-id-fb1', ); @@ -1282,7 +1302,7 @@ describe('GeminiChat', () => { expect(mockHandleFallback).toHaveBeenCalledTimes(1); expect(mockHandleFallback).toHaveBeenCalledWith( mockConfig, - FAILED_MODEL, + 'test-model', authType, error429, ); @@ -1300,6 +1320,7 @@ describe('GeminiChat', () => { mockHandleFallback.mockResolvedValue(false); const stream = await chat.sendMessageStream( + 'test-model', { message: 'test stop' }, 'prompt-id-fb2', ); @@ -1357,6 +1378,7 @@ describe('GeminiChat', () => { // Send a message and consume the stream const stream = await chat.sendMessageStream( + 'test-model', { message: 'test' }, 'prompt-id-discard-test', ); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index f054578aaa..9f7f19ba9a 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -19,7 +19,10 @@ import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; import type { Config } from '../config/config.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + getEffectiveModel, +} from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import type { StructuredError } from './turn.js'; import type { CompletedToolCall } from './coreToolScheduler.js'; @@ -206,6 +209,7 @@ export class GeminiChat { * ``` */ async sendMessageStream( + model: string, params: SendMessageParameters, prompt_id: string, ): Promise> { @@ -253,6 +257,7 @@ export class GeminiChat { } const stream = await self.makeApiCallAndProcessStream( + model, requestContents, params, prompt_id, @@ -317,18 +322,17 @@ export class GeminiChat { } private async makeApiCallAndProcessStream( + model: string, requestContents: Content[], params: SendMessageParameters, prompt_id: string, userContent: Content, ): Promise> { - let currentAttemptModel: string | undefined; - const apiCall = () => { - const modelToUse = this.config.isInFallbackMode() - ? DEFAULT_GEMINI_FLASH_MODEL - : this.config.getModel(); - currentAttemptModel = modelToUse; + const modelToUse = getEffectiveModel( + this.config.isInFallbackMode(), + model, + ); if ( this.config.getQuotaErrorOccurred() && @@ -352,15 +356,7 @@ export class GeminiChat { const onPersistent429Callback = async ( authType?: string, error?: unknown, - ) => { - if (!currentAttemptModel) return null; - return await handleFallback( - this.config, - currentAttemptModel, - authType, - error, - ); - }; + ) => await handleFallback(this.config, model, authType, error); const streamResponse = await retryWithBackoff(apiCall, { shouldRetry: (error: unknown) => { diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts index e3b51cacc7..ed5b549097 100644 --- a/packages/core/src/core/subagent.test.ts +++ b/packages/core/src/core/subagent.test.ts @@ -499,7 +499,7 @@ describe('subagent.ts', () => { expect(scope.output.emitted_vars).toEqual({}); expect(mockSendMessageStream).toHaveBeenCalledTimes(1); // Check the initial message - expect(mockSendMessageStream.mock.calls[0][0].message).toEqual([ + expect(mockSendMessageStream.mock.calls[0][1].message).toEqual([ { text: 'Get Started!' }, ]); }); @@ -543,7 +543,7 @@ describe('subagent.ts', () => { expect(mockSendMessageStream).toHaveBeenCalledTimes(1); // Check the tool response sent back in the second call - const secondCallArgs = mockSendMessageStream.mock.calls[0][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[0][1]; expect(secondCallArgs.message).toEqual([{ text: 'Get Started!' }]); }); @@ -605,7 +605,7 @@ describe('subagent.ts', () => { ); // Check the response sent back to the model - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; expect(secondCallArgs.message).toEqual([ { text: 'file1.txt\nfile2.ts' }, ]); @@ -653,7 +653,7 @@ describe('subagent.ts', () => { await scope.runNonInteractive(new ContextState()); // The agent should send the specific error message from responseParts. - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; expect(secondCallArgs.message).toEqual([ { @@ -699,7 +699,7 @@ describe('subagent.ts', () => { await scope.runNonInteractive(new ContextState()); // Check the nudge message sent in Turn 2 - const secondCallArgs = mockSendMessageStream.mock.calls[1][0]; + const secondCallArgs = mockSendMessageStream.mock.calls[1][1]; // We check that the message contains the required variable name and the nudge phrasing. expect(secondCallArgs.message[0].text).toContain('required_var'); diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts index fe5ac1ff87..15cf5af910 100644 --- a/packages/core/src/core/subagent.ts +++ b/packages/core/src/core/subagent.ts @@ -430,6 +430,7 @@ export class SubAgentScope { }; const responseStream = await chat.sendMessageStream( + this.modelConfig.model, messageParams, promptId, ); diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index 16fdd90fd9..d3451166a9 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -97,6 +97,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Hi' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -104,6 +105,7 @@ describe('Turn', () => { } expect(mockSendMessageStream).toHaveBeenCalledWith( + 'test-model', { message: reqParts, config: { abortSignal: expect.any(AbortSignal) }, @@ -144,6 +146,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Use tools' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -206,7 +209,11 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test abort' }]; - for await (const event of turn.run(reqParts, abortController.signal)) { + for await (const event of turn.run( + 'test-model', + reqParts, + abortController.signal, + )) { events.push(event); } expect(events).toEqual([ @@ -227,6 +234,7 @@ describe('Turn', () => { mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined); const events = []; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -267,6 +275,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test undefined tool parts' }], new AbortController().signal, )) { @@ -323,6 +332,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test finish reason' }], new AbortController().signal, )) { @@ -370,6 +380,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Generate long text' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -407,6 +418,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test safety' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -443,6 +455,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test no finish reason' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -487,6 +500,7 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test multiple responses' }]; for await (const event of turn.run( + 'test-model', reqParts, new AbortController().signal, )) { @@ -529,6 +543,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'Test citations' }], new AbortController().signal, )) { @@ -578,6 +593,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -624,6 +640,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -669,6 +686,7 @@ describe('Turn', () => { const events = []; for await (const event of turn.run( + 'test-model', [{ text: 'test' }], new AbortController().signal, )) { @@ -705,7 +723,11 @@ describe('Turn', () => { const events = []; const reqParts: Part[] = [{ text: 'Test malformed error handling' }]; - for await (const event of turn.run(reqParts, abortController.signal)) { + for await (const event of turn.run( + 'test-model', + reqParts, + abortController.signal, + )) { events.push(event); } @@ -727,7 +749,11 @@ describe('Turn', () => { mockSendMessageStream.mockResolvedValue(mockResponseStream); const events = []; - for await (const event of turn.run([], new AbortController().signal)) { + for await (const event of turn.run( + 'test-model', + [], + new AbortController().signal, + )) { events.push(event); } @@ -752,7 +778,11 @@ describe('Turn', () => { })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; - for await (const _ of turn.run(reqParts, new AbortController().signal)) { + for await (const _ of turn.run( + 'test-model', + reqParts, + new AbortController().signal, + )) { // consume stream } expect(turn.getDebugResponses()).toEqual([resp1, resp2]); diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 66e7801448..8f38f7d2ed 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -211,6 +211,7 @@ export class Turn { ) {} // The run method yields simpler events suitable for server logic async *run( + model: string, req: PartListUnion, signal: AbortSignal, ): AsyncGenerator { @@ -218,6 +219,7 @@ export class Turn { // Note: This assumes `sendMessageStream` yields events like // { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse } const responseStream = await this.chat.sendMessageStream( + model, { message: req, config: { diff --git a/packages/core/src/routing/modelRouterService.test.ts b/packages/core/src/routing/modelRouterService.test.ts new file mode 100644 index 0000000000..0f83796787 --- /dev/null +++ b/packages/core/src/routing/modelRouterService.test.ts @@ -0,0 +1,96 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ModelRouterService } from './modelRouterService.js'; +import { Config } from '../config/config.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { RoutingContext, RoutingDecision } from './routingStrategy.js'; +import { DefaultStrategy } from './strategies/defaultStrategy.js'; +import { CompositeStrategy } from './strategies/compositeStrategy.js'; +import { FallbackStrategy } from './strategies/fallbackStrategy.js'; +import { OverrideStrategy } from './strategies/overrideStrategy.js'; + +vi.mock('../config/config.js'); +vi.mock('../core/baseLlmClient.js'); +vi.mock('./strategies/defaultStrategy.js'); +vi.mock('./strategies/compositeStrategy.js'); +vi.mock('./strategies/fallbackStrategy.js'); +vi.mock('./strategies/overrideStrategy.js'); + +describe('ModelRouterService', () => { + let service: ModelRouterService; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + let mockContext: RoutingContext; + let mockCompositeStrategy: CompositeStrategy; + + beforeEach(() => { + vi.clearAllMocks(); + + mockConfig = new Config({} as never); + mockBaseLlmClient = {} as BaseLlmClient; + vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient); + + mockCompositeStrategy = new CompositeStrategy( + [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + 'agent-router', + ); + vi.mocked(CompositeStrategy).mockImplementation( + () => mockCompositeStrategy, + ); + + service = new ModelRouterService(mockConfig); + + mockContext = { + history: [], + request: [{ text: 'test prompt' }], + signal: new AbortController().signal, + }; + }); + + it('should initialize with a CompositeStrategy', () => { + expect(CompositeStrategy).toHaveBeenCalled(); + expect(service['strategy']).toBeInstanceOf(CompositeStrategy); + }); + + it('should initialize the CompositeStrategy with the correct child strategies in order', () => { + // This test relies on the mock implementation detail of the constructor + const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0]; + const childStrategies = compositeStrategyArgs[0]; + + expect(childStrategies.length).toBe(3); + expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy); + expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy); + expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy); + expect(compositeStrategyArgs[1]).toBe('agent-router'); + }); + + describe('route()', () => { + it('should delegate routing to the composite strategy', async () => { + const strategyDecision: RoutingDecision = { + model: 'strategy-chosen-model', + metadata: { + source: 'test-router/fallback', + latencyMs: 10, + reasoning: 'Strategy reasoning', + }, + }; + const strategySpy = vi + .spyOn(mockCompositeStrategy, 'route') + .mockResolvedValue(strategyDecision); + + const decision = await service.route(mockContext); + + expect(strategySpy).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(decision).toEqual(strategyDecision); + }); + }); +}); diff --git a/packages/core/src/routing/modelRouterService.ts b/packages/core/src/routing/modelRouterService.ts new file mode 100644 index 0000000000..a984125f89 --- /dev/null +++ b/packages/core/src/routing/modelRouterService.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { + RoutingContext, + RoutingDecision, + TerminalStrategy, +} from './routingStrategy.js'; +import { DefaultStrategy } from './strategies/defaultStrategy.js'; +import { CompositeStrategy } from './strategies/compositeStrategy.js'; +import { FallbackStrategy } from './strategies/fallbackStrategy.js'; +import { OverrideStrategy } from './strategies/overrideStrategy.js'; + +/** + * A centralized service for making model routing decisions. + */ +export class ModelRouterService { + private config: Config; + private strategy: TerminalStrategy; + + constructor(config: Config) { + this.config = config; + this.strategy = this.initializeDefaultStrategy(); + } + + private initializeDefaultStrategy(): TerminalStrategy { + // Initialize the composite strategy with the desired priority order. + // The strategies are ordered in order of highest priority. + return new CompositeStrategy( + [new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()], + 'agent-router', + ); + } + + /** + * Determines which model to use for a given request context. + * + * @param context The full context of the request. + * @returns A promise that resolves to a RoutingDecision. + */ + async route(context: RoutingContext): Promise { + const decision = await this.strategy.route( + context, + this.config, + this.config.getBaseLlmClient(), + ); + + return decision; + } +} diff --git a/packages/core/src/routing/routingStrategy.ts b/packages/core/src/routing/routingStrategy.ts new file mode 100644 index 0000000000..d5d8df8dc9 --- /dev/null +++ b/packages/core/src/routing/routingStrategy.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content, PartListUnion } from '@google/genai'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; +import type { Config } from '../config/config.js'; + +/** + * The output of a routing decision. It specifies which model to use and why. + */ +export interface RoutingDecision { + /** The model identifier string to use for the next API call (e.g., 'gemini-2.5-pro'). */ + model: string; + /** + * Metadata about the routing decision for logging purposes. + */ + metadata: { + source: string; + latencyMs: number; + reasoning: string; + error?: string; + }; +} + +/** + * The context provided to the router for making a decision. + */ +export interface RoutingContext { + /** The full history of the conversation. */ + history: Content[]; + /** The immediate request parts to be processed. */ + request: PartListUnion; + /** An abort signal to cancel an LLM call during routing. */ + signal: AbortSignal; +} + +/** + * The core interface that all routing strategies must implement. + * Strategies implementing this interface may decline a request by returning null. + */ +export interface RoutingStrategy { + /** The name of the strategy (e.g., 'fallback', 'override', 'composite'). */ + readonly name: string; + + /** + * Determines which model to use for a given request context. + * @param context The full context of the request. + * @param config The current configuration. + * @param client A reference to the GeminiClient, allowing the strategy to make its own API calls if needed. + * @returns A promise that resolves to a RoutingDecision, or null if the strategy is not applicable. + */ + route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise; +} + +/** + * A strategy that is guaranteed to return a decision. It must not return null. + * This is used to ensure that a composite chain always terminates. + */ +export interface TerminalStrategy extends RoutingStrategy { + /** + * Determines which model to use for a given request context. + * @returns A promise that resolves to a RoutingDecision. + */ + route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise; +} diff --git a/packages/core/src/routing/strategies/compositeStrategy.test.ts b/packages/core/src/routing/strategies/compositeStrategy.test.ts new file mode 100644 index 0000000000..7fb2c393b4 --- /dev/null +++ b/packages/core/src/routing/strategies/compositeStrategy.test.ts @@ -0,0 +1,215 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { CompositeStrategy } from './compositeStrategy.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, + TerminalStrategy, +} from '../routingStrategy.js'; +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; + +describe('CompositeStrategy', () => { + let mockContext: RoutingContext; + let mockConfig: Config; + let mockBaseLlmClient: BaseLlmClient; + let mockStrategy1: RoutingStrategy; + let mockStrategy2: RoutingStrategy; + let mockTerminalStrategy: TerminalStrategy; + + beforeEach(() => { + vi.clearAllMocks(); + + mockContext = {} as RoutingContext; + mockConfig = {} as Config; + mockBaseLlmClient = {} as BaseLlmClient; + + mockStrategy1 = { + name: 'strategy1', + route: vi.fn().mockResolvedValue(null), + }; + + mockStrategy2 = { + name: 'strategy2', + route: vi.fn().mockResolvedValue(null), + }; + + mockTerminalStrategy = { + name: 'terminal', + route: vi.fn().mockResolvedValue({ + model: 'terminal-model', + metadata: { + source: 'terminal', + latencyMs: 10, + reasoning: 'Terminal decision', + }, + }), + }; + }); + + it('should try strategies in order and return the first successful decision', async () => { + const decision: RoutingDecision = { + model: 'strategy2-model', + metadata: { + source: 'strategy2', + latencyMs: 20, + reasoning: 'Strategy 2 decided', + }, + }; + vi.spyOn(mockStrategy2, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockStrategy2, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockStrategy1.route).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(mockStrategy2.route).toHaveBeenCalledWith( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + expect(mockTerminalStrategy.route).not.toHaveBeenCalled(); + + expect(result.model).toBe('strategy2-model'); + expect(result.metadata.source).toBe('test-router/strategy2'); + }); + + it('should fall back to the terminal strategy if no other strategy provides a decision', async () => { + const composite = new CompositeStrategy( + [mockStrategy1, mockStrategy2, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(mockStrategy1.route).toHaveBeenCalledTimes(1); + expect(mockStrategy2.route).toHaveBeenCalledTimes(1); + expect(mockTerminalStrategy.route).toHaveBeenCalledTimes(1); + + expect(result.model).toBe('terminal-model'); + expect(result.metadata.source).toBe('test-router/terminal'); + }); + + it('should handle errors in non-terminal strategies and continue', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + vi.spyOn(mockStrategy1, 'route').mockRejectedValue( + new Error('Strategy 1 failed'), + ); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'test-router', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + "[Routing] Strategy 'strategy1' failed. Continuing to next strategy. Error:", + expect.any(Error), + ); + expect(result.model).toBe('terminal-model'); + consoleErrorSpy.mockRestore(); + }); + + it('should re-throw an error from the terminal strategy', async () => { + const consoleErrorSpy = vi + .spyOn(console, 'error') + .mockImplementation(() => {}); + const terminalError = new Error('Terminal strategy failed'); + vi.spyOn(mockTerminalStrategy, 'route').mockRejectedValue(terminalError); + + const composite = new CompositeStrategy([mockTerminalStrategy]); + + await expect( + composite.route(mockContext, mockConfig, mockBaseLlmClient), + ).rejects.toThrow(terminalError); + + expect(consoleErrorSpy).toHaveBeenCalledWith( + "[Routing] Critical Error: Terminal strategy 'terminal' failed. Routing cannot proceed. Error:", + terminalError, + ); + consoleErrorSpy.mockRestore(); + }); + + it('should correctly finalize the decision metadata', async () => { + const decision: RoutingDecision = { + model: 'some-model', + metadata: { + source: 'child-source', + latencyMs: 50, + reasoning: 'Child reasoning', + }, + }; + vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'my-composite', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(result.model).toBe('some-model'); + expect(result.metadata.source).toBe('my-composite/child-source'); + expect(result.metadata.reasoning).toBe('Child reasoning'); + // It should keep the child's latency + expect(result.metadata.latencyMs).toBe(50); + }); + + it('should calculate total latency if child latency is not provided', async () => { + const decision: RoutingDecision = { + model: 'some-model', + metadata: { + source: 'child-source', + // No latencyMs here + latencyMs: 0, + reasoning: 'Child reasoning', + }, + }; + vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision); + + const composite = new CompositeStrategy( + [mockStrategy1, mockTerminalStrategy], + 'my-composite', + ); + + const result = await composite.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(result.metadata.latencyMs).toBeGreaterThanOrEqual(0); + }); +}); diff --git a/packages/core/src/routing/strategies/compositeStrategy.ts b/packages/core/src/routing/strategies/compositeStrategy.ts new file mode 100644 index 0000000000..42646fc4e3 --- /dev/null +++ b/packages/core/src/routing/strategies/compositeStrategy.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, + TerminalStrategy, +} from '../routingStrategy.js'; + +/** + * A strategy that attempts a list of child strategies in order (Chain of Responsibility). + */ +export class CompositeStrategy implements TerminalStrategy { + readonly name: string; + + private strategies: [...RoutingStrategy[], TerminalStrategy]; + + /** + * Initializes the CompositeStrategy. + * @param strategies The strategies to try, in order of priority. The last strategy must be terminal. + * @param name The name of this composite configuration (e.g., 'router' or 'composite'). + */ + constructor( + strategies: [...RoutingStrategy[], TerminalStrategy], + name: string = 'composite', + ) { + this.strategies = strategies; + this.name = name; + } + + async route( + context: RoutingContext, + config: Config, + baseLlmClient: BaseLlmClient, + ): Promise { + const startTime = performance.now(); + + // Separate non-terminal strategies from the terminal one. + // This separation allows TypeScript to understand the control flow guarantees. + const nonTerminalStrategies = this.strategies.slice( + 0, + -1, + ) as RoutingStrategy[]; + const terminalStrategy = this.strategies[ + this.strategies.length - 1 + ] as TerminalStrategy; + + // Try non-terminal strategies, allowing them to fail gracefully. + for (const strategy of nonTerminalStrategies) { + try { + const decision = await strategy.route(context, config, baseLlmClient); + if (decision) { + return this.finalizeDecision(decision, startTime); + } + } catch (error) { + console.error( + `[Routing] Strategy '${strategy.name}' failed. Continuing to next strategy. Error:`, + error, + ); + } + } + + // If no other strategy matched, execute the terminal strategy. + try { + const decision = await terminalStrategy.route( + context, + config, + baseLlmClient, + ); + + return this.finalizeDecision(decision, startTime); + } catch (error) { + console.error( + `[Routing] Critical Error: Terminal strategy '${terminalStrategy.name}' failed. Routing cannot proceed. Error:`, + error, + ); + throw error; + } + } + + /** + * Helper function to enhance the decision metadata with composite information. + */ + private finalizeDecision( + decision: RoutingDecision, + startTime: number, + ): RoutingDecision { + const endTime = performance.now(); + const totalLatency = endTime - startTime; + + // Combine the source paths: composite_name/child_source (e.g. 'router/default') + const compositeSource = `${this.name}/${decision.metadata.source}`; + + return { + ...decision, + metadata: { + ...decision.metadata, + source: compositeSource, + latencyMs: decision.metadata.latencyMs || totalLatency, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/defaultStrategy.test.ts b/packages/core/src/routing/strategies/defaultStrategy.test.ts new file mode 100644 index 0000000000..1c739545a4 --- /dev/null +++ b/packages/core/src/routing/strategies/defaultStrategy.test.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { DefaultStrategy } from './defaultStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import { DEFAULT_GEMINI_MODEL } from '../../config/models.js'; +import type { Config } from '../../config/config.js'; + +describe('DefaultStrategy', () => { + it('should always route to the default Gemini model', async () => { + const strategy = new DefaultStrategy(); + const mockContext = {} as RoutingContext; + const mockConfig = {} as Config; + const mockClient = {} as BaseLlmClient; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).toEqual({ + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: 'default', + latencyMs: 0, + reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`, + }, + }); + }); +}); diff --git a/packages/core/src/routing/strategies/defaultStrategy.ts b/packages/core/src/routing/strategies/defaultStrategy.ts new file mode 100644 index 0000000000..dba7949f9e --- /dev/null +++ b/packages/core/src/routing/strategies/defaultStrategy.ts @@ -0,0 +1,33 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + TerminalStrategy, +} from '../routingStrategy.js'; +import { DEFAULT_GEMINI_MODEL } from '../../config/models.js'; + +export class DefaultStrategy implements TerminalStrategy { + readonly name = 'default'; + + async route( + _context: RoutingContext, + _config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + return { + model: DEFAULT_GEMINI_MODEL, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/fallbackStrategy.test.ts b/packages/core/src/routing/strategies/fallbackStrategy.test.ts new file mode 100644 index 0000000000..dfda72d4ca --- /dev/null +++ b/packages/core/src/routing/strategies/fallbackStrategy.test.ts @@ -0,0 +1,86 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { FallbackStrategy } from './fallbackStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { Config } from '../../config/config.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_FLASH_LITE_MODEL, +} from '../../config/models.js'; + +describe('FallbackStrategy', () => { + const strategy = new FallbackStrategy(); + const mockContext = {} as RoutingContext; + const mockClient = {} as BaseLlmClient; + + it('should return null when not in fallback mode', async () => { + const mockConfig = { + isInFallbackMode: () => false, + getModel: () => DEFAULT_GEMINI_MODEL, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + expect(decision).toBeNull(); + }); + + describe('when in fallback mode', () => { + it('should downgrade a pro model to the flash model', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + expect(decision?.metadata.reasoning).toContain('In fallback mode'); + }); + + it('should honor a lite model request', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_FLASH_LITE_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + }); + + it('should use the flash model if flash is requested', async () => { + const mockConfig = { + isInFallbackMode: () => true, + getModel: () => DEFAULT_GEMINI_FLASH_MODEL, + } as Config; + + const decision = await strategy.route( + mockContext, + mockConfig, + mockClient, + ); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); + expect(decision?.metadata.source).toBe('fallback'); + }); + }); +}); diff --git a/packages/core/src/routing/strategies/fallbackStrategy.ts b/packages/core/src/routing/strategies/fallbackStrategy.ts new file mode 100644 index 0000000000..aef01743aa --- /dev/null +++ b/packages/core/src/routing/strategies/fallbackStrategy.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import { getEffectiveModel } from '../../config/models.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; + +export class FallbackStrategy implements RoutingStrategy { + readonly name = 'fallback'; + + async route( + _context: RoutingContext, + config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + const isInFallbackMode: boolean = config.isInFallbackMode(); + + if (!isInFallbackMode) { + return null; + } + + const effectiveModel = getEffectiveModel( + isInFallbackMode, + config.getModel(), + ); + return { + model: effectiveModel, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `In fallback mode. Using: ${effectiveModel}`, + }, + }; + } +} diff --git a/packages/core/src/routing/strategies/overrideStrategy.test.ts b/packages/core/src/routing/strategies/overrideStrategy.test.ts new file mode 100644 index 0000000000..69c4088f8d --- /dev/null +++ b/packages/core/src/routing/strategies/overrideStrategy.test.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { OverrideStrategy } from './overrideStrategy.js'; +import type { RoutingContext } from '../routingStrategy.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { Config } from '../../config/config.js'; + +describe('OverrideStrategy', () => { + const strategy = new OverrideStrategy(); + const mockContext = {} as RoutingContext; + const mockClient = {} as BaseLlmClient; + + it('should return null when no override model is specified', async () => { + const mockConfig = { + getModel: () => '', // Simulate no model override + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + expect(decision).toBeNull(); + }); + + it('should return a decision with the override model when one is specified', async () => { + const overrideModel = 'gemini-2.5-pro-custom'; + const mockConfig = { + getModel: () => overrideModel, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(overrideModel); + expect(decision?.metadata.source).toBe('override'); + expect(decision?.metadata.reasoning).toContain( + 'Routing bypassed by forced model directive', + ); + expect(decision?.metadata.reasoning).toContain(overrideModel); + }); + + it('should handle different override model names', async () => { + const overrideModel = 'gemini-2.5-flash-experimental'; + const mockConfig = { + getModel: () => overrideModel, + } as Config; + + const decision = await strategy.route(mockContext, mockConfig, mockClient); + + expect(decision).not.toBeNull(); + expect(decision?.model).toBe(overrideModel); + }); +}); diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts new file mode 100644 index 0000000000..b3aef6c332 --- /dev/null +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -0,0 +1,40 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../../config/config.js'; +import type { BaseLlmClient } from '../../core/baseLlmClient.js'; +import type { + RoutingContext, + RoutingDecision, + RoutingStrategy, +} from '../routingStrategy.js'; + +/** + * Handles cases where the user explicitly specifies a model (override). + */ +export class OverrideStrategy implements RoutingStrategy { + readonly name = 'override'; + + async route( + _context: RoutingContext, + config: Config, + _baseLlmClient: BaseLlmClient, + ): Promise { + const overrideModel = config.getModel(); + if (overrideModel) { + return { + model: overrideModel, + metadata: { + source: this.name, + latencyMs: 0, + reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`, + }, + }; + } + // No override specified, pass to the next strategy. + return null; + } +}