fix(core): cache model routing decision in LocalAgentExecutor (#26548)

This commit is contained in:
AK
2026-05-07 17:18:22 -07:00
committed by GitHub
parent c52acebaa2
commit ebeea7570d
2 changed files with 107 additions and 19 deletions
@@ -2279,6 +2279,89 @@ describe('LocalAgentExecutor', () => {
); );
}); });
it('should cache the routing decision across multiple turns', async () => {
const definition = createTestDefinition();
definition.modelConfig.model = 'auto';
definition.runConfig.maxTurns = 3;
const mockRouter = {
route: vi.fn().mockResolvedValue({
model: 'routed-model',
metadata: { source: 'test', reasoning: 'test' },
}),
};
vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue(
mockRouter as unknown as ModelRouterService,
);
vi.spyOn(
mockConfig.modelConfigService,
'getResolvedConfig',
).mockReturnValue({
model: 'auto',
generateContentConfig: {},
} as unknown as ResolvedModelConfig);
const executor = await LocalAgentExecutor.create(
definition,
mockConfig,
onActivity,
);
mockModelResponse([
{
name: LS_TOOL_NAME,
args: {},
id: 'call1',
},
]);
mockModelResponse([
{
name: COMPLETE_TASK_TOOL_NAME,
args: { finalResult: 'done' },
id: 'call2',
},
]);
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: {},
prompt_id: 'test-prompt',
},
response: {
resultDisplay: 'ls result',
responseParts: [],
data: {},
},
},
]);
await executor.run({ goal: 'test' }, signal);
expect(mockRouter.route).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
1,
expect.objectContaining({ model: 'routed-model' }),
expect.any(Array),
expect.any(String),
expect.any(AbortSignal),
LlmRole.SUBAGENT,
);
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
2,
expect.objectContaining({ model: 'routed-model' }),
expect.any(Array),
expect.any(String),
expect.any(AbortSignal),
LlmRole.SUBAGENT,
);
});
it('should NOT use model routing when the agent model is NOT "auto"', async () => { it('should NOT use model routing when the agent model is NOT "auto"', async () => {
const definition = createTestDefinition(); const definition = createTestDefinition();
definition.modelConfig.model = 'concrete-model'; definition.modelConfig.model = 'concrete-model';
+24 -19
View File
@@ -62,6 +62,7 @@ import { getErrorMessage } from '../utils/errors.js';
import { templateString } from './utils.js'; import { templateString } from './utils.js';
import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js'; import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js';
import type { RoutingContext } from '../routing/routingStrategy.js'; import type { RoutingContext } from '../routing/routingStrategy.js';
import { LRUCache } from 'mnemonist';
import { parseThought } from '../utils/thoughtUtils.js'; import { parseThought } from '../utils/thoughtUtils.js';
import { type z } from 'zod'; import { type z } from 'zod';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
@@ -127,6 +128,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private readonly compressionService: ChatCompressionService; private readonly compressionService: ChatCompressionService;
private readonly parentCallId?: string; private readonly parentCallId?: string;
private hasFailedCompressionAttempt = false; private hasFailedCompressionAttempt = false;
private cache: LRUCache<string, string>;
private get executionContext(): AgentLoopContext { private get executionContext(): AgentLoopContext {
return { return {
@@ -311,6 +313,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.onActivity = onActivity; this.onActivity = onActivity;
this.compressionService = new ChatCompressionService(); this.compressionService = new ChatCompressionService();
this.parentCallId = parentCallId; this.parentCallId = parentCallId;
this.cache = new LRUCache<string, string>(10);
this.agentId = Math.random().toString(36).slice(2, 8); this.agentId = Math.random().toString(36).slice(2, 8);
} }
@@ -949,26 +952,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}); });
const requestedModel = resolvedConfig.model; const requestedModel = resolvedConfig.model;
let modelToUse: string; let modelToUse: string | undefined;
if (isAutoModel(requestedModel)) { if (isAutoModel(requestedModel)) {
// TODO(joshualitt): This try / catch is inconsistent with the routing modelToUse = this.cache.get('modelToUse');
// behavior for the main agent. Ideally, we would have a universal
// policy for routing failure. Given routing failure does not necessarily // If not cached, fetch from the router and cache the result.
// mean generation will fail, we may want to share this logic with if (!modelToUse) {
// other places we use model routing. try {
try { const routingContext: RoutingContext = {
const routingContext: RoutingContext = { history: chat.getHistory(/*curated=*/ true),
history: chat.getHistory(/*curated=*/ true), request: message.parts || [],
request: message.parts || [], signal,
signal, requestedModel,
requestedModel, };
}; const router = this.context.config.getModelRouterService();
const router = this.context.config.getModelRouterService(); const decision = await router.route(routingContext);
const decision = await router.route(routingContext); modelToUse = decision.model;
modelToUse = decision.model; } catch (error) {
} catch (error) { debugLogger.warn(`Error during model routing: ${error}`);
debugLogger.warn(`Error during model routing: ${error}`); modelToUse = DEFAULT_GEMINI_MODEL;
modelToUse = DEFAULT_GEMINI_MODEL; }
// Cache the result regardless of whether it succeeded or fell back
this.cache.set('modelToUse', modelToUse);
} }
} else { } else {
modelToUse = requestedModel; modelToUse = requestedModel;