feat(agents): migrate subagents to event-driven scheduler (#17567)

This commit is contained in:
Abhi
2026-01-26 17:12:55 -05:00
committed by GitHub
parent 13bc5f620c
commit 9d34ae52d6
8 changed files with 741 additions and 335 deletions

View File

@@ -15,7 +15,6 @@ import type {
FunctionDeclaration,
Schema,
} from '@google/genai';
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { CompressionStatus } from '../core/turn.js';
import { type ToolCallRequestInfo } from '../scheduler/types.js';
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
import { debugLogger } from '../utils/debugLogger.js';
import { getModelConfigAlias } from './registry.js';
import { getVersion } from '../utils/version.js';
import { ApprovalMode } from '../policy/types.js';
import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private readonly runtimeContext: Config;
private readonly onActivity?: ActivityCallback;
private readonly compressionService: ChatCompressionService;
private readonly parentCallId?: string;
private hasFailedCompressionAttempt = false;
/**
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
// Get the parent prompt ID from context
const parentPromptId = promptIdContext.getStore();
// Get the parent tool call ID from context
const toolContext = getToolCallContext();
const parentCallId = toolContext?.callId;
return new LocalAgentExecutor(
definition,
runtimeContext,
agentToolRegistry,
parentPromptId,
parentCallId,
onActivity,
);
}
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
runtimeContext: Config,
toolRegistry: ToolRegistry,
parentPromptId: string | undefined,
parentCallId: string | undefined,
onActivity?: ActivityCallback,
) {
this.definition = definition;
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.toolRegistry = toolRegistry;
this.onActivity = onActivity;
this.compressionService = new ChatCompressionService();
this.parentCallId = parentCallId;
const randomIdPart = Math.random().toString(36).slice(2, 8);
// parentPromptId will be undefined if this agent is invoked directly
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
let submittedOutput: string | null = null;
let taskCompleted = false;
// We'll collect promises for the tool executions
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
const syncResponseParts: Part[] = [];
// We'll separate complete_task from other tools
const toolRequests: ToolCallRequestInfo[] = [];
// Map to keep track of tool name by callId for activity emission
const toolNameMap = new Map<string, string>();
// Synchronous results (like complete_task or unauthorized calls)
const syncResults = new Map<string, Part>();
for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`;
const args = functionCall.args ?? {};
const toolName = functionCall.name as string;
this.emitActivity('TOOL_CALL_START', {
name: functionCall.name,
name: toolName,
args,
});
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
if (toolName === TASK_COMPLETE_TOOL_NAME) {
if (taskCompleted) {
// We already have a completion from this turn. Ignore subsequent ones.
const error =
'Task already marked complete in this turn. Ignoring duplicate call.';
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
continue;
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
if (!validationResult.success) {
taskCompleted = false; // Validation failed, revoke completion
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
continue;
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
? outputValue
: JSON.stringify(outputValue, null, 2);
}
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { result: 'Output submitted and task completed.' },
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
name: toolName,
output: 'Output submitted and task completed.',
});
} else {
// Failed to provide required output.
taskCompleted = false; // Revoke completion status
const error = `Missing required argument '${outputName}' for completion.`;
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
}
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
typeof resultArg === 'string'
? resultArg
: JSON.stringify(resultArg, null, 2);
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { status: 'Result submitted and task completed.' },
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
name: toolName,
output: 'Result submitted and task completed.',
});
} else {
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
taskCompleted = false; // Revoke completion
const error =
'Missing required "result" argument. You must provide your findings when calling complete_task.';
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
}
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}
// Handle standard tools
if (!allowedToolNames.has(functionCall.name as string)) {
const error = createUnauthorizedToolError(functionCall.name as string);
if (!allowedToolNames.has(toolName)) {
const error = createUnauthorizedToolError(toolName);
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: functionCall.name as string,
name: toolName,
id: callId,
response: { error },
},
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.emitActivity('ERROR', {
context: 'tool_call_unauthorized',
name: functionCall.name,
name: toolName,
callId,
error,
});
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
continue;
}
const requestInfo: ToolCallRequestInfo = {
toolRequests.push({
callId,
name: functionCall.name as string,
name: toolName,
args,
isClientInitiated: true,
isClientInitiated: false, // These are coming from the subagent (the "model")
prompt_id: promptId,
};
});
toolNameMap.set(callId, toolName);
}
// Create a promise for the tool execution
const executionPromise = (async () => {
const agentContext = Object.create(this.runtimeContext);
agentContext.getToolRegistry = () => this.toolRegistry;
agentContext.getApprovalMode = () => ApprovalMode.YOLO;
const { response: toolResponse } = await executeToolCall(
agentContext,
requestInfo,
// Execute standard tool calls using the new scheduler
if (toolRequests.length > 0) {
const completedCalls = await scheduleAgentTools(
this.runtimeContext,
toolRequests,
{
schedulerId: this.agentId,
parentCallId: this.parentCallId,
toolRegistry: this.toolRegistry,
signal,
);
},
);
if (toolResponse.error) {
for (const call of completedCalls) {
const toolName =
toolNameMap.get(call.request.callId) || call.request.name;
if (call.status === 'success') {
this.emitActivity('TOOL_CALL_END', {
name: toolName,
output: call.response.resultDisplay,
});
} else if (call.status === 'error') {
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
error: toolResponse.error.message,
name: toolName,
error: call.response.error?.message || 'Unknown error',
});
} else {
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
output: toolResponse.resultDisplay,
} else if (call.status === 'cancelled') {
this.emitActivity('ERROR', {
context: 'tool_call',
name: toolName,
error: 'Tool call was cancelled.',
});
}
return toolResponse.responseParts;
})();
toolExecutionPromises.push(executionPromise);
// Add result to syncResults to preserve order later
syncResults.set(call.request.callId, call.response.responseParts[0]);
}
}
// Wait for all tool executions to complete
const asyncResults = await Promise.all(toolExecutionPromises);
// Combine all response parts
const toolResponseParts: Part[] = [...syncResponseParts];
for (const result of asyncResults) {
if (result) {
toolResponseParts.push(...result);
// Reconstruct toolResponseParts in the original order
const toolResponseParts: Part[] = [];
for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`;
const part = syncResults.get(callId);
if (part) {
toolResponseParts.push(part);
}
}