feat(core): Preliminary changes for subagent model routing. (#16035)

This commit is contained in:
joshualitt
2026-01-07 13:21:10 -08:00
committed by GitHub
parent 17b3eb730a
commit a1dd19738e
12 changed files with 200 additions and 15 deletions

View File

@@ -35,6 +35,8 @@ export interface RoutingContext {
request: PartListUnion;
/** An abort signal to cancel an LLM call during routing. */
signal: AbortSignal;
/** The model string requested for this turn, if any. */
requestedModel?: string;
}
/**

View File

@@ -281,4 +281,30 @@ describe('ClassifierStrategy', () => {
);
consoleWarnSpy.mockRestore();
});
it('should respect requestedModel from context in resolveClassifierModel', async () => {
const requestedModel = DEFAULT_GEMINI_MODEL; // Pro model
const mockApiResponse = {
reasoning: 'Choice is flash',
model_choice: 'flash',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const contextWithRequestedModel = {
...mockContext,
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockBaseLlmClient,
);
expect(decision).not.toBeNull();
// Since requestedModel is Pro, and choice is flash, it should resolve to Flash
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
});
});

View File

@@ -168,7 +168,7 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime;
const selectedModel = resolveClassifierModel(
config.getModel(),
context.requestedModel ?? config.getModel(),
routerResponse.model_choice,
config.getPreviewFeatures(),
);

View File

@@ -108,4 +108,25 @@ describe('FallbackStrategy', () => {
// Important: check that it queried snapshot with the RESOLVED model, not 'auto'
expect(mockService.snapshot).toHaveBeenCalledWith(DEFAULT_GEMINI_MODEL);
});
it('should respect requestedModel from context', async () => {
const requestedModel = 'requested-model';
const configModel = 'config-model';
vi.mocked(mockConfig.getModel).mockReturnValue(configModel);
vi.mocked(mockService.snapshot).mockReturnValue({ available: true });
const contextWithRequestedModel = {
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockClient,
);
expect(decision).toBeNull();
// Should check availability of the requested model from context
expect(mockService.snapshot).toHaveBeenCalledWith(requestedModel);
});
});

View File

@@ -18,11 +18,11 @@ export class FallbackStrategy implements RoutingStrategy {
readonly name = 'fallback';
async route(
_context: RoutingContext,
context: RoutingContext,
config: Config,
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const requestedModel = config.getModel();
const requestedModel = context.requestedModel ?? config.getModel();
const resolvedModel = resolveModel(
requestedModel,
config.getPreviewFeatures(),

View File

@@ -56,4 +56,25 @@ describe('OverrideStrategy', () => {
expect(decision).not.toBeNull();
expect(decision?.model).toBe(overrideModel);
});
it('should respect requestedModel from context', async () => {
const requestedModel = 'requested-model';
const configModel = 'config-model';
const mockConfig = {
getModel: () => configModel,
getPreviewFeatures: () => false,
} as Config;
const contextWithRequestedModel = {
requestedModel,
} as RoutingContext;
const decision = await strategy.route(
contextWithRequestedModel,
mockConfig,
mockClient,
);
expect(decision).not.toBeNull();
expect(decision?.model).toBe(requestedModel);
});
});

View File

@@ -5,11 +5,7 @@
*/
import type { Config } from '../../config/config.js';
import {
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
resolveModel,
} from '../../config/models.js';
import { isAutoModel, resolveModel } from '../../config/models.js';
import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import type {
RoutingContext,
@@ -24,18 +20,16 @@ export class OverrideStrategy implements RoutingStrategy {
readonly name = 'override';
async route(
_context: RoutingContext,
context: RoutingContext,
config: Config,
_baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> {
const overrideModel = config.getModel();
const overrideModel = context.requestedModel ?? config.getModel();
// If the model is 'auto' we should pass to the next strategy.
if (
overrideModel === DEFAULT_GEMINI_MODEL_AUTO ||
overrideModel === PREVIEW_GEMINI_MODEL_AUTO
)
if (isAutoModel(overrideModel)) {
return null;
}
// Return the overridden model name.
return {