mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
Co-authored-by: matt korwel <matt.korwel@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
resolveModel,
|
||||
resolveClassifierModel,
|
||||
isGemini3Model,
|
||||
isGemini2Model,
|
||||
isAutoModel,
|
||||
getDisplayString,
|
||||
@@ -25,6 +26,29 @@ import {
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
} from './models.js';
|
||||
|
||||
describe('isGemini3Model', () => {
|
||||
it('should return true for gemini-3 models', () => {
|
||||
expect(isGemini3Model('gemini-3-pro-preview')).toBe(true);
|
||||
expect(isGemini3Model('gemini-3-flash-preview')).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for aliases that resolve to Gemini 3 when preview is enabled', () => {
|
||||
expect(isGemini3Model(GEMINI_MODEL_ALIAS_AUTO, true)).toBe(true);
|
||||
expect(isGemini3Model(GEMINI_MODEL_ALIAS_PRO, true)).toBe(true);
|
||||
expect(isGemini3Model(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for Gemini 2 models', () => {
|
||||
expect(isGemini3Model('gemini-2.5-pro')).toBe(false);
|
||||
expect(isGemini3Model('gemini-2.5-flash')).toBe(false);
|
||||
expect(isGemini3Model(DEFAULT_GEMINI_MODEL_AUTO)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for arbitrary strings', () => {
|
||||
expect(isGemini3Model('gpt-4')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDisplayString', () => {
|
||||
it('should return Auto (Gemini 3) for preview auto model', () => {
|
||||
expect(getDisplayString(PREVIEW_GEMINI_MODEL_AUTO)).toBe('Auto (Gemini 3)');
|
||||
|
||||
@@ -137,6 +137,20 @@ export function isPreviewModel(model: string): boolean {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the model is a Gemini 3 model.
|
||||
*
|
||||
* @param model The model name to check.
|
||||
* @returns True if the model is a Gemini 3 model.
|
||||
*/
|
||||
export function isGemini3Model(
|
||||
model: string,
|
||||
previewFeaturesEnabled = false,
|
||||
): boolean {
|
||||
const resolved = resolveModel(model, previewFeaturesEnabled);
|
||||
return /^gemini-3(\.|-|$)/.test(resolved);
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the model is a Gemini 2.x model.
|
||||
*
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
PREVIEW_GEMINI_MODEL_AUTO,
|
||||
} from '../../config/models.js';
|
||||
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||
import type { Content } from '@google/genai';
|
||||
@@ -50,8 +51,8 @@ describe('ClassifierStrategy', () => {
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
},
|
||||
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
|
||||
getPreviewFeatures: () => false,
|
||||
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
|
||||
getPreviewFeatures: vi.fn().mockReturnValue(false),
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
|
||||
} as unknown as Config;
|
||||
mockBaseLlmClient = {
|
||||
@@ -61,8 +62,9 @@ describe('ClassifierStrategy', () => {
|
||||
vi.spyOn(promptIdContext, 'getStore').mockReturnValue('test-prompt-id');
|
||||
});
|
||||
|
||||
it('should return null if numerical routing is enabled', async () => {
|
||||
it('should return null if numerical routing is enabled and model is Gemini 3', async () => {
|
||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
@@ -74,6 +76,24 @@ describe('ClassifierStrategy', () => {
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should NOT return null if numerical routing is enabled but model is NOT Gemini 3', async () => {
|
||||
vi.mocked(mockConfig.getNumericalRoutingEnabled).mockResolvedValue(true);
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue({
|
||||
reasoning: 'test',
|
||||
model_choice: 'flash',
|
||||
});
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should call generateJson with the correct parameters', async () => {
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Simple task',
|
||||
|
||||
@@ -12,7 +12,7 @@ import type {
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import { resolveClassifierModel } from '../../config/models.js';
|
||||
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||
import { createUserContent, Type } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
@@ -133,7 +133,11 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
): Promise<RoutingDecision | null> {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
if (await config.getNumericalRoutingEnabled()) {
|
||||
const model = context.requestedModel ?? config.getModel();
|
||||
if (
|
||||
(await config.getNumericalRoutingEnabled()) &&
|
||||
isGemini3Model(model, config.getPreviewFeatures())
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -164,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
const reasoning = routerResponse.reasoning;
|
||||
const latencyMs = Date.now() - startTime;
|
||||
const selectedModel = resolveClassifierModel(
|
||||
context.requestedModel ?? config.getModel(),
|
||||
model,
|
||||
routerResponse.model_choice,
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
@@ -10,9 +10,11 @@ import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_MODEL_AUTO,
|
||||
} from '../../config/models.js';
|
||||
import { promptIdContext } from '../../utils/promptIdContext.js';
|
||||
import type { Content } from '@google/genai';
|
||||
@@ -46,8 +48,8 @@ describe('NumericalClassifierStrategy', () => {
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockReturnValue(mockResolvedConfig),
|
||||
},
|
||||
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
|
||||
getPreviewFeatures: () => false,
|
||||
getModel: vi.fn().mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO),
|
||||
getPreviewFeatures: vi.fn().mockReturnValue(false),
|
||||
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
|
||||
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||
getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -76,6 +78,32 @@ describe('NumericalClassifierStrategy', () => {
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the model is not a Gemini 3 model', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the model is explicitly a Gemini 2 model', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should call generateJson with the correct parameters and wrapped user content', async () => {
|
||||
const mockApiResponse = {
|
||||
complexity_reasoning: 'Simple task',
|
||||
@@ -120,7 +148,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -146,7 +174,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -172,7 +200,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Routed to Flash because 60 < 80
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -198,7 +226,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Strict)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -226,7 +254,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 60 < Threshold 70
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Remote)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -252,7 +280,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Threshold 45.5
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Remote)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -278,7 +306,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL, // Score 35 >= Threshold 30
|
||||
model: PREVIEW_GEMINI_MODEL, // Score 35 >= Threshold 30
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Remote)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -306,7 +334,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL, // Score 40 < Default A/B Threshold 50
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -333,7 +361,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
@@ -360,7 +388,7 @@ describe('NumericalClassifierStrategy', () => {
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'NumericalClassifier (Control)',
|
||||
latencyMs: expect.any(Number),
|
||||
|
||||
@@ -12,7 +12,7 @@ import type {
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
import { resolveClassifierModel } from '../../config/models.js';
|
||||
import { resolveClassifierModel, isGemini3Model } from '../../config/models.js';
|
||||
import { createUserContent, Type } from '@google/genai';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import { debugLogger } from '../../utils/debugLogger.js';
|
||||
@@ -134,10 +134,15 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
): Promise<RoutingDecision | null> {
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const model = context.requestedModel ?? config.getModel();
|
||||
if (!(await config.getNumericalRoutingEnabled())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!isGemini3Model(model, config.getPreviewFeatures())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const promptId = getPromptIdWithFallback('classifier-router');
|
||||
|
||||
const finalHistory = context.history.slice(-HISTORY_TURNS_FOR_CONTEXT);
|
||||
@@ -177,7 +182,7 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
|
||||
);
|
||||
|
||||
const selectedModel = resolveClassifierModel(
|
||||
config.getModel(),
|
||||
model,
|
||||
modelAlias,
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user