mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(core): cache model routing decision in LocalAgentExecutor (#26548)
This commit is contained in:
@@ -2279,6 +2279,89 @@ describe('LocalAgentExecutor', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should cache the routing decision across multiple turns', async () => {
|
||||
const definition = createTestDefinition();
|
||||
definition.modelConfig.model = 'auto';
|
||||
definition.runConfig.maxTurns = 3;
|
||||
|
||||
const mockRouter = {
|
||||
route: vi.fn().mockResolvedValue({
|
||||
model: 'routed-model',
|
||||
metadata: { source: 'test', reasoning: 'test' },
|
||||
}),
|
||||
};
|
||||
vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue(
|
||||
mockRouter as unknown as ModelRouterService,
|
||||
);
|
||||
|
||||
vi.spyOn(
|
||||
mockConfig.modelConfigService,
|
||||
'getResolvedConfig',
|
||||
).mockReturnValue({
|
||||
model: 'auto',
|
||||
generateContentConfig: {},
|
||||
} as unknown as ResolvedModelConfig);
|
||||
|
||||
const executor = await LocalAgentExecutor.create(
|
||||
definition,
|
||||
mockConfig,
|
||||
onActivity,
|
||||
);
|
||||
|
||||
mockModelResponse([
|
||||
{
|
||||
name: LS_TOOL_NAME,
|
||||
args: {},
|
||||
id: 'call1',
|
||||
},
|
||||
]);
|
||||
mockModelResponse([
|
||||
{
|
||||
name: COMPLETE_TASK_TOOL_NAME,
|
||||
args: { finalResult: 'done' },
|
||||
id: 'call2',
|
||||
},
|
||||
]);
|
||||
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: {},
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
response: {
|
||||
resultDisplay: 'ls result',
|
||||
responseParts: [],
|
||||
data: {},
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
await executor.run({ goal: 'test' }, signal);
|
||||
|
||||
expect(mockRouter.route).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({ model: 'routed-model' }),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(AbortSignal),
|
||||
LlmRole.SUBAGENT,
|
||||
);
|
||||
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({ model: 'routed-model' }),
|
||||
expect.any(Array),
|
||||
expect.any(String),
|
||||
expect.any(AbortSignal),
|
||||
LlmRole.SUBAGENT,
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT use model routing when the agent model is NOT "auto"', async () => {
|
||||
const definition = createTestDefinition();
|
||||
definition.modelConfig.model = 'concrete-model';
|
||||
|
||||
@@ -62,6 +62,7 @@ import { getErrorMessage } from '../utils/errors.js';
|
||||
import { templateString } from './utils.js';
|
||||
import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js';
|
||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { LRUCache } from 'mnemonist';
|
||||
import { parseThought } from '../utils/thoughtUtils.js';
|
||||
import { type z } from 'zod';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
@@ -127,6 +128,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
private readonly compressionService: ChatCompressionService;
|
||||
private readonly parentCallId?: string;
|
||||
private hasFailedCompressionAttempt = false;
|
||||
private cache: LRUCache<string, string>;
|
||||
|
||||
private get executionContext(): AgentLoopContext {
|
||||
return {
|
||||
@@ -311,6 +313,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.onActivity = onActivity;
|
||||
this.compressionService = new ChatCompressionService();
|
||||
this.parentCallId = parentCallId;
|
||||
this.cache = new LRUCache<string, string>(10);
|
||||
|
||||
this.agentId = Math.random().toString(36).slice(2, 8);
|
||||
}
|
||||
@@ -949,26 +952,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
const requestedModel = resolvedConfig.model;
|
||||
|
||||
let modelToUse: string;
|
||||
let modelToUse: string | undefined;
|
||||
if (isAutoModel(requestedModel)) {
|
||||
// TODO(joshualitt): This try / catch is inconsistent with the routing
|
||||
// behavior for the main agent. Ideally, we would have a universal
|
||||
// policy for routing failure. Given routing failure does not necessarily
|
||||
// mean generation will fail, we may want to share this logic with
|
||||
// other places we use model routing.
|
||||
try {
|
||||
const routingContext: RoutingContext = {
|
||||
history: chat.getHistory(/*curated=*/ true),
|
||||
request: message.parts || [],
|
||||
signal,
|
||||
requestedModel,
|
||||
};
|
||||
const router = this.context.config.getModelRouterService();
|
||||
const decision = await router.route(routingContext);
|
||||
modelToUse = decision.model;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`Error during model routing: ${error}`);
|
||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||
modelToUse = this.cache.get('modelToUse');
|
||||
|
||||
// If not cached, fetch from the router and cache the result.
|
||||
if (!modelToUse) {
|
||||
try {
|
||||
const routingContext: RoutingContext = {
|
||||
history: chat.getHistory(/*curated=*/ true),
|
||||
request: message.parts || [],
|
||||
signal,
|
||||
requestedModel,
|
||||
};
|
||||
const router = this.context.config.getModelRouterService();
|
||||
const decision = await router.route(routingContext);
|
||||
modelToUse = decision.model;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`Error during model routing: ${error}`);
|
||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||
}
|
||||
// Cache the result regardless of whether it succeeded or fell back
|
||||
this.cache.set('modelToUse', modelToUse);
|
||||
}
|
||||
} else {
|
||||
modelToUse = requestedModel;
|
||||
|
||||
Reference in New Issue
Block a user