mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 18:14:29 -07:00
[Gemma x Gemini CLI] Add an Experimental Gemma Router that uses a LiteRT-LM shim into the Composite Model Classifier Strategy (#17231)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Allen Hutchison <adh@google.com>
This commit is contained in:
@@ -9,6 +9,7 @@ import { ModelRouterService } from './modelRouterService.js';
|
||||
import { Config } from '../config/config.js';
|
||||
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { LocalLiteRtLmClient } from '../core/localLiteRtLmClient.js';
|
||||
import type { RoutingContext, RoutingDecision } from './routingStrategy.js';
|
||||
import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
@@ -19,6 +20,7 @@ import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
|
||||
import { logModelRouting } from '../telemetry/loggers.js';
|
||||
import { ModelRoutingEvent } from '../telemetry/types.js';
|
||||
import { GemmaClassifierStrategy } from './strategies/gemmaClassifierStrategy.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
@@ -30,6 +32,7 @@ vi.mock('./strategies/overrideStrategy.js');
|
||||
vi.mock('./strategies/approvalModeStrategy.js');
|
||||
vi.mock('./strategies/classifierStrategy.js');
|
||||
vi.mock('./strategies/numericalClassifierStrategy.js');
|
||||
vi.mock('./strategies/gemmaClassifierStrategy.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
vi.mock('../telemetry/types.js');
|
||||
|
||||
@@ -37,6 +40,7 @@ describe('ModelRouterService', () => {
|
||||
let service: ModelRouterService;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let mockLocalLiteRtLmClient: LocalLiteRtLmClient;
|
||||
let mockContext: RoutingContext;
|
||||
let mockCompositeStrategy: CompositeStrategy;
|
||||
|
||||
@@ -45,9 +49,20 @@ describe('ModelRouterService', () => {
|
||||
|
||||
mockConfig = new Config({} as never);
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
mockLocalLiteRtLmClient = {} as LocalLiteRtLmClient;
|
||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||
vi.spyOn(mockConfig, 'getLocalLiteRtLmClient').mockReturnValue(
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: false,
|
||||
classifier: {
|
||||
host: 'http://localhost:1234',
|
||||
model: 'gemma3-1b-gpu-custom',
|
||||
},
|
||||
});
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
@@ -96,6 +111,36 @@ describe('ModelRouterService', () => {
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
it('should include GemmaClassifierStrategy when enabled', () => {
|
||||
// Override the default mock for this specific test
|
||||
vi.spyOn(mockConfig, 'getGemmaModelRouterSettings').mockReturnValue({
|
||||
enabled: true,
|
||||
classifier: {
|
||||
host: 'http://localhost:1234',
|
||||
model: 'gemma3-1b-gpu-custom',
|
||||
},
|
||||
});
|
||||
|
||||
// Clear previous mock calls from beforeEach
|
||||
vi.mocked(CompositeStrategy).mockClear();
|
||||
|
||||
// Re-initialize the service to pick up the new config
|
||||
service = new ModelRouterService(mockConfig);
|
||||
|
||||
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
||||
const childStrategies = compositeStrategyArgs[0];
|
||||
|
||||
expect(childStrategies.length).toBe(7);
|
||||
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
||||
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
||||
expect(childStrategies[2]).toBeInstanceOf(ApprovalModeStrategy);
|
||||
expect(childStrategies[3]).toBeInstanceOf(GemmaClassifierStrategy);
|
||||
expect(childStrategies[4]).toBeInstanceOf(ClassifierStrategy);
|
||||
expect(childStrategies[5]).toBeInstanceOf(NumericalClassifierStrategy);
|
||||
expect(childStrategies[6]).toBeInstanceOf(DefaultStrategy);
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
describe('route()', () => {
|
||||
const strategyDecision: RoutingDecision = {
|
||||
model: 'strategy-chosen-model',
|
||||
@@ -117,6 +162,7 @@ describe('ModelRouterService', () => {
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
mockLocalLiteRtLmClient,
|
||||
);
|
||||
expect(decision).toEqual(strategyDecision);
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user