mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(core): cache model routing decision in LocalAgentExecutor (#26548)
This commit is contained in:
@@ -2279,6 +2279,89 @@ describe('LocalAgentExecutor', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should cache the routing decision across multiple turns', async () => {
|
||||||
|
const definition = createTestDefinition();
|
||||||
|
definition.modelConfig.model = 'auto';
|
||||||
|
definition.runConfig.maxTurns = 3;
|
||||||
|
|
||||||
|
const mockRouter = {
|
||||||
|
route: vi.fn().mockResolvedValue({
|
||||||
|
model: 'routed-model',
|
||||||
|
metadata: { source: 'test', reasoning: 'test' },
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
vi.spyOn(mockConfig, 'getModelRouterService').mockReturnValue(
|
||||||
|
mockRouter as unknown as ModelRouterService,
|
||||||
|
);
|
||||||
|
|
||||||
|
vi.spyOn(
|
||||||
|
mockConfig.modelConfigService,
|
||||||
|
'getResolvedConfig',
|
||||||
|
).mockReturnValue({
|
||||||
|
model: 'auto',
|
||||||
|
generateContentConfig: {},
|
||||||
|
} as unknown as ResolvedModelConfig);
|
||||||
|
|
||||||
|
const executor = await LocalAgentExecutor.create(
|
||||||
|
definition,
|
||||||
|
mockConfig,
|
||||||
|
onActivity,
|
||||||
|
);
|
||||||
|
|
||||||
|
mockModelResponse([
|
||||||
|
{
|
||||||
|
name: LS_TOOL_NAME,
|
||||||
|
args: {},
|
||||||
|
id: 'call1',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
mockModelResponse([
|
||||||
|
{
|
||||||
|
name: COMPLETE_TASK_TOOL_NAME,
|
||||||
|
args: { finalResult: 'done' },
|
||||||
|
id: 'call2',
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
|
{
|
||||||
|
status: 'success',
|
||||||
|
request: {
|
||||||
|
callId: 'call1',
|
||||||
|
name: LS_TOOL_NAME,
|
||||||
|
args: {},
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
response: {
|
||||||
|
resultDisplay: 'ls result',
|
||||||
|
responseParts: [],
|
||||||
|
data: {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
await executor.run({ goal: 'test' }, signal);
|
||||||
|
|
||||||
|
expect(mockRouter.route).toHaveBeenCalledTimes(1);
|
||||||
|
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||||
|
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
|
||||||
|
1,
|
||||||
|
expect.objectContaining({ model: 'routed-model' }),
|
||||||
|
expect.any(Array),
|
||||||
|
expect.any(String),
|
||||||
|
expect.any(AbortSignal),
|
||||||
|
LlmRole.SUBAGENT,
|
||||||
|
);
|
||||||
|
expect(mockSendMessageStream).toHaveBeenNthCalledWith(
|
||||||
|
2,
|
||||||
|
expect.objectContaining({ model: 'routed-model' }),
|
||||||
|
expect.any(Array),
|
||||||
|
expect.any(String),
|
||||||
|
expect.any(AbortSignal),
|
||||||
|
LlmRole.SUBAGENT,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should NOT use model routing when the agent model is NOT "auto"', async () => {
|
it('should NOT use model routing when the agent model is NOT "auto"', async () => {
|
||||||
const definition = createTestDefinition();
|
const definition = createTestDefinition();
|
||||||
definition.modelConfig.model = 'concrete-model';
|
definition.modelConfig.model = 'concrete-model';
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ import { getErrorMessage } from '../utils/errors.js';
|
|||||||
import { templateString } from './utils.js';
|
import { templateString } from './utils.js';
|
||||||
import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js';
|
import { DEFAULT_GEMINI_MODEL, isAutoModel } from '../config/models.js';
|
||||||
import type { RoutingContext } from '../routing/routingStrategy.js';
|
import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||||
|
import { LRUCache } from 'mnemonist';
|
||||||
import { parseThought } from '../utils/thoughtUtils.js';
|
import { parseThought } from '../utils/thoughtUtils.js';
|
||||||
import { type z } from 'zod';
|
import { type z } from 'zod';
|
||||||
import { debugLogger } from '../utils/debugLogger.js';
|
import { debugLogger } from '../utils/debugLogger.js';
|
||||||
@@ -127,6 +128,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
private readonly compressionService: ChatCompressionService;
|
private readonly compressionService: ChatCompressionService;
|
||||||
private readonly parentCallId?: string;
|
private readonly parentCallId?: string;
|
||||||
private hasFailedCompressionAttempt = false;
|
private hasFailedCompressionAttempt = false;
|
||||||
|
private cache: LRUCache<string, string>;
|
||||||
|
|
||||||
private get executionContext(): AgentLoopContext {
|
private get executionContext(): AgentLoopContext {
|
||||||
return {
|
return {
|
||||||
@@ -311,6 +313,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
this.onActivity = onActivity;
|
this.onActivity = onActivity;
|
||||||
this.compressionService = new ChatCompressionService();
|
this.compressionService = new ChatCompressionService();
|
||||||
this.parentCallId = parentCallId;
|
this.parentCallId = parentCallId;
|
||||||
|
this.cache = new LRUCache<string, string>(10);
|
||||||
|
|
||||||
this.agentId = Math.random().toString(36).slice(2, 8);
|
this.agentId = Math.random().toString(36).slice(2, 8);
|
||||||
}
|
}
|
||||||
@@ -949,26 +952,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
});
|
});
|
||||||
const requestedModel = resolvedConfig.model;
|
const requestedModel = resolvedConfig.model;
|
||||||
|
|
||||||
let modelToUse: string;
|
let modelToUse: string | undefined;
|
||||||
if (isAutoModel(requestedModel)) {
|
if (isAutoModel(requestedModel)) {
|
||||||
// TODO(joshualitt): This try / catch is inconsistent with the routing
|
modelToUse = this.cache.get('modelToUse');
|
||||||
// behavior for the main agent. Ideally, we would have a universal
|
|
||||||
// policy for routing failure. Given routing failure does not necessarily
|
// If not cached, fetch from the router and cache the result.
|
||||||
// mean generation will fail, we may want to share this logic with
|
if (!modelToUse) {
|
||||||
// other places we use model routing.
|
try {
|
||||||
try {
|
const routingContext: RoutingContext = {
|
||||||
const routingContext: RoutingContext = {
|
history: chat.getHistory(/*curated=*/ true),
|
||||||
history: chat.getHistory(/*curated=*/ true),
|
request: message.parts || [],
|
||||||
request: message.parts || [],
|
signal,
|
||||||
signal,
|
requestedModel,
|
||||||
requestedModel,
|
};
|
||||||
};
|
const router = this.context.config.getModelRouterService();
|
||||||
const router = this.context.config.getModelRouterService();
|
const decision = await router.route(routingContext);
|
||||||
const decision = await router.route(routingContext);
|
modelToUse = decision.model;
|
||||||
modelToUse = decision.model;
|
} catch (error) {
|
||||||
} catch (error) {
|
debugLogger.warn(`Error during model routing: ${error}`);
|
||||||
debugLogger.warn(`Error during model routing: ${error}`);
|
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
}
|
||||||
|
// Cache the result regardless of whether it succeeded or fell back
|
||||||
|
this.cache.set('modelToUse', modelToUse);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
modelToUse = requestedModel;
|
modelToUse = requestedModel;
|
||||||
|
|||||||
Reference in New Issue
Block a user