feat(plan): support automatic model switching for Plan Mode (#20240)

This commit is contained in:
Jerop Kipruto
2026-02-24 19:15:14 -05:00
committed by GitHub
parent 1f9da6723f
commit bf278ef2b0
19 changed files with 422 additions and 31 deletions

View File

@@ -14,10 +14,12 @@ import { DefaultStrategy } from './strategies/defaultStrategy.js';
import { CompositeStrategy } from './strategies/compositeStrategy.js';
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
import { OverrideStrategy } from './strategies/overrideStrategy.js';
import { ApprovalModeStrategy } from './strategies/approvalModeStrategy.js';
import { ClassifierStrategy } from './strategies/classifierStrategy.js';
import { NumericalClassifierStrategy } from './strategies/numericalClassifierStrategy.js';
import { logModelRouting } from '../telemetry/loggers.js';
import { ModelRoutingEvent } from '../telemetry/types.js';
import { ApprovalMode } from '../policy/types.js';
vi.mock('../config/config.js');
vi.mock('../core/baseLlmClient.js');
@@ -25,6 +27,7 @@ vi.mock('./strategies/defaultStrategy.js');
vi.mock('./strategies/compositeStrategy.js');
vi.mock('./strategies/fallbackStrategy.js');
vi.mock('./strategies/overrideStrategy.js');
vi.mock('./strategies/approvalModeStrategy.js');
vi.mock('./strategies/classifierStrategy.js');
vi.mock('./strategies/numericalClassifierStrategy.js');
vi.mock('../telemetry/loggers.js');
@@ -45,11 +48,15 @@ describe('ModelRouterService', () => {
vi.spyOn(mockConfig, 'getBaseLlmClient').mockReturnValue(mockBaseLlmClient);
vi.spyOn(mockConfig, 'getNumericalRoutingEnabled').mockResolvedValue(false);
vi.spyOn(mockConfig, 'getClassifierThreshold').mockResolvedValue(undefined);
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
ApprovalMode.DEFAULT,
);
mockCompositeStrategy = new CompositeStrategy(
[
new FallbackStrategy(),
new OverrideStrategy(),
new ApprovalModeStrategy(),
new ClassifierStrategy(),
new NumericalClassifierStrategy(),
new DefaultStrategy(),
@@ -79,12 +86,13 @@ describe('ModelRouterService', () => {
const compositeStrategyArgs = vi.mocked(CompositeStrategy).mock.calls[0];
const childStrategies = compositeStrategyArgs[0];
expect(childStrategies.length).toBe(5);
expect(childStrategies.length).toBe(6);
expect(childStrategies[0]).toBeInstanceOf(FallbackStrategy);
expect(childStrategies[1]).toBeInstanceOf(OverrideStrategy);
expect(childStrategies[2]).toBeInstanceOf(ClassifierStrategy);
expect(childStrategies[3]).toBeInstanceOf(NumericalClassifierStrategy);
expect(childStrategies[4]).toBeInstanceOf(DefaultStrategy);
expect(childStrategies[2]).toBeInstanceOf(ApprovalModeStrategy);
expect(childStrategies[3]).toBeInstanceOf(ClassifierStrategy);
expect(childStrategies[4]).toBeInstanceOf(NumericalClassifierStrategy);
expect(childStrategies[5]).toBeInstanceOf(DefaultStrategy);
expect(compositeStrategyArgs[1]).toBe('agent-router');
});
@@ -127,6 +135,7 @@ describe('ModelRouterService', () => {
'Strategy reasoning',
false,
undefined,
ApprovalMode.DEFAULT,
false,
undefined,
);
@@ -153,6 +162,7 @@ describe('ModelRouterService', () => {
'An exception occurred during routing.',
true,
'Strategy failed',
ApprovalMode.DEFAULT,
false,
undefined,
);

View File

@@ -16,6 +16,7 @@ import { NumericalClassifierStrategy } from './strategies/numericalClassifierStr
import { CompositeStrategy } from './strategies/compositeStrategy.js';
import { FallbackStrategy } from './strategies/fallbackStrategy.js';
import { OverrideStrategy } from './strategies/overrideStrategy.js';
import { ApprovalModeStrategy } from './strategies/approvalModeStrategy.js';
import { logModelRouting } from '../telemetry/loggers.js';
import { ModelRoutingEvent } from '../telemetry/types.js';
@@ -40,6 +41,7 @@ export class ModelRouterService {
[
new FallbackStrategy(),
new OverrideStrategy(),
new ApprovalModeStrategy(),
new ClassifierStrategy(),
new NumericalClassifierStrategy(),
new DefaultStrategy(),
@@ -105,6 +107,7 @@ export class ModelRouterService {
decision!.metadata.reasoning,
failed,
error_message,
this.config.getApprovalMode(),
enableNumericalRouting,
classifierThreshold,
);

View File

@@ -0,0 +1,187 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { ApprovalModeStrategy } from './approvalModeStrategy.js';
import type { RoutingContext } from '../routingStrategy.js';
import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../../config/models.js';
import { ApprovalMode } from '../../policy/types.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
describe('ApprovalModeStrategy', () => {
let strategy: ApprovalModeStrategy;
let mockContext: RoutingContext;
let mockConfig: Config;
let mockBaseLlmClient: BaseLlmClient;
beforeEach(() => {
vi.clearAllMocks();
strategy = new ApprovalModeStrategy();
mockContext = {
history: [],
request: [{ text: 'test' }],
signal: new AbortController().signal,
};
mockConfig = {
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
getApprovedPlanPath: vi.fn().mockReturnValue(undefined),
getPlanModeRoutingEnabled: vi.fn().mockResolvedValue(true),
} as unknown as Config;
mockBaseLlmClient = {} as BaseLlmClient;
});
it('should return null if the model is not an auto model', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toBeNull();
});
it('should return null if plan mode routing is disabled', async () => {
vi.mocked(mockConfig.getPlanModeRoutingEnabled).mockResolvedValue(false);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toBeNull();
});
it('should route to PRO model if ApprovalMode is PLAN (Gemini 2.5)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toEqual({
model: DEFAULT_GEMINI_MODEL,
metadata: {
source: 'approval-mode',
latencyMs: expect.any(Number),
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
},
});
});
it('should route to PRO model if ApprovalMode is PLAN (Gemini 3)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_MODEL,
metadata: {
source: 'approval-mode',
latencyMs: expect.any(Number),
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
},
});
});
it('should route to FLASH model if an approved plan exists (Gemini 2.5)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toEqual({
model: DEFAULT_GEMINI_FLASH_MODEL,
metadata: {
source: 'approval-mode',
latencyMs: expect.any(Number),
reasoning:
'Routing to Flash model because an approved plan exists at /path/to/plan.md.',
},
});
});
it('should route to FLASH model if an approved plan exists (Gemini 3)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(
'/path/to/plan.md',
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toEqual({
model: PREVIEW_GEMINI_FLASH_MODEL,
metadata: {
source: 'approval-mode',
latencyMs: expect.any(Number),
reasoning:
'Routing to Flash model because an approved plan exists at /path/to/plan.md.',
},
});
});
it('should return null if not in PLAN mode and no approved plan exists', async () => {
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.DEFAULT);
vi.mocked(mockConfig.getApprovedPlanPath).mockReturnValue(undefined);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision).toBeNull();
});
it('should prioritize requestedModel over config model if it is an auto model', async () => {
mockContext.requestedModel = PREVIEW_GEMINI_MODEL_AUTO;
vi.mocked(mockConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_MODEL);
});
});

