mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-12 07:01:09 -07:00
feat(core): Wire up model routing to subagents. (#16043)
This commit is contained in:
@@ -57,8 +57,12 @@ import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import type {
|
||||
ModelConfigKey,
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
import type { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
|
||||
const {
|
||||
mockSendMessageStream,
|
||||
@@ -1192,6 +1196,101 @@ describe('LocalAgentExecutor', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('Model Routing', () => {
|
||||
it('should use model routing when the agent model is "auto"', async () => {
|
||||
const definition = createTestDefinition();
|
||||
definition.modelConfig.model = 'auto';
|
||||
|
||||
const mockRouter = {
|
||||
route: vi.fn().mockResolvedValue({
|
||||
model: 'routed-model',
|
||||
metadata: { source: 'test', reasoning: 'test' },
|
||||
}),
|
||||
};
|
||||
vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue(
|
||||
mockRouter as unknown as ModelRouterService,
|
||||
);
|
||||
|
||||
// Mock resolved config to return 'auto'
|
||||
vi.spyOn(
|
||||
mockConfig.modelConfigService,
|
||||
'getResolvedConfig',
|
||||
).mockReturnValue({
|
||||
model: 'auto',
|
||||
generateContentConfig: {},
|
||||
} as unknown as ResolvedModelConfig);
|
||||
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'done' },
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
|
||||
await executor.run({ goal: 'test' }, signal);
|
||||
|
||||
expect(mockRouter.route).toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'routed-model' }),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT use model routing when the agent model is NOT "auto"', async () => {
|
||||
const definition = createTestDefinition();
|
||||
definition.modelConfig.model = 'concrete-model';
|
||||
|
||||
const mockRouter = {
|
||||
route: vi.fn(),
|
||||
};
|
||||
vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue(
|
||||
mockRouter as unknown as ModelRouterService,
|
||||
);
|
||||
|
||||
// Mock resolved config to return 'concrete-model'
|
||||
vi.spyOn(
|
||||
mockConfig.modelConfigService,
|
||||
'getResolvedConfig',
|
||||
).mockReturnValue({
|
||||
model: 'concrete-model',
|
||||
generateContentConfig: {},
|
||||
} as unknown as ResolvedModelConfig);
|
||||
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockModelResponse([
|
||||
{
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
args: { finalResult: 'done' },
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
|
||||
await executor.run({ goal: 'test' }, signal);
|
||||
|
||||
expect(mockRouter.route).not.toHaveBeenCalled();
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'concrete-model' }),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('run (Termination Conditions)', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
|
||||
@@ -40,6 +40,8 @@ import type {
|
||||
} from './types.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import { templateString } from './utils.js';
|
||||
import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js';
|
||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { parseThought } from '../utils/thoughtUtils.js';
|
||||
import { type z } from 'zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
@@ -589,9 +591,44 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
signal: AbortSignal,
|
||||
promptId: string,
|
||||
): Promise<{ functionCalls: FunctionCall[]; textResponse: string }> {
|
||||
const modelConfigAlias = getModelConfigAlias(this.definition);
|
||||
|
||||
// Resolve the model config early to get the concrete model string (which may be `auto`).
|
||||
const resolvedConfig =
|
||||
this.runtimeContext.modelConfigService.getResolvedConfig({
|
||||
model: modelConfigAlias,
|
||||
overrideScope: this.definition.name,
|
||||
});
|
||||
const requestedModel = resolvedConfig.model;
|
||||
|
||||
let modelToUse: string;
|
||||
if (isAutoModel(requestedModel)) {
|
||||
// TODO(joshualitt): This try / catch is inconsistent with the routing
|
||||
// behavior for the main agent. Ideally, we would have a universal
|
||||
// policy for routing failure. Given routing failure does not necessarily
|
||||
// mean generation will fail, we may want to share this logic with
|
||||
// other places we use model routing.
|
||||
try {
|
||||
const routingContext: RoutingContext = {
|
||||
history: chat.getHistory(/*curated=*/ true),
|
||||
request: message.parts || [],
|
||||
signal,
|
||||
requestedModel,
|
||||
};
|
||||
const router = this.runtimeContext.getModelRouterService();
|
||||
const decision = await router.route(routingContext);
|
||||
modelToUse = decision.model;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`Error during model routing: ${error}`);
|
||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||
}
|
||||
} else {
|
||||
modelToUse = requestedModel;
|
||||
}
|
||||
|
||||
const responseStream = await chat.sendMessageStream(
|
||||
{
|
||||
model: getModelConfigAlias(this.definition),
|
||||
model: modelToUse,
|
||||
overrideScope: this.definition.name,
|
||||
},
|
||||
message.parts || [],
|
||||
|
||||
@@ -243,6 +243,51 @@ describe('AgentRegistry', () => {
|
||||
});
|
||||
|
||||
describe('registration logic', () => {
|
||||
it('should register runtime overrides when the model is "auto"', async () => {
|
||||
const autoAgent: LocalAgentDefinition = {
|
||||
...MOCK_AGENT_V1,
|
||||
name: 'AutoAgent',
|
||||
modelConfig: { ...MOCK_AGENT_V1.modelConfig, model: 'auto' },
|
||||
};
|
||||
|
||||
const registerOverrideSpy = vi.spyOn(
|
||||
mockConfig.modelConfigService,
|
||||
'registerRuntimeModelOverride',
|
||||
);
|
||||
|
||||
await registry.testRegisterAgent(autoAgent);
|
||||
|
||||
// Should register one alias for the custom model config.
|
||||
expect(
|
||||
mockConfig.modelConfigService.getResolvedConfig({
|
||||
model: getModelConfigAlias(autoAgent),
|
||||
}),
|
||||
).toStrictEqual({
|
||||
model: 'auto',
|
||||
generateContentConfig: {
|
||||
temperature: autoAgent.modelConfig.temp,
|
||||
topP: autoAgent.modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Should register one override for the agent name (scope)
|
||||
expect(registerOverrideSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Check scope override
|
||||
expect(registerOverrideSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
match: { overrideScope: autoAgent.name },
|
||||
modelConfig: expect.objectContaining({
|
||||
generateContentConfig: expect.any(Object),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should register a valid agent definition', async () => {
|
||||
await registry.testRegisterAgent(MOCK_AGENT_V1);
|
||||
expect(registry.getDefinition('MockAgent')).toEqual(MOCK_AGENT_V1);
|
||||
|
||||
@@ -20,8 +20,8 @@ import {
|
||||
GEMINI_MODEL_ALIAS_AUTO,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
isPreviewModel,
|
||||
isAutoModel,
|
||||
} from '../config/models.js';
|
||||
import type { ModelConfigAlias } from '../services/modelConfigService.js';
|
||||
|
||||
/**
|
||||
* Returns the model config alias for a given agent definition.
|
||||
@@ -199,7 +199,10 @@ export class AgentRegistry {
|
||||
|
||||
this.agents.set(definition.name, definition);
|
||||
|
||||
// Register model config.
|
||||
// Register model config. We always create a runtime alias. However,
|
||||
// if the user is using `auto` as a model string then we also create
|
||||
// runtime overrides to ensure the subagent generation settings are
|
||||
// respected regardless of the final model string from routing.
|
||||
// TODO(12916): Migrate sub-agents where possible to static configs.
|
||||
const modelConfig = definition.modelConfig;
|
||||
let model = modelConfig.model;
|
||||
@@ -207,24 +210,35 @@ export class AgentRegistry {
|
||||
model = this.config.getModel();
|
||||
}
|
||||
|
||||
const runtimeAlias: ModelConfigAlias = {
|
||||
modelConfig: {
|
||||
model,
|
||||
generateContentConfig: {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
},
|
||||
const generateContentConfig = {
|
||||
temperature: modelConfig.temp,
|
||||
topP: modelConfig.top_p,
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: modelConfig.thinkingBudget ?? -1,
|
||||
},
|
||||
};
|
||||
|
||||
this.config.modelConfigService.registerRuntimeModelConfig(
|
||||
getModelConfigAlias(definition),
|
||||
runtimeAlias,
|
||||
{
|
||||
modelConfig: {
|
||||
model,
|
||||
generateContentConfig,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
if (isAutoModel(model)) {
|
||||
this.config.modelConfigService.registerRuntimeModelOverride({
|
||||
match: {
|
||||
overrideScope: definition.name,
|
||||
},
|
||||
modelConfig: {
|
||||
generateContentConfig,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user