diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index c9cacf79f6..cb328e1125 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -2279,6 +2279,89 @@ describe('LocalAgentExecutor', () => { ); }); + it('should cache the routing decision across multiple turns', async () => { + const definition = createTestDefinition(); + definition.modelConfig.model = 'auto'; + definition.runConfig.maxTurns = 3; + + const mockRouter = { + route: vi.fn().mockResolvedValue({ + model: 'routed-model', + metadata: { source: 'test', reasoning: 'test' }, + }), + }; + vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue( + mockRouter as unknown as ModelRouterService, + ); + + vi.spyOn( + mockConfig.modelConfigService, + 'getResolvedConfig', + ).mockReturnValue({ + model: 'auto', + generateContentConfig: {}, + } as unknown as ResolvedModelConfig); + + const executor = await LocalAgentExecutor.create( + definition, + mockConfig, + onActivity, + ); + + mockModelResponse([ + { + name: LS_TOOL_NAME, + args: {}, + id: 'call1', + }, + ]); + mockModelResponse([ + { + name: COMPLETE_TASK_TOOL_NAME, + args: { finalResult: 'done' }, + id: 'call2', + }, + ]); + + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: {}, + prompt_id: 'test-prompt', + }, + response: { + resultDisplay: 'ls result', + responseParts: [], + data: {}, + }, + }, + ]); + + await executor.run({ goal: 'test' }, signal); + + expect(mockRouter.route).toHaveBeenCalledTimes(1); + expect(mockSendMessageStream).toHaveBeenCalledTimes(2); + expect(mockSendMessageStream).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ model: 'routed-model' }), + expect.any(Array), + expect.any(String), + expect.any(AbortSignal), + LlmRole.SUBAGENT, + ); + expect(mockSendMessageStream).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ model: 'routed-model' }), + expect.any(Array), + expect.any(String), + expect.any(AbortSignal), + LlmRole.SUBAGENT, + ); + }); + it('should NOT use model routing when the agent model is NOT "auto"', async () => { const definition = createTestDefinition(); definition.modelConfig.model = 'concrete-model'; diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index c3572edb11..b05a80f0b7 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -62,6 +62,7 @@ import { getErrorMessage } from '../utils/errors.js'; import { templateString } from './utils.js'; import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js'; import type { RoutingContext } from '../routing/routingStrategy.js'; +import { LRUCache } from 'mnemonist'; import { parseThought } from '../utils/thoughtUtils.js'; import { type z } from 'zod'; import { debugLogger } from '../utils/debugLogger.js'; @@ -127,6 +128,7 @@ export class LocalAgentExecutor { private readonly compressionService: ChatCompressionService; private readonly parentCallId?: string; private hasFailedCompressionAttempt = false; + private cache: LRUCache; private get executionContext(): AgentLoopContext { return { @@ -311,6 +313,7 @@ export class LocalAgentExecutor { this.onActivity = onActivity; this.compressionService = new ChatCompressionService(); this.parentCallId = parentCallId; + this.cache = new LRUCache(10); this.agentId = Math.random().toString(36).slice(2, 8); } @@ -949,26 +952,28 @@ export class LocalAgentExecutor { }); const requestedModel = resolvedConfig.model; - let modelToUse: string; + let modelToUse: string | undefined; if (isAutoModel(requestedModel)) { - // TODO(joshualitt): This try / catch is inconsistent with the routing - // behavior for the main agent. Ideally, we would have a universal - // policy for routing failure. Given routing failure does not necessarily - // mean generation will fail, we may want to share this logic with - // other places we use model routing. - try { - const routingContext: RoutingContext = { - history: chat.getHistory(/*curated=*/ true), - request: message.parts || [], - signal, - requestedModel, - }; - const router = this.context.config.getModelRouterService(); - const decision = await router.route(routingContext); - modelToUse = decision.model; - } catch (error) { - debugLogger.warn(`Error during model routing: ${error}`); - modelToUse = DEFAULT_GEMINI_MODEL; + modelToUse = this.cache.get('modelToUse'); + + // If not cached, fetch from the router and cache the result. + if (!modelToUse) { + try { + const routingContext: RoutingContext = { + history: chat.getHistory(/*curated=*/ true), + request: message.parts || [], + signal, + requestedModel, + }; + const router = this.context.config.getModelRouterService(); + const decision = await router.route(routingContext); + modelToUse = decision.model; + } catch (error) { + debugLogger.warn(`Error during model routing: ${error}`); + modelToUse = DEFAULT_GEMINI_MODEL; + } + // Cache the result regardless of whether it succeeded or fell back + this.cache.set('modelToUse', modelToUse); } } else { modelToUse = requestedModel;