mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 05:12:55 -07:00
Co-authored-by: matt korwel <matt.korwel@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest';
|
|||||||
import {
|
import {
|
||||||
resolveModel,
|
resolveModel,
|
||||||
resolveClassifierModel,
|
resolveClassifierModel,
|
||||||
|
isGemini3Model,
|
||||||
isGemini2Model,
|
isGemini2Model,
|
||||||
isAutoModel,
|
isAutoModel,
|
||||||
getDisplayString,
|
getDisplayString,
|
||||||
@@ -25,6 +26,29 @@ import {
|
|||||||
DEFAULT_GEMINI_MODEL_AUTO,
|
DEFAULT_GEMINI_MODEL_AUTO,
|
||||||
} from './models.js';
|
} from './models.js';
|
||||||
|
|
||||||
|
describe('isGemini3Model', () => {
|
||||||
|
it('should return true for gemini-3 models', () => {
|
||||||
|
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
|
||||||
|
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return true for aliases that resolve to Gemini 3', () => {
|
||||||
|
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO, true)).toBe(true);
|
||||||
|
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO, true)).toBe(true);
|
||||||
|
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false for Gemini 2 models', () => {
|
||||||
|
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
|
||||||
|
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
|
||||||
|
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return false for arbitrary strings', () => {
|
||||||
|
expect(isGemini3Model('gpt-4')).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('getDisplayString', () => {
|
describe('getDisplayString', () => {
|
||||||
it('should return Auto (Gemini 3) for preview auto model', () => {
|
it('should return Auto (Gemini 3) for preview auto model', () => {
|
||||||
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
|
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
|
||||||
|
|||||||
@@ -137,6 +137,21 @@ export function isPreviewModel(model: string): boolean {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if the model is a Gemini 3 model.
|
||||||
|
*
|
||||||
|
* @param model The model name to check.
|
||||||
|
* @param previewFeaturesEnabled A boolean indicating if preview features are enabled.
|
||||||
|
* @returns True if the model is a Gemini 3 model.
|
||||||
|
*/
|
||||||
|
export function isGemini3Model(
|
||||||
|
model: string,
|
||||||
|
previewFeaturesEnabled: boolean = false,
|
||||||
|
): boolean {
|
||||||
|
const resolved = resolveModel(model, previewFeaturesEnabled);
|
||||||
|
return /^gemini-3(\.|-|$)/.test(resolved);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Checks if the model is a Gemini 2.x model.
|
* Checks if the model is a Gemini 2.x model.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -4,19 +4,16 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||||
import { ClassifierStrategy } from './classifierStrategy.js';
|
import { ClassifierStrategy } from './classifierStrategy.js';
|
||||||
import type { RoutingContext } from '../routingStrategy.js';
|
import type { RoutingContext } from '../routingStrategy.js';
|
||||||
import type { Config } from '../../config/config.js';
|
import type { Config } from '../../config/config.js';
|
||||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||||
import {
|
|
||||||
isFunctionCall,
|
|
||||||
isFunctionResponse,
|
|
||||||
} from '../../utils/messageInspectors.js';
|
|
||||||
import {
|
import {
|
||||||
DEFAULT_GEMINI_FLASH_MODEL,
|
DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
DEFAULT_GEMINI_MODEL,
|
DEFAULT_GEMINI_MODEL,
|
||||||
DEFAULT_GEMINI_MODEL_AUTO,
|
DEFAULT_GEMINI_MODEL_AUTO,
|
||||||
|
PREVIEW_GEMINI_MODEL_AUTO,
|
||||||
} from '../../config/models.js';
|
} from '../../config/models.js';
|
||||||
import { promptIdContext } from '../../utils/promptIdContext.js';
|
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||||
import type { Content } from '@google/genai';
|
import type { Content } from '@google/genai';
|
||||||
@@ -31,6 +28,9 @@ describe('ClassifierStrategy', () => {
|
|||||||
let mockConfig: Config;
|
let mockConfig: Config;
|
||||||
let mockBaseLlmClient: BaseLlmClient;
|
let mockBaseLlmClient: BaseLlmClient;
|
||||||
let mockResolvedConfig: ResolvedModelConfig;
|
let mockResolvedConfig: ResolvedModelConfig;
|
||||||
|
let mockGetModel: Mock;
|
||||||
|
let mockGetNumericalRoutingEnabled: Mock;
|
||||||
|
let mockGenerateJson: Mock;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
@@ -46,23 +46,30 @@ describe('ClassifierStrategy', () => {
|
|||||||
model: 'classifier',
|
model: 'classifier',
|
||||||
generateContentConfig: {},
|
generateContentConfig: {},
|
||||||
} as unknown as ResolvedModelConfig;
|
} as unknown as ResolvedModelConfig;
|
||||||
|
|
||||||
|
mockGetModel = vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||||
|
mockGetNumericalRoutingEnabled = vi.fn().mockResolvedValue(false);
|
||||||
|
mockGenerateJson = vi.fn();
|
||||||
|
|
||||||
mockConfig = {
|
mockConfig = {
|
||||||
modelConfigService: {
|
modelConfigService: {
|
||||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||||
},
|
},
|
||||||
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
|
getModel: mockGetModel,
|
||||||
getPreviewFeatures: () => false,
|
getPreviewFeatures: () => false,
|
||||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
|
getNumericalRoutingEnabled: mockGetNumericalRoutingEnabled,
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
|
|
||||||
mockBaseLlmClient = {
|
mockBaseLlmClient = {
|
||||||
generateJson: vi.fn(),
|
generateJson: mockGenerateJson,
|
||||||
} as unknown as BaseLlmClient;
|
} as unknown as BaseLlmClient;
|
||||||
|
|
||||||
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
|
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should return null if numerical routing is enabled', async () => {
|
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
|
||||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
|
mockGetNumericalRoutingEnabled.mockResolvedValue(true);
|
||||||
|
mockGetModel.mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
|
||||||
|
|
||||||
const decision = await strategy.route(
|
const decision = await strategy.route(
|
||||||
mockContext,
|
mockContext,
|
||||||
@@ -71,7 +78,25 @@ describe('ClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toBeNull();
|
expect(decision).toBeNull();
|
||||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
expect(mockGenerateJson).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
|
||||||
|
mockGetNumericalRoutingEnabled.mockResolvedValue(true);
|
||||||
|
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||||
|
mockGenerateJson.mockResolvedValue({
|
||||||
|
reasoning: 'test',
|
||||||
|
model_choice: 'flash',
|
||||||
|
});
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).not.toBeNull();
|
||||||
|
expect(mockGenerateJson).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should call generateJson with the correct parameters', async () => {
|
it('should call generateJson with the correct parameters', async () => {
|
||||||
@@ -79,13 +104,11 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'Simple task',
|
reasoning: 'Simple task',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
modelConfigKey: { model: mockResolvedConfig.model },
|
modelConfigKey: { model: mockResolvedConfig.model },
|
||||||
promptId: 'test-prompt-id',
|
promptId: 'test-prompt-id',
|
||||||
@@ -98,9 +121,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'This is a simple task.',
|
reasoning: 'This is a simple task.',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
const decision = await strategy.route(
|
const decision = await strategy.route(
|
||||||
mockContext,
|
mockContext,
|
||||||
@@ -108,7 +129,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
mockBaseLlmClient,
|
mockBaseLlmClient,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
expect(mockGenerateJson).toHaveBeenCalledOnce();
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
@@ -124,9 +145,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'This is a complex task.',
|
reasoning: 'This is a complex task.',
|
||||||
model_choice: 'pro',
|
model_choice: 'pro',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
mockContext.request = [{ text: 'how do I build a spaceship?' }];
|
mockContext.request = [{ text: 'how do I build a spaceship?' }];
|
||||||
|
|
||||||
const decision = await strategy.route(
|
const decision = await strategy.route(
|
||||||
@@ -135,7 +154,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
mockBaseLlmClient,
|
mockBaseLlmClient,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledOnce();
|
expect(mockGenerateJson).toHaveBeenCalledOnce();
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_MODEL,
|
model: DEFAULT_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
@@ -151,7 +170,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
.spyOn(debugLogger, 'warn')
|
.spyOn(debugLogger, 'warn')
|
||||||
.mockImplementation(() => {});
|
.mockImplementation(() => {});
|
||||||
const testError = new Error('API Failure');
|
const testError = new Error('API Failure');
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockRejectedValue(testError);
|
mockGenerateJson.mockRejectedValue(testError);
|
||||||
|
|
||||||
const decision = await strategy.route(
|
const decision = await strategy.route(
|
||||||
mockContext,
|
mockContext,
|
||||||
@@ -172,9 +191,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'This is a simple task.',
|
reasoning: 'This is a simple task.',
|
||||||
// model_choice is missing, which will cause a Zod parsing error.
|
// model_choice is missing, which will cause a Zod parsing error.
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(malformedApiResponse);
|
||||||
malformedApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
const decision = await strategy.route(
|
const decision = await strategy.route(
|
||||||
mockContext,
|
mockContext,
|
||||||
@@ -203,14 +220,11 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'Simple.',
|
reasoning: 'Simple.',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||||
.calls[0][0];
|
|
||||||
const contents = generateJsonCall.contents;
|
const contents = generateJsonCall.contents;
|
||||||
|
|
||||||
const expectedContents = [
|
const expectedContents = [
|
||||||
@@ -239,14 +253,11 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'Simple.',
|
reasoning: 'Simple.',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||||
.calls[0][0];
|
|
||||||
const contents = generateJsonCall.contents;
|
const contents = generateJsonCall.contents;
|
||||||
|
|
||||||
// Manually calculate what the history should be
|
// Manually calculate what the history should be
|
||||||
@@ -254,7 +265,10 @@ describe('ClassifierStrategy', () => {
|
|||||||
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
const HISTORY_TURNS_FOR_CONTEXT = 4;
|
||||||
const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW);
|
const historySlice = longHistory.slice(-HISTORY_SEARCH_WINDOW);
|
||||||
const cleanHistory = historySlice.filter(
|
const cleanHistory = historySlice.filter(
|
||||||
(content) => !isFunctionCall(content) && !isFunctionResponse(content),
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
(content: any) =>
|
||||||
|
!content.parts?.[0]?.functionCall &&
|
||||||
|
!content.parts?.[0]?.functionResponse,
|
||||||
);
|
);
|
||||||
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
const finalHistory = cleanHistory.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||||
|
|
||||||
@@ -275,14 +289,11 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'Simple.',
|
reasoning: 'Simple.',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
await strategy.route(mockContext, mockConfig, mockBaseLlmClient);
|
||||||
|
|
||||||
const generateJsonCall = vi.mocked(mockBaseLlmClient.generateJson).mock
|
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||||
.calls[0][0];
|
|
||||||
|
|
||||||
expect(generateJsonCall.promptId).toMatch(
|
expect(generateJsonCall.promptId).toMatch(
|
||||||
/^classifier-router-fallback-\d+-\w+$/,
|
/^classifier-router-fallback-\d+-\w+$/,
|
||||||
@@ -301,9 +312,7 @@ describe('ClassifierStrategy', () => {
|
|||||||
reasoning: 'Choice is flash',
|
reasoning: 'Choice is flash',
|
||||||
model_choice: 'flash',
|
model_choice: 'flash',
|
||||||
};
|
};
|
||||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||||
mockApiResponse,
|
|
||||||
);
|
|
||||||
|
|
||||||
const contextWithRequestedModel = {
|
const contextWithRequestedModel = {
|
||||||
...mockContext,
|
...mockContext,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import type {
|
|||||||
RoutingDecision,
|
RoutingDecision,
|
||||||
RoutingStrategy,
|
RoutingStrategy,
|
||||||
} from '../routingStrategy.js';
|
} from '../routingStrategy.js';
|
||||||
import { resolveClassifierModel } from '../../config/models.js';
|
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||||
import { createUserContent, Type } from '@google/genai';
|
import { createUserContent, Type } from '@google/genai';
|
||||||
import type { Config } from '../../config/config.js';
|
import type { Config } from '../../config/config.js';
|
||||||
import {
|
import {
|
||||||
@@ -133,7 +133,12 @@ export class ClassifierStrategy implements RoutingStrategy {
|
|||||||
): Promise<RoutingDecision | null> {
|
): Promise<RoutingDecision | null> {
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
try {
|
try {
|
||||||
if (await config.getNumericalRoutingEnabled()) {
|
const model = context.requestedModel ?? config.getModel();
|
||||||
|
const previewFeaturesEnabled = config.getPreviewFeatures();
|
||||||
|
if (
|
||||||
|
(await config.getNumericalRoutingEnabled()) &&
|
||||||
|
isGemini3Model(model, previewFeaturesEnabled)
|
||||||
|
) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -164,7 +169,7 @@ export class ClassifierStrategy implements RoutingStrategy {
|
|||||||
const reasoning = routerResponse.reasoning;
|
const reasoning = routerResponse.reasoning;
|
||||||
const latencyMs = Date.now() - startTime;
|
const latencyMs = Date.now() - startTime;
|
||||||
const selectedModel = resolveClassifierModel(
|
const selectedModel = resolveClassifierModel(
|
||||||
context.requestedModel ?? config.getModel(),
|
model,
|
||||||
routerResponse.model_choice,
|
routerResponse.model_choice,
|
||||||
config.getPreviewFeatures(),
|
config.getPreviewFeatures(),
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -10,9 +10,12 @@ import type { RoutingContext } from '../routingStrategy.js';
|
|||||||
import type { Config } from '../../config/config.js';
|
import type { Config } from '../../config/config.js';
|
||||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||||
import {
|
import {
|
||||||
DEFAULT_GEMINI_FLASH_MODEL,
|
PREVIEW_GEMINI_FLASH_MODEL,
|
||||||
DEFAULT_GEMINI_MODEL,
|
PREVIEW_GEMINI_MODEL,
|
||||||
DEFAULT_GEMINI_MODEL_AUTO,
|
DEFAULT_GEMINI_MODEL_AUTO,
|
||||||
|
DEFAULT_GEMINI_MODEL,
|
||||||
|
PREVIEW_GEMINI_MODEL_AUTO,
|
||||||
|
GEMINI_MODEL_ALIAS_AUTO,
|
||||||
} from '../../config/models.js';
|
} from '../../config/models.js';
|
||||||
import { promptIdContext } from '../../utils/promptIdContext.js';
|
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||||
import type { Content } from '@google/genai';
|
import type { Content } from '@google/genai';
|
||||||
@@ -46,7 +49,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
modelConfigService: {
|
modelConfigService: {
|
||||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||||
},
|
},
|
||||||
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
|
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||||
getPreviewFeatures: () => false,
|
getPreviewFeatures: () => false,
|
||||||
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
||||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||||
@@ -76,6 +79,54 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should return null if the model is not a Gemini 3 model', async () => {
|
||||||
|
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).toBeNull();
|
||||||
|
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return null if the model is explicitly a Gemini 2 model', async () => {
|
||||||
|
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).toBeNull();
|
||||||
|
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return a decision if model is auto and preview features are enabled (resolves to Gemini 3)', async () => {
|
||||||
|
vi.mocked(mockConfig.getModel).mockReturnValue(GEMINI_MODEL_ALIAS_AUTO);
|
||||||
|
vi.spyOn(mockConfig, 'getPreviewFeatures').mockReturnValue(true);
|
||||||
|
|
||||||
|
const mockApiResponse = {
|
||||||
|
complexity_reasoning: 'Simple task',
|
||||||
|
complexity_score: 10,
|
||||||
|
};
|
||||||
|
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||||
|
mockApiResponse,
|
||||||
|
);
|
||||||
|
|
||||||
|
const decision = await strategy.route(
|
||||||
|
mockContext,
|
||||||
|
mockConfig,
|
||||||
|
mockBaseLlmClient,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(decision).not.toBeNull();
|
||||||
|
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
it('should call generateJson with the correct parameters and wrapped user content', async () => {
|
it('should call generateJson with the correct parameters and wrapped user content', async () => {
|
||||||
const mockApiResponse = {
|
const mockApiResponse = {
|
||||||
complexity_reasoning: 'Simple task',
|
complexity_reasoning: 'Simple task',
|
||||||
@@ -120,7 +171,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Control)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -146,7 +197,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_MODEL,
|
model: PREVIEW_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Control)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -172,7 +223,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Strict)',
|
source: 'NumericalClassifier (Strict)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -198,7 +249,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_MODEL,
|
model: PREVIEW_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Strict)',
|
source: 'NumericalClassifier (Strict)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -226,7 +277,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
|
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Remote)',
|
source: 'NumericalClassifier (Remote)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -252,7 +303,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
|
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Remote)',
|
source: 'NumericalClassifier (Remote)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -278,7 +329,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
|
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Remote)',
|
source: 'NumericalClassifier (Remote)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -306,7 +357,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Control)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -333,7 +384,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Control)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
@@ -360,7 +411,7 @@ describe('NumericalClassifierStrategy', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(decision).toEqual({
|
expect(decision).toEqual({
|
||||||
model: DEFAULT_GEMINI_MODEL,
|
model: PREVIEW_GEMINI_MODEL,
|
||||||
metadata: {
|
metadata: {
|
||||||
source: 'NumericalClassifier (Control)',
|
source: 'NumericalClassifier (Control)',
|
||||||
latencyMs: expect.any(Number),
|
latencyMs: expect.any(Number),
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import type {
|
|||||||
RoutingDecision,
|
RoutingDecision,
|
||||||
RoutingStrategy,
|
RoutingStrategy,
|
||||||
} from '../routingStrategy.js';
|
} from '../routingStrategy.js';
|
||||||
import { resolveClassifierModel } from '../../config/models.js';
|
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||||
import { createUserContent, Type } from '@google/genai';
|
import { createUserContent, Type } from '@google/genai';
|
||||||
import type { Config } from '../../config/config.js';
|
import type { Config } from '../../config/config.js';
|
||||||
import { debugLogger } from '../../utils/debugLogger.js';
|
import { debugLogger } from '../../utils/debugLogger.js';
|
||||||
@@ -134,10 +134,16 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
|||||||
): Promise<RoutingDecision | null> {
|
): Promise<RoutingDecision | null> {
|
||||||
const startTime = Date.now();
|
const startTime = Date.now();
|
||||||
try {
|
try {
|
||||||
|
const model = context.requestedModel ?? config.getModel();
|
||||||
|
const previewFeaturesEnabled = config.getPreviewFeatures();
|
||||||
if (!(await config.getNumericalRoutingEnabled())) {
|
if (!(await config.getNumericalRoutingEnabled())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!isGemini3Model(model, previewFeaturesEnabled)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
const promptId = getPromptIdWithFallback('classifier-router');
|
const promptId = getPromptIdWithFallback('classifier-router');
|
||||||
|
|
||||||
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||||
@@ -177,7 +183,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
|||||||
);
|
);
|
||||||
|
|
||||||
const selectedModel = resolveClassifierModel(
|
const selectedModel = resolveClassifierModel(
|
||||||
config.getModel(),
|
model,
|
||||||
modelAlias,
|
modelAlias,
|
||||||
config.getPreviewFeatures(),
|
config.getPreviewFeatures(),
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user