View File

@@ -0,0 +1,83 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
isAutoModel,
isPreviewModel,
} from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { ApprovalMode } from '../../policy/types.js';
import type {
RoutingContext,
RoutingDecision,
RoutingStrategy,
} from '../routingStrategy.js';
/**
* A strategy that routes based on the current ApprovalMode and plan status.
*
* - In PLAN mode: Routes to the PRO model for high-quality planning.
* - In other modes with an approved plan: Routes to the FLASH model for efficient implementation.
*/
export class ApprovalModeStrategy implements RoutingStrategy {
readonly name = 'approval-mode';
async route(
context: RoutingContext,
config: Config,
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const model = context.requestedModel ?? config.getModel();
// This strategy only applies to "auto" models.
if (!isAutoModel(model)) {
return null;
}
if (!(await config.getPlanModeRoutingEnabled())) {
return null;
}
const startTime = Date.now();
const approvalMode = config.getApprovalMode();
const approvedPlanPath = config.getApprovedPlanPath();
const isPreview = isPreviewModel(model);
// 1. Planning Phase: If ApprovalMode === PLAN, explicitly route to the Pro model.
if (approvalMode === ApprovalMode.PLAN) {
const proModel = isPreview ? PREVIEW_GEMINI_MODEL : DEFAULT_GEMINI_MODEL;
return {
model: proModel,
metadata: {
source: this.name,
latencyMs: Date.now() - startTime,
reasoning: 'Routing to Pro model because ApprovalMode is PLAN.',
},
};
} else if (approvedPlanPath) {
// 2. Implementation Phase: If ApprovalMode !== PLAN AND an approved plan path is set, prefer the Flash model.
const flashModel = isPreview
? PREVIEW_GEMINI_FLASH_MODEL
: DEFAULT_GEMINI_FLASH_MODEL;
return {
model: flashModel,
metadata: {
source: this.name,
latencyMs: Date.now() - startTime,
reasoning: `Routing to Flash model because an approved plan exists at ${approvedPlanPath}.`,
},
};
}
return null;
}
}