feat(core): Preliminary changes for subagent model routing. (#16035)

This commit is contained in:
joshualitt
2026-01-07 13:21:10 -08:00
committed by GitHub
parent 17b3eb730a
commit a1dd19738e
12 changed files with 200 additions and 15 deletions

View File

@@ -9,6 +9,7 @@ import {
resolveModel,
resolveClassifierModel,
isGemini2Model,
isAutoModel,
getDisplayString,
DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL,
@@ -18,6 +19,7 @@ import {
GEMINI_MODEL_ALIAS_PRO,
GEMINI_MODEL_ALIAS_FLASH,
GEMINI_MODEL_ALIAS_FLASH_LITE,
GEMINI_MODEL_ALIAS_AUTO,
PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO,
@@ -171,6 +173,26 @@ describe('isGemini2Model', () => {
});
});
describe('isAutoModel', () => {
it('should return true for "auto"', () => {
expect(isAutoModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(true);
});
it('should return true for "auto-gemini-3"', () => {
expect(isAutoModel(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});
it('should return true for "auto-gemini-2.5"', () => {
expect(isAutoModel(DEFAULT_GEMINI_MODEL_AUTO)).toBe(true);
});
it('should return false for concrete models', () => {
expect(isAutoModel(DEFAULT_GEMINI_MODEL)).toBe(false);
expect(isAutoModel(PREVIEW_GEMINI_MODEL)).toBe(false);
expect(isAutoModel('some-random-model')).toBe(false);
});
});
describe('resolveClassifierModel', () => {
it('should return flash model when alias is flash', () => {
expect(

View File

@@ -146,6 +146,20 @@ export function isGemini2Model(model: string): boolean {
return /^gemini-2(\.|$)/.test(model);
}
/**
* Checks if the model is an auto model.
*
* @param model The model name to check.
* @returns True if the model is an auto model.
*/
export function isAutoModel(model: string): boolean {
return (
model === GEMINI_MODEL_ALIAS_AUTO ||
model === PREVIEW_GEMINI_MODEL_AUTO ||
model === DEFAULT_GEMINI_MODEL_AUTO
);
}
/**
* Checks if the model supports multimodal function responses (multimodal data nested within function response).
* This is supported in Gemini 3.

View File

@@ -606,6 +606,7 @@ export class GeminiClient {
history: this.getChat().getHistory(/*curated=*/ true),
request,
signal,
requestedModel: this.config.getModel(),
};
let modelToUse: string;

View File

@@ -35,6 +35,8 @@ export interface RoutingContext {
request: PartListUnion;
/** An abort signal to cancel an LLM call during routing. */
signal: AbortSignal;
/** The model string requested for this turn, if any. */
requestedModel?: string;
}
/**

View File

@@ -281,4 +281,30 @@ describe('ClassifierStrategy', () => {
);
consoleWarnSpy.mockRestore();
});
it('should respect requestedModel from context in resolveClassifierModel', async () => {
const requestedModel = DEFAULT_GEMINI_MODEL; // Pro model
const mockApiResponse = {
reasoning: 'Choice is flash',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const contextWithRequestedModel = {
...mockContext,
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockBaseLlmClient,
);
expect(decision).not.toBeNull();
// Since requestedModel is Pro, and choice is flash, it should resolve to Flash
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
});

View File

@@ -168,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
config.getModel(),
context.requestedModel ?? config.getModel(),
routerResponse.model_choice,
config.getPreviewFeatures(),
);

View File

@@ -108,4 +108,25 @@ describe('FallbackStrategy', () => {
// Important: check that it queried snapshot with the RESOLVED model, not 'auto'
expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL);
});
it('should respect requestedModel from context', async () => {
const requestedModel = 'requested-model';
const configModel = 'config-model';
vi.mocked(mockConfig.getModel).mockReturnValue(configModel);
vi.mocked(mockService.snapshot).mockReturnValue({ available: true });
const contextWithRequestedModel = {
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockClient,
);
expect(decision).toBeNull();
// Should check availability of the requested model from context
expect(mockService.snapshot).toHaveBeenCalledWith(requestedModel);
});
});

View File

@@ -18,11 +18,11 @@ export class FallbackStrategy implements RoutingStrategy {
readonly name = 'fallback';
async route(
_context: RoutingContext,
context: RoutingContext,
config: Config,
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const requestedModel = config.getModel();
const requestedModel = context.requestedModel ?? config.getModel();
const resolvedModel = resolveModel(
requestedModel,
config.getPreviewFeatures(),

View File

@@ -56,4 +56,25 @@ describe('OverrideStrategy', () => {
expect(decision).not.toBeNull();
expect(decision?.model).toBe(overrideModel);
});
it('should respect requestedModel from context', async () => {
const requestedModel = 'requested-model';
const configModel = 'config-model';
const mockConfig = {
getModel: () => configModel,
getPreviewFeatures: () => false,
} as Config;
const contextWithRequestedModel = {
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockClient,
);
expect(decision).not.toBeNull();
expect(decision?.model).toBe(requestedModel);
});
});

View File

@@ -5,11 +5,7 @@
*/
import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
resolveModel,
} from '../../config/models.js';
import { isAutoModel, resolveModel } from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import type {
RoutingContext,
@@ -24,18 +20,16 @@ export class OverrideStrategy implements RoutingStrategy {
readonly name = 'override';
async route(
_context: RoutingContext,
context: RoutingContext,
config: Config,
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const overrideModel = config.getModel();
const overrideModel = context.requestedModel ?? config.getModel();
// If the model is 'auto' we should pass to the next strategy.
if (
overrideModel === DEFAULT_GEMINI_MODEL_AUTO ||
overrideModel === PREVIEW_GEMINI_MODEL_AUTO
)
if (isAutoModel(overrideModel)) {
return null;
}
// Return the overridden model name.
return {

View File

@@ -577,6 +577,81 @@ describe('ModelConfigService', () => {
});
});
describe('runtime overrides', () => {
it('should resolve a simple runtime-registered override', () => {
const config: ModelConfigServiceConfig = {
aliases: {},
overrides: [],
};
const service = new ModelConfigService(config);
service.registerRuntimeModelOverride({
match: { model: 'gemini-pro' },
modelConfig: {
generateContentConfig: {
temperature: 0.99,
},
},
});
const resolved = service.getResolvedConfig({ model: 'gemini-pro' });
expect(resolved.model).toBe('gemini-pro');
expect(resolved.generateContentConfig.temperature).toBe(0.99);
});
it('should prioritize runtime overrides over default overrides when they have the same specificity', () => {
const config: ModelConfigServiceConfig = {
aliases: {},
overrides: [
{
match: { model: 'gemini-pro' },
modelConfig: { generateContentConfig: { temperature: 0.1 } },
},
],
};
const service = new ModelConfigService(config);
service.registerRuntimeModelOverride({
match: { model: 'gemini-pro' },
modelConfig: { generateContentConfig: { temperature: 0.9 } },
});
const resolved = service.getResolvedConfig({ model: 'gemini-pro' });
// Runtime overrides are appended after overrides/customOverrides, so they should win.
expect(resolved.generateContentConfig.temperature).toBe(0.9);
});
it('should still respect specificity with runtime overrides', () => {
const config: ModelConfigServiceConfig = {
aliases: {},
overrides: [],
};
const service = new ModelConfigService(config);
// Register a more specific runtime override
service.registerRuntimeModelOverride({
match: { model: 'gemini-pro', overrideScope: 'my-agent' },
modelConfig: { generateContentConfig: { temperature: 0.1 } },
});
// Register a less specific runtime override later
service.registerRuntimeModelOverride({
match: { model: 'gemini-pro' },
modelConfig: { generateContentConfig: { temperature: 0.9 } },
});
const resolved = service.getResolvedConfig({
model: 'gemini-pro',
overrideScope: 'my-agent',
});
// Specificity should win over order
expect(resolved.generateContentConfig.temperature).toBe(0.1);
});
});
describe('custom aliases', () => {
it('should resolve a custom alias', () => {
const config: ModelConfigServiceConfig = {

View File

@@ -65,6 +65,7 @@ export interface _ResolvedModelConfig {
export class ModelConfigService {
private readonly runtimeAliases: Record<string, ModelConfigAlias> = {};
private readonly runtimeOverrides: ModelConfigOverride[] = [];
// TODO(12597): Process config to build a typed alias hierarchy.
constructor(private readonly config: ModelConfigServiceConfig) {}
@@ -73,6 +74,10 @@ export class ModelConfigService {
this.runtimeAliases[aliasName] = alias;
}
registerRuntimeModelOverride(override: ModelConfigOverride): void {
this.runtimeOverrides.push(override);
}
private resolveAlias(
aliasName: string,
aliases: Record<string, ModelConfigAlias>,
@@ -123,7 +128,11 @@ export class ModelConfigService {
...customAliases,
...this.runtimeAliases,
};
const allOverrides = [...overrides, ...customOverrides];
const allOverrides = [
...overrides,
...customOverrides,
...this.runtimeOverrides,
];
let baseModel: string | undefined = context.model;
let resolvedConfig: GenerateContentConfig = {};