mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-17 00:31:44 -07:00
feat(plan): support automatic model switching for Plan Mode (#20240)
This commit is contained in:
@@ -14,10 +14,12 @@ import { DefaultStrategy } from './strategies/defaultStrategy.js';
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||
import { ApprovalModeStrategy } from './strategies/approvalModeStrategy.js';
|
||||
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
|
||||
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
|
||||
import { logModelRouting } from '../telemetry/loggers.js';
|
||||
import { ModelRoutingEvent } from '../telemetry/types.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
vi.mock('../core/baseLlmClient.js');
|
||||
@@ -25,6 +27,7 @@ vi.mock('./strategies/defaultStrategy.js');
|
||||
vi.mock('./strategies/compositeStrategy.js');
|
||||
vi.mock('./strategies/fallbackStrategy.js');
|
||||
vi.mock('./strategies/overrideStrategy.js');
|
||||
vi.mock('./strategies/approvalModeStrategy.js');
|
||||
vi.mock('./strategies/classifierStrategy.js');
|
||||
vi.mock('./strategies/numericalClassifierStrategy.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
@@ -45,11 +48,15 @@ describe('ModelRouterService', () => {
|
||||
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
|
||||
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
|
||||
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.DEFAULT,
|
||||
);
|
||||
|
||||
mockCompositeStrategy = new CompositeStrategy(
|
||||
[
|
||||
new FallbackStrategy(),
|
||||
new OverrideStrategy(),
|
||||
new ApprovalModeStrategy(),
|
||||
new ClassifierStrategy(),
|
||||
new NumericalClassifierStrategy(),
|
||||
new DefaultStrategy(),
|
||||
@@ -79,12 +86,13 @@ describe('ModelRouterService', () => {
|
||||
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
|
||||
const childStrategies = compositeStrategyArgs[0];
|
||||
|
||||
expect(childStrategies.length).toBe(5);
|
||||
expect(childStrategies.length).toBe(6);
|
||||
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
|
||||
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
|
||||
expect(childStrategies[2]).toBeInstanceOf(ClassifierStrategy);
|
||||
expect(childStrategies[3]).toBeInstanceOf(NumericalClassifierStrategy);
|
||||
expect(childStrategies[4]).toBeInstanceOf(DefaultStrategy);
|
||||
expect(childStrategies[2]).toBeInstanceOf(ApprovalModeStrategy);
|
||||
expect(childStrategies[3]).toBeInstanceOf(ClassifierStrategy);
|
||||
expect(childStrategies[4]).toBeInstanceOf(NumericalClassifierStrategy);
|
||||
expect(childStrategies[5]).toBeInstanceOf(DefaultStrategy);
|
||||
expect(compositeStrategyArgs[1]).toBe('agent-router');
|
||||
});
|
||||
|
||||
@@ -127,6 +135,7 @@ describe('ModelRouterService', () => {
|
||||
'Strategy reasoning',
|
||||
false,
|
||||
undefined,
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
);
|
||||
@@ -153,6 +162,7 @@ describe('ModelRouterService', () => {
|
||||
'An exception occurred during routing.',
|
||||
true,
|
||||
'Strategy failed',
|
||||
ApprovalMode.DEFAULT,
|
||||
false,
|
||||
undefined,
|
||||
);
|
||||
|
||||
@@ -16,6 +16,7 @@ import { NumericalClassifierStrategy } from './strategies/numericalClassifierStr
|
||||
import { CompositeStrategy } from './strategies/compositeStrategy.js';
|
||||
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
|
||||
import { OverrideStrategy } from './strategies/overrideStrategy.js';
|
||||
import { ApprovalModeStrategy } from './strategies/approvalModeStrategy.js';
|
||||
|
||||
import { logModelRouting } from '../telemetry/loggers.js';
|
||||
import { ModelRoutingEvent } from '../telemetry/types.js';
|
||||
@@ -40,6 +41,7 @@ export class ModelRouterService {
|
||||
[
|
||||
new FallbackStrategy(),
|
||||
new OverrideStrategy(),
|
||||
new ApprovalModeStrategy(),
|
||||
new ClassifierStrategy(),
|
||||
new NumericalClassifierStrategy(),
|
||||
new DefaultStrategy(),
|
||||
@@ -105,6 +107,7 @@ export class ModelRouterService {
|
||||
decision!.metadata.reasoning,
|
||||
failed,
|
||||
error_message,
|
||||
this.config.getApprovalMode(),
|
||||
enableNumericalRouting,
|
||||
classifierThreshold,
|
||||
);
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { ApprovalModeStrategy } from './approvalModeStrategy.js';
|
||||
import type { RoutingContext } from '../routingStrategy.js';
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
PREVIEW_GEMINI_MODEL_AUTO,
|
||||
} from '../../config/models.js';
|
||||
import { ApprovalMode } from '../../policy/types.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
|
||||
describe('ApprovalModeStrategy', () => {
|
||||
let strategy: ApprovalModeStrategy;
|
||||
let mockContext: RoutingContext;
|
||||
let mockConfig: Config;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
strategy = new ApprovalModeStrategy();
|
||||
mockContext = {
|
||||
history: [],
|
||||
request: [{ text: 'test' }],
|
||||
signal: new AbortController().signal,
|
||||
};
|
||||
|
||||
mockConfig = {
|
||||
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
|
||||
getApprovedPlanPath: vi.fn().mockReturnValue(undefined),
|
||||
getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true),
|
||||
} as unknown as Config;
|
||||
|
||||
mockBaseLlmClient = {} as BaseLlmClient;
|
||||
});
|
||||
|
||||
it('should return null if the model is not an auto model', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
it('should return null if plan mode routing is disabled', async () => {
|
||||
vi.mocked(mockConfig.getPlanModeRoutingEnabled).mockResolvedValue(false);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
it('should route to PRO model if ApprovalMode is PLAN (Gemini 2.5)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'approval-mode',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should route to PRO model if ApprovalMode is PLAN (Gemini 3)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
metadata: {
|
||||
source: 'approval-mode',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should route to FLASH model if an approved plan exists (Gemini 2.5)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
|
||||
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
|
||||
'/path/to/plan.md',
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'approval-mode',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning:
|
||||
'Routing to Flash model because an approved plan exists at /path/to/plan.md.',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should route to FLASH model if an approved plan exists (Gemini 3)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
|
||||
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
|
||||
'/path/to/plan.md',
|
||||
);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toEqual({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
metadata: {
|
||||
source: 'approval-mode',
|
||||
latencyMs: expect.any(Number),
|
||||
reasoning:
|
||||
'Routing to Flash model because an approved plan exists at /path/to/plan.md.',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should return null if not in PLAN mode and no approved plan exists', async () => {
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
|
||||
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(undefined);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
});
|
||||
|
||||
it('should prioritize requestedModel over config model if it is an auto model', async () => {
|
||||
mockContext.requestedModel = PREVIEW_GEMINI_MODEL_AUTO;
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
|
||||
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
|
||||
|
||||
const decision = await strategy.route(
|
||||
mockContext,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
|
||||
});
|
||||
});
|
||||
83
packages/core/src/routing/strategies/approvalModeStrategy.ts
Normal file
83
packages/core/src/routing/strategies/approvalModeStrategy.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
isAutoModel,
|
||||
isPreviewModel,
|
||||
} from '../../config/models.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import { ApprovalMode } from '../../policy/types.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
RoutingDecision,
|
||||
RoutingStrategy,
|
||||
} from '../routingStrategy.js';
|
||||
|
||||
/**
|
||||
* A strategy that routes based on the current ApprovalMode and plan status.
|
||||
*
|
||||
* - In PLAN mode: Routes to the PRO model for high-quality planning.
|
||||
* - In other modes with an approved plan: Routes to the FLASH model for efficient implementation.
|
||||
*/
|
||||
export class ApprovalModeStrategy implements RoutingStrategy {
|
||||
readonly name = 'approval-mode';
|
||||
|
||||
async route(
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const model = context.requestedModel ?? config.getModel();
|
||||
|
||||
// This strategy only applies to "auto" models.
|
||||
if (!isAutoModel(model)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!(await config.getPlanModeRoutingEnabled())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const startTime = Date.now();
|
||||
const approvalMode = config.getApprovalMode();
|
||||
const approvedPlanPath = config.getApprovedPlanPath();
|
||||
|
||||
const isPreview = isPreviewModel(model);
|
||||
|
||||
// 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model.
|
||||
if (approvalMode === ApprovalMode.PLAN) {
|
||||
const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL;
|
||||
return {
|
||||
model: proModel,
|
||||
metadata: {
|
||||
source: this.name,
|
||||
latencyMs: Date.now() - startTime,
|
||||
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
|
||||
},
|
||||
};
|
||||
} else if (approvedPlanPath) {
|
||||
// 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model.
|
||||
const flashModel = isPreview
|
||||
? PREVIEW_GEMINI_FLASH_MODEL
|
||||
: DEFAULT_GEMINI_FLASH_MODEL;
|
||||
return {
|
||||
model: flashModel,
|
||||
metadata: {
|
||||
source: this.name,
|
||||
latencyMs: Date.now() - startTime,
|
||||
reasoning: `Routing to Flash model because an approved plan exists at ${approvedPlanPath}.`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user