mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-14 23:31:13 -07:00
feat(agents): migrate subagents to event-driven scheduler (#17567)
This commit is contained in:
@@ -15,7 +15,6 @@ import type {
|
||||
FunctionDeclaration,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { type ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
import { getVersion } from '../utils/version.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { getToolCallContext } from '../utils/toolCallContext.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
private readonly runtimeContext: Config;
|
||||
private readonly onActivity?: ActivityCallback;
|
||||
private readonly compressionService: ChatCompressionService;
|
||||
private readonly parentCallId?: string;
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
/**
|
||||
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// Get the parent prompt ID from context
|
||||
const parentPromptId = promptIdContext.getStore();
|
||||
|
||||
// Get the parent tool call ID from context
|
||||
const toolContext = getToolCallContext();
|
||||
const parentCallId = toolContext?.callId;
|
||||
|
||||
return new LocalAgentExecutor(
|
||||
definition,
|
||||
runtimeContext,
|
||||
agentToolRegistry,
|
||||
parentPromptId,
|
||||
parentCallId,
|
||||
onActivity,
|
||||
);
|
||||
}
|
||||
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
runtimeContext: Config,
|
||||
toolRegistry: ToolRegistry,
|
||||
parentPromptId: string | undefined,
|
||||
parentCallId: string | undefined,
|
||||
onActivity?: ActivityCallback,
|
||||
) {
|
||||
this.definition = definition;
|
||||
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.onActivity = onActivity;
|
||||
this.compressionService = new ChatCompressionService();
|
||||
this.parentCallId = parentCallId;
|
||||
|
||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||
// parentPromptId will be undefined if this agent is invoked directly
|
||||
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
let submittedOutput: string | null = null;
|
||||
let taskCompleted = false;
|
||||
|
||||
// We'll collect promises for the tool executions
|
||||
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
|
||||
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
|
||||
const syncResponseParts: Part[] = [];
|
||||
// We'll separate complete_task from other tools
|
||||
const toolRequests: ToolCallRequestInfo[] = [];
|
||||
// Map to keep track of tool name by callId for activity emission
|
||||
const toolNameMap = new Map<string, string>();
|
||||
// Synchronous results (like complete_task or unauthorized calls)
|
||||
const syncResults = new Map<string, Part>();
|
||||
|
||||
for (const [index, functionCall] of functionCalls.entries()) {
|
||||
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||
const args = functionCall.args ?? {};
|
||||
const toolName = functionCall.name as string;
|
||||
|
||||
this.emitActivity('TOOL_CALL_START', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
args,
|
||||
});
|
||||
|
||||
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
|
||||
if (toolName === TASK_COMPLETE_TOOL_NAME) {
|
||||
if (taskCompleted) {
|
||||
// We already have a completion from this turn. Ignore subsequent ones.
|
||||
const error =
|
||||
'Task already marked complete in this turn. Ignoring duplicate call.';
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
continue;
|
||||
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
if (!validationResult.success) {
|
||||
taskCompleted = false; // Validation failed, revoke completion
|
||||
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
continue;
|
||||
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
? outputValue
|
||||
: JSON.stringify(outputValue, null, 2);
|
||||
}
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { result: 'Output submitted and task completed.' },
|
||||
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
},
|
||||
});
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
output: 'Output submitted and task completed.',
|
||||
});
|
||||
} else {
|
||||
// Failed to provide required output.
|
||||
taskCompleted = false; // Revoke completion status
|
||||
const error = `Missing required argument '${outputName}' for completion.`;
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
}
|
||||
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
typeof resultArg === 'string'
|
||||
? resultArg
|
||||
: JSON.stringify(resultArg, null, 2);
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { status: 'Result submitted and task completed.' },
|
||||
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
},
|
||||
});
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
output: 'Result submitted and task completed.',
|
||||
});
|
||||
} else {
|
||||
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
taskCompleted = false; // Revoke completion
|
||||
const error =
|
||||
'Missing required "result" argument. You must provide your findings when calling complete_task.';
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
}
|
||||
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
|
||||
// Handle standard tools
|
||||
if (!allowedToolNames.has(functionCall.name as string)) {
|
||||
const error = createUnauthorizedToolError(functionCall.name as string);
|
||||
|
||||
if (!allowedToolNames.has(toolName)) {
|
||||
const error = createUnauthorizedToolError(toolName);
|
||||
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
|
||||
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: functionCall.name as string,
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { error },
|
||||
},
|
||||
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call_unauthorized',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
callId,
|
||||
error,
|
||||
});
|
||||
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
continue;
|
||||
}
|
||||
|
||||
const requestInfo: ToolCallRequestInfo = {
|
||||
toolRequests.push({
|
||||
callId,
|
||||
name: functionCall.name as string,
|
||||
name: toolName,
|
||||
args,
|
||||
isClientInitiated: true,
|
||||
isClientInitiated: false, // These are coming from the subagent (the "model")
|
||||
prompt_id: promptId,
|
||||
};
|
||||
});
|
||||
toolNameMap.set(callId, toolName);
|
||||
}
|
||||
|
||||
// Create a promise for the tool execution
|
||||
const executionPromise = (async () => {
|
||||
const agentContext = Object.create(this.runtimeContext);
|
||||
agentContext.getToolRegistry = () => this.toolRegistry;
|
||||
agentContext.getApprovalMode = () => ApprovalMode.YOLO;
|
||||
|
||||
const { response: toolResponse } = await executeToolCall(
|
||||
agentContext,
|
||||
requestInfo,
|
||||
// Execute standard tool calls using the new scheduler
|
||||
if (toolRequests.length > 0) {
|
||||
const completedCalls = await scheduleAgentTools(
|
||||
this.runtimeContext,
|
||||
toolRequests,
|
||||
{
|
||||
schedulerId: this.agentId,
|
||||
parentCallId: this.parentCallId,
|
||||
toolRegistry: this.toolRegistry,
|
||||
signal,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
if (toolResponse.error) {
|
||||
for (const call of completedCalls) {
|
||||
const toolName =
|
||||
toolNameMap.get(call.request.callId) || call.request.name;
|
||||
if (call.status === 'success') {
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: toolName,
|
||||
output: call.response.resultDisplay,
|
||||
});
|
||||
} else if (call.status === 'error') {
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
error: toolResponse.error.message,
|
||||
name: toolName,
|
||||
error: call.response.error?.message || 'Unknown error',
|
||||
});
|
||||
} else {
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
output: toolResponse.resultDisplay,
|
||||
} else if (call.status === 'cancelled') {
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: toolName,
|
||||
error: 'Tool call was cancelled.',
|
||||
});
|
||||
}
|
||||
|
||||
return toolResponse.responseParts;
|
||||
})();
|
||||
|
||||
toolExecutionPromises.push(executionPromise);
|
||||
// Add result to syncResults to preserve order later
|
||||
syncResults.set(call.request.callId, call.response.responseParts[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool executions to complete
|
||||
const asyncResults = await Promise.all(toolExecutionPromises);
|
||||
|
||||
// Combine all response parts
|
||||
const toolResponseParts: Part[] = [...syncResponseParts];
|
||||
for (const result of asyncResults) {
|
||||
if (result) {
|
||||
toolResponseParts.push(...result);
|
||||
// Reconstruct toolResponseParts in the original order
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const [index, functionCall] of functionCalls.entries()) {
|
||||
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||
const part = syncResults.get(callId);
|
||||
if (part) {
|
||||
toolResponseParts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user