mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-07 20:00:37 -07:00
feat(core): Preliminary changes for subagent model routing. (#16035)
This commit is contained in:
@@ -35,6 +35,8 @@ export interface RoutingContext {
|
||||
request: PartListUnion;
|
||||
/** An abort signal to cancel an LLM call during routing. */
|
||||
signal: AbortSignal;
|
||||
/** The model string requested for this turn, if any. */
|
||||
requestedModel?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -281,4 +281,30 @@ describe('ClassifierStrategy', () => {
|
||||
);
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should respect requestedModel from context in resolveClassifierModel', async () => {
|
||||
const requestedModel = DEFAULT_GEMINI_MODEL; // Pro model
|
||||
const mockApiResponse = {
|
||||
reasoning: 'Choice is flash',
|
||||
model_choice: 'flash',
|
||||
};
|
||||
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
|
||||
mockApiResponse,
|
||||
);
|
||||
|
||||
const contextWithRequestedModel = {
|
||||
...mockContext,
|
||||
requestedModel,
|
||||
} as RoutingContext;
|
||||
|
||||
const decision = await strategy.route(
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockBaseLlmClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
// Since requestedModel is Pro, and choice is flash, it should resolve to Flash
|
||||
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -168,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
|
||||
const reasoning = routerResponse.reasoning;
|
||||
const latencyMs = Date.now() - startTime;
|
||||
const selectedModel = resolveClassifierModel(
|
||||
config.getModel(),
|
||||
context.requestedModel ?? config.getModel(),
|
||||
routerResponse.model_choice,
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
@@ -108,4 +108,25 @@ describe('FallbackStrategy', () => {
|
||||
// Important: check that it queried snapshot with the RESOLVED model, not 'auto'
|
||||
expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL);
|
||||
});
|
||||
|
||||
it('should respect requestedModel from context', async () => {
|
||||
const requestedModel = 'requested-model';
|
||||
const configModel = 'config-model';
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(configModel);
|
||||
vi.mocked(mockService.snapshot).mockReturnValue({ available: true });
|
||||
|
||||
const contextWithRequestedModel = {
|
||||
requestedModel,
|
||||
} as RoutingContext;
|
||||
|
||||
const decision = await strategy.route(
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
);
|
||||
|
||||
expect(decision).toBeNull();
|
||||
// Should check availability of the requested model from context
|
||||
expect(mockService.snapshot).toHaveBeenCalledWith(requestedModel);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -18,11 +18,11 @@ export class FallbackStrategy implements RoutingStrategy {
|
||||
readonly name = 'fallback';
|
||||
|
||||
async route(
|
||||
_context: RoutingContext,
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const requestedModel = config.getModel();
|
||||
const requestedModel = context.requestedModel ?? config.getModel();
|
||||
const resolvedModel = resolveModel(
|
||||
requestedModel,
|
||||
config.getPreviewFeatures(),
|
||||
|
||||
@@ -56,4 +56,25 @@ describe('OverrideStrategy', () => {
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(overrideModel);
|
||||
});
|
||||
|
||||
it('should respect requestedModel from context', async () => {
|
||||
const requestedModel = 'requested-model';
|
||||
const configModel = 'config-model';
|
||||
const mockConfig = {
|
||||
getModel: () => configModel,
|
||||
getPreviewFeatures: () => false,
|
||||
} as Config;
|
||||
const contextWithRequestedModel = {
|
||||
requestedModel,
|
||||
} as RoutingContext;
|
||||
|
||||
const decision = await strategy.route(
|
||||
contextWithRequestedModel,
|
||||
mockConfig,
|
||||
mockClient,
|
||||
);
|
||||
|
||||
expect(decision).not.toBeNull();
|
||||
expect(decision?.model).toBe(requestedModel);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,11 +5,7 @@
|
||||
*/
|
||||
|
||||
import type { Config } from '../../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_MODEL_AUTO,
|
||||
PREVIEW_GEMINI_MODEL_AUTO,
|
||||
resolveModel,
|
||||
} from '../../config/models.js';
|
||||
import { isAutoModel, resolveModel } from '../../config/models.js';
|
||||
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
|
||||
import type {
|
||||
RoutingContext,
|
||||
@@ -24,18 +20,16 @@ export class OverrideStrategy implements RoutingStrategy {
|
||||
readonly name = 'override';
|
||||
|
||||
async route(
|
||||
_context: RoutingContext,
|
||||
context: RoutingContext,
|
||||
config: Config,
|
||||
_baseLlmClient: BaseLlmClient,
|
||||
): Promise<RoutingDecision | null> {
|
||||
const overrideModel = config.getModel();
|
||||
const overrideModel = context.requestedModel ?? config.getModel();
|
||||
|
||||
// If the model is 'auto' we should pass to the next strategy.
|
||||
if (
|
||||
overrideModel === DEFAULT_GEMINI_MODEL_AUTO ||
|
||||
overrideModel === PREVIEW_GEMINI_MODEL_AUTO
|
||||
)
|
||||
if (isAutoModel(overrideModel)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Return the overridden model name.
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user