mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-21 02:24:09 -07:00
feat(routing): Initialize model routing architecture (#8153)
This commit is contained in:
@@ -206,9 +206,7 @@ describe('Server Config (config.ts)', () => {
|
||||
it('should refresh auth and update config', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
const newModel = 'gemini-flash';
|
||||
const mockContentConfig = {
|
||||
model: newModel,
|
||||
apiKey: 'test-key',
|
||||
};
|
||||
|
||||
@@ -226,10 +224,8 @@ describe('Server Config (config.ts)', () => {
|
||||
config,
|
||||
authType,
|
||||
);
|
||||
// Verify that contentGeneratorConfig is updated with the new model
|
||||
// Verify that contentGeneratorConfig is updated
|
||||
expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);
|
||||
expect(config.getContentGeneratorConfig().model).toBe(newModel);
|
||||
expect(config.getModel()).toBe(newModel); // getModel() should return the updated model
|
||||
expect(GeminiClient).toHaveBeenCalledWith(config);
|
||||
// Verify that fallback mode is reset
|
||||
expect(config.isInFallbackMode()).toBe(false);
|
||||
|
||||
@@ -44,6 +44,7 @@ import { StartSessionEvent } from '../telemetry/index.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from './models.js';
|
||||
import { shouldAttemptBrowserLaunch } from '../utils/browser.js';
|
||||
import type { MCPOAuthConfig } from '../mcp/oauth-provider.js';
|
||||
@@ -62,6 +63,7 @@ import {
|
||||
RipgrepFallbackEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
import { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
import { OutputFormat } from '../output/types.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
@@ -270,6 +272,7 @@ export class Config {
|
||||
private readonly usageStatisticsEnabled: boolean;
|
||||
private geminiClient!: GeminiClient;
|
||||
private baseLlmClient!: BaseLlmClient;
|
||||
private modelRouterService: ModelRouterService;
|
||||
private readonly fileFiltering: {
|
||||
respectGitIgnore: boolean;
|
||||
respectGeminiIgnore: boolean;
|
||||
@@ -282,7 +285,7 @@ export class Config {
|
||||
private readonly proxy: string | undefined;
|
||||
private readonly cwd: string;
|
||||
private readonly bugCommand: BugCommandSettings | undefined;
|
||||
private readonly model: string;
|
||||
private model: string;
|
||||
private readonly extensionContextFilePaths: string[];
|
||||
private readonly noBrowser: boolean;
|
||||
private readonly folderTrustFeature: boolean;
|
||||
@@ -372,7 +375,7 @@ export class Config {
|
||||
this.cwd = params.cwd ?? process.cwd();
|
||||
this.fileDiscoveryService = params.fileDiscoveryService ?? null;
|
||||
this.bugCommand = params.bugCommand;
|
||||
this.model = params.model;
|
||||
this.model = params.model || DEFAULT_GEMINI_MODEL;
|
||||
this.extensionContextFilePaths = params.extensionContextFilePaths ?? [];
|
||||
this.maxSessionTurns = params.maxSessionTurns ?? -1;
|
||||
this.experimentalZedIntegration =
|
||||
@@ -424,6 +427,7 @@ export class Config {
|
||||
setGlobalDispatcher(new ProxyAgent(this.getProxy() as string));
|
||||
}
|
||||
this.geminiClient = new GeminiClient(this);
|
||||
this.modelRouterService = new ModelRouterService(this);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -523,13 +527,16 @@ export class Config {
|
||||
}
|
||||
|
||||
getModel(): string {
|
||||
return this.contentGeneratorConfig?.model || this.model;
|
||||
return this.model;
|
||||
}
|
||||
|
||||
setModel(newModel: string): void {
|
||||
if (this.contentGeneratorConfig) {
|
||||
this.contentGeneratorConfig.model = newModel;
|
||||
// Do not allow Pro usage if the user is in fallback mode.
|
||||
if (newModel.includes('pro') && this.isInFallbackMode()) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.model = newModel;
|
||||
}
|
||||
|
||||
isInFallbackMode(): boolean {
|
||||
@@ -699,6 +706,10 @@ export class Config {
|
||||
return this.geminiClient;
|
||||
}
|
||||
|
||||
getModelRouterService(): ModelRouterService {
|
||||
return this.modelRouterService;
|
||||
}
|
||||
|
||||
getEnableRecursiveFileSearch(): boolean {
|
||||
return this.fileFiltering.enableRecursiveFileSearch;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
getEffectiveModel,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
} from './models.js';
|
||||
|
||||
describe('getEffectiveModel', () => {
|
||||
describe('When NOT in fallback mode', () => {
|
||||
const isInFallbackMode = false;
|
||||
|
||||
it('should return the Pro model when Pro is requested', () => {
|
||||
const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL);
|
||||
expect(model).toBe(DEFAULT_GEMINI_MODEL);
|
||||
});
|
||||
|
||||
it('should return the Flash model when Flash is requested', () => {
|
||||
const model = getEffectiveModel(
|
||||
isInFallbackMode,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
it('should return the Lite model when Lite is requested', () => {
|
||||
const model = getEffectiveModel(
|
||||
isInFallbackMode,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL);
|
||||
});
|
||||
|
||||
it('should return a custom model name when requested', () => {
|
||||
const customModel = 'custom-model-v1';
|
||||
const model = getEffectiveModel(isInFallbackMode, customModel);
|
||||
expect(model).toBe(customModel);
|
||||
});
|
||||
});
|
||||
|
||||
describe('When IN fallback mode', () => {
|
||||
const isInFallbackMode = true;
|
||||
|
||||
it('should downgrade the Pro model to the Flash model', () => {
|
||||
const model = getEffectiveModel(isInFallbackMode, DEFAULT_GEMINI_MODEL);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
it('should return the Flash model when Flash is requested', () => {
|
||||
const model = getEffectiveModel(
|
||||
isInFallbackMode,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
it('should HONOR the Lite model when Lite is requested', () => {
|
||||
const model = getEffectiveModel(
|
||||
isInFallbackMode,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL);
|
||||
});
|
||||
|
||||
it('should HONOR any model with "lite" in its name', () => {
|
||||
const customLiteModel = 'gemini-2.5-custom-lite-vNext';
|
||||
const model = getEffectiveModel(isInFallbackMode, customLiteModel);
|
||||
expect(model).toBe(customLiteModel);
|
||||
});
|
||||
|
||||
it('should downgrade any other custom model to the Flash model', () => {
|
||||
const customModel = 'custom-model-v1-unlisted';
|
||||
const model = getEffectiveModel(isInFallbackMode, customModel);
|
||||
expect(model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -12,3 +12,35 @@ export const DEFAULT_GEMINI_EMBEDDING_MODEL = 'gemini-embedding-001';
|
||||
|
||||
// Some thinking models do not default to dynamic thinking which is done by a value of -1
|
||||
export const DEFAULT_THINKING_MODE = -1;
|
||||
|
||||
/**
|
||||
* Determines the effective model to use, applying fallback logic if necessary.
|
||||
*
|
||||
* When fallback mode is active, this function enforces the use of the standard
|
||||
* fallback model. However, it makes an exception for "lite" models (any model
|
||||
* with "lite" in its name), allowing them to be used to preserve cost savings.
|
||||
* This ensures that "pro" models are always downgraded, while "lite" model
|
||||
* requests are honored.
|
||||
*
|
||||
* @param isInFallbackMode Whether the application is in fallback mode.
|
||||
* @param requestedModel The model that was originally requested.
|
||||
* @returns The effective model name.
|
||||
*/
|
||||
export function getEffectiveModel(
|
||||
isInFallbackMode: boolean,
|
||||
requestedModel: string,
|
||||
): string {
|
||||
// If we are not in fallback mode, simply use the requested model.
|
||||
if (!isInFallbackMode) {
|
||||
return requestedModel;
|
||||
}
|
||||
|
||||
// If a "lite" model is requested, honor it. This allows for variations of
|
||||
// lite models without needing to list them all as constants.
|
||||
if (requestedModel.includes('lite')) {
|
||||
return requestedModel;
|
||||
}
|
||||
|
||||
// Default fallback for Gemini CLI.
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
}
|
||||
|
||||
@@ -41,6 +41,7 @@ import { setSimulate429 } from '../utils/testUtils.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import { ideContext } from '../ide/ideContext.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
import type { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -234,7 +235,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
mockContentGenerator = {
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 100 }),
|
||||
embedContent: vi.fn(),
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
@@ -248,7 +249,6 @@ describe('Gemini Client (client.ts)', () => {
|
||||
};
|
||||
const fileService = new FileDiscoveryService('/test/dir');
|
||||
const contentGeneratorConfig: ContentGeneratorConfig = {
|
||||
model: 'test-model',
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
authType: AuthType.USE_GEMINI,
|
||||
@@ -281,6 +281,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
|
||||
}),
|
||||
getGeminiClient: vi.fn(),
|
||||
getModelRouterService: vi.fn().mockReturnValue({
|
||||
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
|
||||
}),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
setFallbackMode: vi.fn(),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
@@ -1116,7 +1119,12 @@ ${JSON.stringify(
|
||||
|
||||
// Assert
|
||||
expect(ideContext.getIdeContext).toHaveBeenCalled();
|
||||
// The `turn.run` method is now called with the model name as the first
|
||||
// argument. We use `expect.any(String)` because this test is
|
||||
// concerned with the IDE context logic, not the model routing,
|
||||
// which is tested in its own dedicated suite.
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
initialRequest,
|
||||
expect.any(Object),
|
||||
);
|
||||
@@ -1506,6 +1514,178 @@ ${JSON.stringify(
|
||||
);
|
||||
});
|
||||
|
||||
describe('Model Routing', () => {
|
||||
let mockRouterService: { route: Mock };
|
||||
|
||||
beforeEach(() => {
|
||||
mockRouterService = {
|
||||
route: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ model: 'routed-model', reason: 'test' }),
|
||||
};
|
||||
vi.mocked(mockConfig.getModelRouterService).mockReturnValue(
|
||||
mockRouterService as unknown as ModelRouterService,
|
||||
);
|
||||
|
||||
mockTurnRunFn.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the model router service to select a model on the first turn', async () => {
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-1',
|
||||
);
|
||||
await fromAsync(stream); // consume stream
|
||||
|
||||
expect(mockConfig.getModelRouterService).toHaveBeenCalled();
|
||||
expect(mockRouterService.route).toHaveBeenCalled();
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model', // The model from the router
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the same model for subsequent turns in the same prompt (stickiness)', async () => {
|
||||
// First turn
|
||||
let stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-1',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// Second turn
|
||||
stream = client.sendMessageStream(
|
||||
[{ text: 'Continue' }],
|
||||
new AbortController().signal,
|
||||
'prompt-1',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
// Router should not be called again
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
// Should stick to the first model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
[{ text: 'Continue' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reset the sticky model and re-route when the prompt_id changes', async () => {
|
||||
// First prompt
|
||||
let stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-1',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'routed-model',
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// New prompt
|
||||
mockRouterService.route.mockResolvedValue({
|
||||
model: 'new-routed-model',
|
||||
reason: 'test',
|
||||
});
|
||||
stream = client.sendMessageStream(
|
||||
[{ text: 'A new topic' }],
|
||||
new AbortController().signal,
|
||||
'prompt-2',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
// Router should be called again for the new prompt
|
||||
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
|
||||
// Should use the newly routed model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
'new-routed-model',
|
||||
[{ text: 'A new topic' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the fallback model and bypass routing when in fallback mode', async () => {
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
mockRouterService.route.mockResolvedValue({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
reason: 'fallback',
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-1',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should stick to the fallback model for the entire sequence even if fallback mode ends', async () => {
|
||||
// Start the sequence in fallback mode
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
mockRouterService.route.mockResolvedValue({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
reason: 'fallback',
|
||||
});
|
||||
let stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-fallback-stickiness',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
// First call should use fallback model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
[{ text: 'Hi' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
|
||||
// End fallback mode
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
|
||||
|
||||
// Second call in the same sequence
|
||||
stream = client.sendMessageStream(
|
||||
[{ text: 'Continue' }],
|
||||
new AbortController().signal,
|
||||
'prompt-fallback-stickiness',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
// Router should still not be called, and it should stick to the fallback model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again
|
||||
expect(mockTurnRunFn).toHaveBeenLastCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL, // Still the fallback model
|
||||
[{ text: 'Continue' }],
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Editor context delta', () => {
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
|
||||
@@ -49,6 +49,7 @@ import {
|
||||
} from '../telemetry/types.js';
|
||||
import type { IdeContext, File } from '../ide/types.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
if (model.startsWith('gemini-2.5')) return true;
|
||||
@@ -118,6 +119,7 @@ export class GeminiClient {
|
||||
|
||||
private readonly loopDetector: LoopDetectionService;
|
||||
private lastPromptId: string;
|
||||
private currentSequenceModel: string | null = null;
|
||||
private lastSentIdeContext: IdeContext | undefined;
|
||||
private forceFullIdeContext = true;
|
||||
|
||||
@@ -430,11 +432,11 @@ export class GeminiClient {
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
originalModel?: string,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
this.currentSequenceModel = null;
|
||||
}
|
||||
this.sessionTurnCount++;
|
||||
if (
|
||||
@@ -450,9 +452,6 @@ export class GeminiClient {
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Track the original model from the first call to detect model switching
|
||||
const initialModel = originalModel || this.config.getModel();
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id);
|
||||
|
||||
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
@@ -494,7 +493,26 @@ export class GeminiClient {
|
||||
return turn;
|
||||
}
|
||||
|
||||
const resultStream = turn.run(request, signal);
|
||||
const routingContext: RoutingContext = {
|
||||
history: this.getChat().getHistory(/*curated=*/ true),
|
||||
request,
|
||||
signal,
|
||||
};
|
||||
|
||||
let modelToUse: string;
|
||||
|
||||
// Determine Model (Stickiness vs. Routing)
|
||||
if (this.currentSequenceModel) {
|
||||
modelToUse = this.currentSequenceModel;
|
||||
} else {
|
||||
const router = await this.config.getModelRouterService();
|
||||
const decision = await router.route(routingContext);
|
||||
modelToUse = decision.model;
|
||||
// Lock the model for the rest of the sequence
|
||||
this.currentSequenceModel = modelToUse;
|
||||
}
|
||||
|
||||
const resultStream = turn.run(modelToUse, request, signal);
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
@@ -506,11 +524,8 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
// Check if model was switched during the call (likely due to quota error)
|
||||
const currentModel = this.config.getModel();
|
||||
if (currentModel !== initialModel) {
|
||||
// Model was switched (likely due to quota error fallback)
|
||||
// Don't continue with recursive call to prevent unwanted Flash execution
|
||||
// Check if next speaker check is needed
|
||||
if (this.config.getQuotaErrorOccurred()) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
@@ -540,7 +555,6 @@ export class GeminiClient {
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
initialModel,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ describe('createContentGenerator', () => {
|
||||
);
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
model: 'test-model',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
},
|
||||
mockConfig,
|
||||
@@ -51,7 +50,6 @@ describe('createContentGenerator', () => {
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
model: 'test-model',
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
},
|
||||
@@ -85,7 +83,6 @@ describe('createContentGenerator', () => {
|
||||
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenerator as never);
|
||||
const generator = await createContentGenerator(
|
||||
{
|
||||
model: 'test-model',
|
||||
apiKey: 'test-api-key',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
},
|
||||
|
||||
@@ -14,7 +14,6 @@ import type {
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import { createCodeAssistContentGenerator } from '../code_assist/codeAssist.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
@@ -50,7 +49,6 @@ export enum AuthType {
|
||||
}
|
||||
|
||||
export type ContentGeneratorConfig = {
|
||||
model: string;
|
||||
apiKey?: string;
|
||||
vertexai?: boolean;
|
||||
authType?: AuthType;
|
||||
@@ -66,11 +64,7 @@ export function createContentGeneratorConfig(
|
||||
const googleCloudProject = process.env['GOOGLE_CLOUD_PROJECT'] || undefined;
|
||||
const googleCloudLocation = process.env['GOOGLE_CLOUD_LOCATION'] || undefined;
|
||||
|
||||
// Use runtime model from config if available; otherwise, fall back to parameter or default
|
||||
const effectiveModel = config.getModel() || DEFAULT_GEMINI_MODEL;
|
||||
|
||||
const contentGeneratorConfig: ContentGeneratorConfig = {
|
||||
model: effectiveModel,
|
||||
authType,
|
||||
proxy: config?.getProxy(),
|
||||
};
|
||||
|
||||
@@ -169,6 +169,7 @@ describe('GeminiChat', () => {
|
||||
// 2. Action & Assert: The stream processing should complete without throwing an error
|
||||
// because the presence of a tool call makes the empty final chunk acceptable.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-tool-call-empty-end',
|
||||
);
|
||||
@@ -220,6 +221,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action & Assert: The stream should fail because there's no finish reason.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-no-finish-empty-end',
|
||||
);
|
||||
@@ -265,6 +267,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action & Assert: The stream should complete without throwing an error.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-valid-then-invalid-end',
|
||||
);
|
||||
@@ -321,6 +324,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-malformed-chunk',
|
||||
);
|
||||
@@ -371,6 +375,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-empty-chunk-consolidation',
|
||||
);
|
||||
@@ -428,6 +433,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-multi-chunk',
|
||||
);
|
||||
@@ -475,6 +481,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Action: Send a message and fully consume the stream to trigger history recording.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test message' },
|
||||
'prompt-id-mixed-chunk',
|
||||
);
|
||||
@@ -537,6 +544,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Action: Send the function response back to the model and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{
|
||||
message: {
|
||||
functionResponse: {
|
||||
@@ -588,17 +596,23 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'hello' },
|
||||
'prompt-id-1',
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream to trigger internal logic
|
||||
// consume stream
|
||||
}
|
||||
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
{
|
||||
model: 'gemini-pro',
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello' }] }],
|
||||
model: 'test-model',
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: 'hello' }],
|
||||
},
|
||||
],
|
||||
config: {},
|
||||
},
|
||||
'prompt-id-1',
|
||||
@@ -809,6 +823,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// ACT: Send a message and collect all events from the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-yield-retry',
|
||||
);
|
||||
@@ -849,6 +864,7 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-retry-success',
|
||||
);
|
||||
@@ -909,6 +925,7 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-retry-fail',
|
||||
);
|
||||
@@ -964,6 +981,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Send a new message
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'Second question' },
|
||||
'prompt-id-retry-existing',
|
||||
);
|
||||
@@ -1034,6 +1052,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 2. Call the method and consume the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test empty stream' },
|
||||
'prompt-id-empty-stream',
|
||||
);
|
||||
@@ -1113,6 +1132,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 3. Start the first stream and consume only the first chunk to pause it
|
||||
const firstStream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'first' },
|
||||
'prompt-1',
|
||||
);
|
||||
@@ -1121,6 +1141,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// 4. While the first stream is paused, start the second call. It will block.
|
||||
const secondStreamPromise = chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'second' },
|
||||
'prompt-2',
|
||||
);
|
||||
@@ -1180,6 +1201,7 @@ describe('GeminiChat', () => {
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-res3',
|
||||
);
|
||||
@@ -1234,12 +1256,9 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
|
||||
const FAILED_MODEL = 'gemini-2.5-pro';
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL);
|
||||
const authType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
|
||||
authType,
|
||||
model: FAILED_MODEL,
|
||||
});
|
||||
|
||||
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
|
||||
@@ -1267,6 +1286,7 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'trigger 429' },
|
||||
'prompt-id-fb1',
|
||||
);
|
||||
@@ -1282,7 +1302,7 @@ describe('GeminiChat', () => {
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
expect(mockHandleFallback).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
FAILED_MODEL,
|
||||
'test-model',
|
||||
authType,
|
||||
error429,
|
||||
);
|
||||
@@ -1300,6 +1320,7 @@ describe('GeminiChat', () => {
|
||||
mockHandleFallback.mockResolvedValue(false);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test stop' },
|
||||
'prompt-id-fb2',
|
||||
);
|
||||
@@ -1357,6 +1378,7 @@ describe('GeminiChat', () => {
|
||||
|
||||
// Send a message and consume the stream
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-discard-test',
|
||||
);
|
||||
|
||||
@@ -19,7 +19,10 @@ import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
getEffectiveModel,
|
||||
} from '../config/models.js';
|
||||
import { hasCycleInSchema } from '../tools/tools.js';
|
||||
import type { StructuredError } from './turn.js';
|
||||
import type { CompletedToolCall } from './coreToolScheduler.js';
|
||||
@@ -206,6 +209,7 @@ export class GeminiChat {
|
||||
* ```
|
||||
*/
|
||||
async sendMessageStream(
|
||||
model: string,
|
||||
params: SendMessageParameters,
|
||||
prompt_id: string,
|
||||
): Promise<AsyncGenerator<StreamEvent>> {
|
||||
@@ -253,6 +257,7 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
const stream = await self.makeApiCallAndProcessStream(
|
||||
model,
|
||||
requestContents,
|
||||
params,
|
||||
prompt_id,
|
||||
@@ -317,18 +322,17 @@ export class GeminiChat {
|
||||
}
|
||||
|
||||
private async makeApiCallAndProcessStream(
|
||||
model: string,
|
||||
requestContents: Content[],
|
||||
params: SendMessageParameters,
|
||||
prompt_id: string,
|
||||
userContent: Content,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
const modelToUse = getEffectiveModel(
|
||||
this.config.isInFallbackMode(),
|
||||
model,
|
||||
);
|
||||
|
||||
if (
|
||||
this.config.getQuotaErrorOccurred() &&
|
||||
@@ -352,15 +356,7 @@ export class GeminiChat {
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
) => await handleFallback(this.config, model, authType, error);
|
||||
|
||||
const streamResponse = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
|
||||
@@ -499,7 +499,7 @@ describe('subagent.ts', () => {
|
||||
expect(scope.output.emitted_vars).toEqual({});
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
// Check the initial message
|
||||
expect(mockSendMessageStream.mock.calls[0][0].message).toEqual([
|
||||
expect(mockSendMessageStream.mock.calls[0][1].message).toEqual([
|
||||
{ text: 'Get Started!' },
|
||||
]);
|
||||
});
|
||||
@@ -543,7 +543,7 @@ describe('subagent.ts', () => {
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Check the tool response sent back in the second call
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[0][0];
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[0][1];
|
||||
expect(secondCallArgs.message).toEqual([{ text: 'Get Started!' }]);
|
||||
});
|
||||
|
||||
@@ -605,7 +605,7 @@ describe('subagent.ts', () => {
|
||||
);
|
||||
|
||||
// Check the response sent back to the model
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][1];
|
||||
expect(secondCallArgs.message).toEqual([
|
||||
{ text: 'file1.txt\nfile2.ts' },
|
||||
]);
|
||||
@@ -653,7 +653,7 @@ describe('subagent.ts', () => {
|
||||
await scope.runNonInteractive(new ContextState());
|
||||
|
||||
// The agent should send the specific error message from responseParts.
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][1];
|
||||
|
||||
expect(secondCallArgs.message).toEqual([
|
||||
{
|
||||
@@ -699,7 +699,7 @@ describe('subagent.ts', () => {
|
||||
await scope.runNonInteractive(new ContextState());
|
||||
|
||||
// Check the nudge message sent in Turn 2
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][0];
|
||||
const secondCallArgs = mockSendMessageStream.mock.calls[1][1];
|
||||
|
||||
// We check that the message contains the required variable name and the nudge phrasing.
|
||||
expect(secondCallArgs.message[0].text).toContain('required_var');
|
||||
|
||||
@@ -430,6 +430,7 @@ export class SubAgentScope {
|
||||
};
|
||||
|
||||
const responseStream = await chat.sendMessageStream(
|
||||
this.modelConfig.model,
|
||||
messageParams,
|
||||
promptId,
|
||||
);
|
||||
|
||||
@@ -97,6 +97,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -104,6 +105,7 @@ describe('Turn', () => {
|
||||
}
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
'test-model',
|
||||
{
|
||||
message: reqParts,
|
||||
config: { abortSignal: expect.any(AbortSignal) },
|
||||
@@ -144,6 +146,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Use tools' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -206,7 +209,11 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test abort' }];
|
||||
for await (const event of turn.run(reqParts, abortController.signal)) {
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
abortController.signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
expect(events).toEqual([
|
||||
@@ -227,6 +234,7 @@ describe('Turn', () => {
|
||||
mockMaybeIncludeSchemaDepthContext.mockResolvedValue(undefined);
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -267,6 +275,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'Test undefined tool parts' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -323,6 +332,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'Test finish reason' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -370,6 +380,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Generate long text' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -407,6 +418,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test safety' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -443,6 +455,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test no finish reason' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -487,6 +500,7 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test multiple responses' }];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -529,6 +543,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'Test citations' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -578,6 +593,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -624,6 +640,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -669,6 +686,7 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[{ text: 'test' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
@@ -705,7 +723,11 @@ describe('Turn', () => {
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test malformed error handling' }];
|
||||
|
||||
for await (const event of turn.run(reqParts, abortController.signal)) {
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
abortController.signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
@@ -727,7 +749,11 @@ describe('Turn', () => {
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run([], new AbortController().signal)) {
|
||||
for await (const event of turn.run(
|
||||
'test-model',
|
||||
[],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
@@ -752,7 +778,11 @@ describe('Turn', () => {
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const _ of turn.run(reqParts, new AbortController().signal)) {
|
||||
for await (const _ of turn.run(
|
||||
'test-model',
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
// consume stream
|
||||
}
|
||||
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
|
||||
|
||||
@@ -211,6 +211,7 @@ export class Turn {
|
||||
) {}
|
||||
// The run method yields simpler events suitable for server logic
|
||||
async *run(
|
||||
model: string,
|
||||
req: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
@@ -218,6 +219,7 @@ export class Turn {
|
||||
// Note: This assumes `sendMessageStream` yields events like
|
||||
// { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse }
|
||||
const responseStream = await this.chat.sendMessageStream(
|
||||
model,
|
||||
{
|
||||
message: req,
|
||||
config: {
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ModelRouterService } from './modelRouterService.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { RoutingContext, RoutingDecision } from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
vi.mock('../core/baseLlmClient.js');
|
||||
vi.mock('./strategies/defaultStrategy.js');
|
||||
vi.mock('./strategies/compositeStrategy.js');
|
||||
vi.mock('./strategies/fallbackStrategy.js');
|
||||
vi.mock('./strategies/overrideStrategy.js');
|
||||
|
||||
describe('ModelRouterService', () => {
|
||||
let service: ModelRouterService;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockContext: RoutingContext;
|
||||
let mockCompositeStrategy: CompositeStrategy;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockConfig = new Config({} as never);
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||
|
||||
mockCompositeStrategy = new CompositeStrategy(
|
||||
[new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
|
||||
'agent-router',
|
||||
);
|
||||
vi.mocked(CompositeStrategy).mockImplementation(
|
||||
() => mockCompositeStrategy,
|
||||
);
|
||||
|
||||
service = new ModelRouterService(mockConfig);
|
||||
|
||||
mockContext = {
|
||||
history: [],
|
||||
request: [{ text: 'test prompt' }],
|
||||
signal: new AbortController().signal,
|
||||
};
|
||||
});
|
||||
|
||||
it('should initialize with a CompositeStrategy', () => {
|
||||
expect(CompositeStrategy).toHaveBeenCalled();
|
||||
expect(service['strategy']).toBeInstanceOf(CompositeStrategy);
|
||||
});
|
||||
|
||||
it('should initialize the CompositeStrategy with the correct child strategies in order', () => {
|
||||
// This test relies on the mock implementation detail of the constructor
|
||||
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
||||
const childStrategies = compositeStrategyArgs[0];
|
||||
|
||||
expect(childStrategies.length).toBe(3);
|
||||
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
||||
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
||||
expect(childStrategies[2]).toBeInstanceOf(DefaultStrategy);
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
describe('route()', () => {
|
||||
it('should delegate routing to the composite strategy', async () => {
|
||||
const strategyDecision: RoutingDecision = {
|
||||
model: 'strategy-chosen-model',
|
||||
metadata: {
|
||||
source: 'test-router/fallback',
|
||||
latencyMs: 10,
|
||||
reasoning: 'Strategy reasoning',
|
||||
},
|
||||
};
|
||||
const strategySpy = vi
|
||||
.spyOn(mockCompositeStrategy, 'route')
|
||||
.mockResolvedValue(strategyDecision);
|
||||
|
||||
const decision = await service.route(mockContext);
|
||||
|
||||
expect(strategySpy).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
expect(decision).toEqual(strategyDecision);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
TerminalStrategy,
|
||||
} from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||
|
||||
/**
|
||||
* A centralized service for making model routing decisions.
|
||||
*/
|
||||
export class ModelRouterService {
|
||||
private config: Config;
|
||||
private strategy: TerminalStrategy;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
this.strategy = this.initializeDefaultStrategy();
|
||||
}
|
||||
|
||||
private initializeDefaultStrategy(): TerminalStrategy {
|
||||
// Initialize the composite strategy with the desired priority order.
|
||||
// The strategies are ordered in order of highest priority.
|
||||
return new CompositeStrategy(
|
||||
[new FallbackStrategy(), new OverrideStrategy(), new DefaultStrategy()],
|
||||
'agent-router',
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Determines which model to use for a given request context.
|
||||
*
|
||||
* @param context The full context of the request.
|
||||
* @returns A promise that resolves to a RoutingDecision.
|
||||
*/
|
||||
async route(context: RoutingContext): Promise<RoutingDecision> {
|
||||
const decision = await this.strategy.route(
|
||||
context,
|
||||
this.config,
|
||||
this.config.getBaseLlmClient(),
|
||||
);
|
||||
|
||||
return decision;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, PartListUnion } from '@google/genai';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
/**
|
||||
* The output of a routing decision. It specifies which model to use and why.
|
||||
*/
|
||||
export interface RoutingDecision {
|
||||
/** The model identifier string to use for the next API call (e.g., 'gemini-2.5-pro'). */
|
||||
model: string;
|
||||
/**
|
||||
* Metadata about the routing decision for logging purposes.
|
||||
*/
|
||||
metadata: {
|
||||
source: string;
|
||||
latencyMs: number;
|
||||
reasoning: string;
|
||||
error?: string;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* The context provided to the router for making a decision.
|
||||
*/
|
||||
export interface RoutingContext {
|
||||
/** The full history of the conversation. */
|
||||
history: Content[];
|
||||
/** The immediate request parts to be processed. */
|
||||
request: PartListUnion;
|
||||
/** An abort signal to cancel an LLM call during routing. */
|
||||
signal: AbortSignal;
|
||||
}
|
||||
|
||||
/**
|
||||
* The core interface that all routing strategies must implement.
|
||||
* Strategies implementing this interface may decline a request by returning null.
|
||||
*/
|
||||
export interface RoutingStrategy {
|
||||
/** The name of the strategy (e.g., 'fallback', 'override', 'composite'). */
|
||||
readonly name: string;
|
||||
|
||||
/**
|
||||
* Determines which model to use for a given request context.
|
||||
* @param context The full context of the request.
|
||||
* @param config The current configuration.
|
||||
* @param client A reference to the GeminiClient, allowing the strategy to make its own API calls if needed.
|
||||
* @returns A promise that resolves to a RoutingDecision, or null if the strategy is not applicable.
|
||||
*/
|
||||
route(
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null>;
|
||||
}
|
||||
|
||||
/**
|
||||
* A strategy that is guaranteed to return a decision. It must not return null.
|
||||
* This is used to ensure that a composite chain always terminates.
|
||||
*/
|
||||
export interface TerminalStrategy extends RoutingStrategy {
|
||||
/**
|
||||
* Determines which model to use for a given request context.
|
||||
* @returns A promise that resolves to a RoutingDecision.
|
||||
*/
|
||||
route(
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision>;
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { CompositeStrategy } from './compositeStrategy.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
TerminalStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
|
||||
describe('CompositeStrategy', () => {
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockStrategy1: RoutingStrategy;
|
||||
let mockStrategy2: RoutingStrategy;
|
||||
let mockTerminalStrategy: TerminalStrategy;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockContext = {} as RoutingContext;
|
||||
mockConfig = {} as Config;
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
|
||||
mockStrategy1 = {
|
||||
name: 'strategy1',
|
||||
route: vi.fn().mockResolvedValue(null),
|
||||
};
|
||||
|
||||
mockStrategy2 = {
|
||||
name: 'strategy2',
|
||||
route: vi.fn().mockResolvedValue(null),
|
||||
};
|
||||
|
||||
mockTerminalStrategy = {
|
||||
name: 'terminal',
|
||||
route: vi.fn().mockResolvedValue({
|
||||
model: 'terminal-model',
|
||||
metadata: {
|
||||
source: 'terminal',
|
||||
latencyMs: 10,
|
||||
reasoning: 'Terminal decision',
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
it('should try strategies in order and return the first successful decision', async () => {
|
||||
const decision: RoutingDecision = {
|
||||
model: 'strategy2-model',
|
||||
metadata: {
|
||||
source: 'strategy2',
|
||||
latencyMs: 20,
|
||||
reasoning: 'Strategy 2 decided',
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockStrategy2, 'route').mockResolvedValue(decision);
|
||||
|
||||
const composite = new CompositeStrategy(
|
||||
[mockStrategy1, mockStrategy2, mockTerminalStrategy],
|
||||
'test-router',
|
||||
);
|
||||
|
||||
const result = await composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(mockStrategy1.route).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
expect(mockStrategy2.route).toHaveBeenCalledWith(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
expect(mockTerminalStrategy.route).not.toHaveBeenCalled();
|
||||
|
||||
expect(result.model).toBe('strategy2-model');
|
||||
expect(result.metadata.source).toBe('test-router/strategy2');
|
||||
});
|
||||
|
||||
it('should fall back to the terminal strategy if no other strategy provides a decision', async () => {
|
||||
const composite = new CompositeStrategy(
|
||||
[mockStrategy1, mockStrategy2, mockTerminalStrategy],
|
||||
'test-router',
|
||||
);
|
||||
|
||||
const result = await composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(mockStrategy1.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockStrategy2.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockTerminalStrategy.route).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(result.model).toBe('terminal-model');
|
||||
expect(result.metadata.source).toBe('test-router/terminal');
|
||||
});
|
||||
|
||||
it('should handle errors in non-terminal strategies and continue', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
vi.spyOn(mockStrategy1, 'route').mockRejectedValue(
|
||||
new Error('Strategy 1 failed'),
|
||||
);
|
||||
|
||||
const composite = new CompositeStrategy(
|
||||
[mockStrategy1, mockTerminalStrategy],
|
||||
'test-router',
|
||||
);
|
||||
|
||||
const result = await composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
"[Routing] Strategy 'strategy1' failed. Continuing to next strategy. Error:",
|
||||
expect.any(Error),
|
||||
);
|
||||
expect(result.model).toBe('terminal-model');
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should re-throw an error from the terminal strategy', async () => {
|
||||
const consoleErrorSpy = vi
|
||||
.spyOn(console, 'error')
|
||||
.mockImplementation(() => {});
|
||||
const terminalError = new Error('Terminal strategy failed');
|
||||
vi.spyOn(mockTerminalStrategy, 'route').mockRejectedValue(terminalError);
|
||||
|
||||
const composite = new CompositeStrategy([mockTerminalStrategy]);
|
||||
|
||||
await expect(
|
||||
composite.route(mockContext, mockConfig, mockBaseLlmClient),
|
||||
).rejects.toThrow(terminalError);
|
||||
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
"[Routing] Critical Error: Terminal strategy 'terminal' failed. Routing cannot proceed. Error:",
|
||||
terminalError,
|
||||
);
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should correctly finalize the decision metadata', async () => {
|
||||
const decision: RoutingDecision = {
|
||||
model: 'some-model',
|
||||
metadata: {
|
||||
source: 'child-source',
|
||||
latencyMs: 50,
|
||||
reasoning: 'Child reasoning',
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision);
|
||||
|
||||
const composite = new CompositeStrategy(
|
||||
[mockStrategy1, mockTerminalStrategy],
|
||||
'my-composite',
|
||||
);
|
||||
|
||||
const result = await composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(result.model).toBe('some-model');
|
||||
expect(result.metadata.source).toBe('my-composite/child-source');
|
||||
expect(result.metadata.reasoning).toBe('Child reasoning');
|
||||
// It should keep the child's latency
|
||||
expect(result.metadata.latencyMs).toBe(50);
|
||||
});
|
||||
|
||||
it('should calculate total latency if child latency is not provided', async () => {
|
||||
const decision: RoutingDecision = {
|
||||
model: 'some-model',
|
||||
metadata: {
|
||||
source: 'child-source',
|
||||
// No latencyMs here
|
||||
latencyMs: 0,
|
||||
reasoning: 'Child reasoning',
|
||||
},
|
||||
};
|
||||
vi.spyOn(mockStrategy1, 'route').mockResolvedValue(decision);
|
||||
|
||||
const composite = new CompositeStrategy(
|
||||
[mockStrategy1, mockTerminalStrategy],
|
||||
'my-composite',
|
||||
);
|
||||
|
||||
const result = await composite.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(result.metadata.latencyMs).toBeGreaterThanOrEqual(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,109 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
TerminalStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
|
||||
/**
|
||||
* A strategy that attempts a list of child strategies in order (Chain of Responsibility).
|
||||
*/
|
||||
export class CompositeStrategy implements TerminalStrategy {
|
||||
readonly name: string;
|
||||
|
||||
private strategies: [...RoutingStrategy[], TerminalStrategy];
|
||||
|
||||
/**
|
||||
* Initializes the CompositeStrategy.
|
||||
* @param strategies The strategies to try, in order of priority. The last strategy must be terminal.
|
||||
* @param name The name of this composite configuration (e.g., 'router' or 'composite').
|
||||
*/
|
||||
constructor(
|
||||
strategies: [...RoutingStrategy[], TerminalStrategy],
|
||||
name: string = 'composite',
|
||||
) {
|
||||
this.strategies = strategies;
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
async route(
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision> {
|
||||
const startTime = performance.now();
|
||||
|
||||
// Separate non-terminal strategies from the terminal one.
|
||||
// This separation allows TypeScript to understand the control flow guarantees.
|
||||
const nonTerminalStrategies = this.strategies.slice(
|
||||
0,
|
||||
-1,
|
||||
) as RoutingStrategy[];
|
||||
const terminalStrategy = this.strategies[
|
||||
this.strategies.length - 1
|
||||
] as TerminalStrategy;
|
||||
|
||||
// Try non-terminal strategies, allowing them to fail gracefully.
|
||||
for (const strategy of nonTerminalStrategies) {
|
||||
try {
|
||||
const decision = await strategy.route(context, config, baseLlmClient);
|
||||
if (decision) {
|
||||
return this.finalizeDecision(decision, startTime);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`[Routing] Strategy '${strategy.name}' failed. Continuing to next strategy. Error:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If no other strategy matched, execute the terminal strategy.
|
||||
try {
|
||||
const decision = await terminalStrategy.route(
|
||||
context,
|
||||
config,
|
||||
baseLlmClient,
|
||||
);
|
||||
|
||||
return this.finalizeDecision(decision, startTime);
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`[Routing] Critical Error: Terminal strategy '${terminalStrategy.name}' failed. Routing cannot proceed. Error:`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to enhance the decision metadata with composite information.
|
||||
*/
|
||||
private finalizeDecision(
|
||||
decision: RoutingDecision,
|
||||
startTime: number,
|
||||
): RoutingDecision {
|
||||
const endTime = performance.now();
|
||||
const totalLatency = endTime - startTime;
|
||||
|
||||
// Combine the source paths: composite_name/child_source (e.g. 'router/default')
|
||||
const compositeSource = `${this.name}/${decision.metadata.source}`;
|
||||
|
||||
return {
|
||||
...decision,
|
||||
metadata: {
|
||||
...decision.metadata,
|
||||
source: compositeSource,
|
||||
latencyMs: decision.metadata.latencyMs || totalLatency,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { DefaultStrategy } from './defaultStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../../config/models.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
|
||||
describe('DefaultStrategy', () => {
|
||||
it('should always route to the default Gemini model', async () => {
|
||||
const strategy = new DefaultStrategy();
|
||||
const mockContext = {} as RoutingContext;
|
||||
const mockConfig = {} as Config;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'default',
|
||||
latencyMs: 0,
|
||||
reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
TerminalStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../../config/models.js';
|
||||
|
||||
export class DefaultStrategy implements TerminalStrategy {
|
||||
readonly name = 'default';
|
||||
|
||||
async route(
|
||||
_context: RoutingContext,
|
||||
_config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision> {
|
||||
return {
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: this.name,
|
||||
latencyMs: 0,
|
||||
reasoning: `Routing to default model: ${DEFAULT_GEMINI_MODEL}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { FallbackStrategy } from './fallbackStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
} from '../../config/models.js';
|
||||
|
||||
describe('FallbackStrategy', () => {
|
||||
const strategy = new FallbackStrategy();
|
||||
const mockContext = {} as RoutingContext;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
|
||||
it('should return null when not in fallback mode', async () => {
|
||||
const mockConfig = {
|
||||
isInFallbackMode: () => false,
|
||||
getModel: () => DEFAULT_GEMINI_MODEL,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
describe('when in fallback mode', () => {
|
||||
it('should downgrade a pro model to the flash model', async () => {
|
||||
const mockConfig = {
|
||||
isInFallbackMode: () => true,
|
||||
getModel: () => DEFAULT_GEMINI_MODEL,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(decision?.metadata.source).toBe('fallback');
|
||||
expect(decision?.metadata.reasoning).toContain('In fallback mode');
|
||||
});
|
||||
|
||||
it('should honor a lite model request', async () => {
|
||||
const mockConfig = {
|
||||
isInFallbackMode: () => true,
|
||||
getModel: () => DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_LITE_MODEL);
|
||||
expect(decision?.metadata.source).toBe('fallback');
|
||||
});
|
||||
|
||||
it('should use the flash model if flash is requested', async () => {
|
||||
const mockConfig = {
|
||||
isInFallbackMode: () => true,
|
||||
getModel: () => DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(decision?.metadata.source).toBe('fallback');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,43 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { getEffectiveModel } from '../../config/models.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
|
||||
export class FallbackStrategy implements RoutingStrategy {
|
||||
readonly name = 'fallback';
|
||||
|
||||
async route(
|
||||
_context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const isInFallbackMode: boolean = config.isInFallbackMode();
|
||||
|
||||
if (!isInFallbackMode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const effectiveModel = getEffectiveModel(
|
||||
isInFallbackMode,
|
||||
config.getModel(),
|
||||
);
|
||||
return {
|
||||
model: effectiveModel,
|
||||
metadata: {
|
||||
source: this.name,
|
||||
latencyMs: 0,
|
||||
reasoning: `In fallback mode. Using: ${effectiveModel}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { OverrideStrategy } from './overrideStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
|
||||
describe('OverrideStrategy', () => {
|
||||
const strategy = new OverrideStrategy();
|
||||
const mockContext = {} as RoutingContext;
|
||||
const mockClient = {} as BaseLlmClient;
|
||||
|
||||
it('should return null when no override model is specified', async () => {
|
||||
const mockConfig = {
|
||||
getModel: () => '', // Simulate no model override
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
it('should return a decision with the override model when one is specified', async () => {
|
||||
const overrideModel = 'gemini-2.5-pro-custom';
|
||||
const mockConfig = {
|
||||
getModel: () => overrideModel,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(overrideModel);
|
||||
expect(decision?.metadata.source).toBe('override');
|
||||
expect(decision?.metadata.reasoning).toContain(
|
||||
'Routing bypassed by forced model directive',
|
||||
);
|
||||
expect(decision?.metadata.reasoning).toContain(overrideModel);
|
||||
});
|
||||
|
||||
it('should handle different override model names', async () => {
|
||||
const overrideModel = 'gemini-2.5-flash-experimental';
|
||||
const mockConfig = {
|
||||
getModel: () => overrideModel,
|
||||
} as Config;
|
||||
|
||||
const decision = await strategy.route(mockContext, mockConfig, mockClient);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(overrideModel);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
|
||||
/**
|
||||
* Handles cases where the user explicitly specifies a model (override).
|
||||
*/
|
||||
export class OverrideStrategy implements RoutingStrategy {
|
||||
readonly name = 'override';
|
||||
|
||||
async route(
|
||||
_context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const overrideModel = config.getModel();
|
||||
if (overrideModel) {
|
||||
return {
|
||||
model: overrideModel,
|
||||
metadata: {
|
||||
source: this.name,
|
||||
latencyMs: 0,
|
||||
reasoning: `Routing bypassed by forced model directive. Using: ${overrideModel}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
// No override specified, pass to the next strategy.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user