feat(agents): migrate subagents to event-driven scheduler (#17567)

This commit is contained in:
Abhi
2026-01-26 17:12:55 -05:00
committed by GitHub
parent 13bc5f620c
commit 9d34ae52d6
8 changed files with 741 additions and 335 deletions

View File

@@ -0,0 +1,74 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
import { scheduleAgentTools } from './agent-scheduler.js';
import { Scheduler } from '../scheduler/scheduler.js';
import type { Config } from '../config/config.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import type { ToolCallRequestInfo } from '../scheduler/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
vi.mock('../scheduler/scheduler.js', () => ({
Scheduler: vi.fn().mockImplementation(() => ({
schedule: vi.fn().mockResolvedValue([{ status: 'success' }]),
})),
}));
describe('agent-scheduler', () => {
let mockConfig: Mocked<Config>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockMessageBus: Mocked<MessageBus>;
beforeEach(() => {
mockMessageBus = {} as Mocked<MessageBus>;
mockToolRegistry = {
getTool: vi.fn(),
} as unknown as Mocked<ToolRegistry>;
mockConfig = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
} as unknown as Mocked<Config>;
});
it('should create a scheduler with agent-specific config', async () => {
const requests: ToolCallRequestInfo[] = [
{
callId: 'call-1',
name: 'test-tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
];
const options = {
schedulerId: 'subagent-1',
parentCallId: 'parent-1',
toolRegistry: mockToolRegistry as unknown as ToolRegistry,
signal: new AbortController().signal,
};
const results = await scheduleAgentTools(
mockConfig as unknown as Config,
requests,
options,
);
expect(results).toEqual([{ status: 'success' }]);
expect(Scheduler).toHaveBeenCalledWith(
expect.objectContaining({
schedulerId: 'subagent-1',
parentCallId: 'parent-1',
messageBus: mockMessageBus,
}),
);
// Verify that the scheduler's config has the overridden tool registry
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
});
});

View File

@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import { Scheduler } from '../scheduler/scheduler.js';
import type {
ToolCallRequestInfo,
CompletedToolCall,
} from '../scheduler/types.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import type { EditorType } from '../utils/editor.js';
/**
* Options for scheduling agent tools.
*/
export interface AgentSchedulingOptions {
/** The unique ID for this agent's scheduler. */
schedulerId: string;
/** The ID of the tool call that invoked this agent. */
parentCallId?: string;
/** The tool registry specific to this agent. */
toolRegistry: ToolRegistry;
/** AbortSignal for cancellation. */
signal: AbortSignal;
/** Optional function to get the preferred editor for tool modifications. */
getPreferredEditor?: () => EditorType | undefined;
}
/**
* Schedules a batch of tool calls for an agent using the new event-driven Scheduler.
*
* @param config The global runtime configuration.
* @param requests The list of tool call requests from the agent.
* @param options Scheduling options including registry and IDs.
* @returns A promise that resolves to the completed tool calls.
*/
export async function scheduleAgentTools(
config: Config,
requests: ToolCallRequestInfo[],
options: AgentSchedulingOptions,
): Promise<CompletedToolCall[]> {
const {
schedulerId,
parentCallId,
toolRegistry,
signal,
getPreferredEditor,
} = options;
// Create a proxy/override of the config to provide the agent-specific tool registry.
const agentConfig: Config = Object.create(config);
agentConfig.getToolRegistry = () => toolRegistry;
const scheduler = new Scheduler({
config: agentConfig,
messageBus: config.getMessageBus(),
getPreferredEditor: getPreferredEditor ?? (() => undefined),
schedulerId,
parentCallId,
});
return scheduler.schedule(requests, signal);
}

View File

@@ -55,6 +55,7 @@ import type {
} from './types.js';
import { AgentTerminateMode } from './types.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import type { ToolCallRequestInfo } from '../scheduler/types.js';
import { CompressionStatus } from '../core/turn.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import type {
@@ -67,12 +68,12 @@ import type { ModelRouterService } from '../routing/modelRouterService.js';
const {
mockSendMessageStream,
mockExecuteToolCall,
mockScheduleAgentTools,
mockSetSystemInstruction,
mockCompress,
} = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(),
mockScheduleAgentTools: vi.fn(),
mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(),
}));
@@ -101,8 +102,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
};
});
vi.mock('../core/nonInteractiveToolExecutor.js', () => ({
executeToolCall: mockExecuteToolCall,
vi.mock('./agent-scheduler.js', () => ({
scheduleAgentTools: mockScheduleAgentTools,
}));
vi.mock('../utils/version.js', () => ({
@@ -275,7 +276,7 @@ describe('LocalAgentExecutor', () => {
mockSetHistory.mockClear();
mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset();
mockExecuteToolCall.mockReset();
mockScheduleAgentTools.mockReset();
mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset();
mockedPromptIdContext.getStore.mockReset();
@@ -540,34 +541,36 @@ describe('LocalAgentExecutor', () => {
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
'T1: Listing',
);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'file1.txt',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { result: 'file1.txt' },
id: 'call1',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'file1.txt',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { result: 'file1.txt' },
id: 'call1',
},
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
});
]);
// Turn 2: Model calls complete_task with required output
mockModelResponse(
@@ -686,34 +689,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: {},
id: 'call1',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: {},
id: 'call1',
},
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
});
]);
mockModelResponse(
[
@@ -759,34 +764,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: {},
id: 'call1',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: {},
id: 'call1',
},
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
});
]);
// Turn 2 (protocol violation)
mockModelResponse([], 'I think I am done.');
@@ -959,33 +966,40 @@ describe('LocalAgentExecutor', () => {
resolveCalls = r;
});
mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => {
callsStarted++;
if (callsStarted === 2) resolveCalls();
await vi.advanceTimersByTimeAsync(100);
return {
status: 'success',
request: reqInfo,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: reqInfo.callId,
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: reqInfo.name,
response: {},
id: reqInfo.callId,
mockScheduleAgentTools.mockImplementation(
async (_ctx, requests: ToolCallRequestInfo[]) => {
const results = await Promise.all(
requests.map(async (reqInfo) => {
callsStarted++;
if (callsStarted === 2) resolveCalls();
await vi.advanceTimersByTimeAsync(100);
return {
status: 'success',
request: reqInfo,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: reqInfo.callId,
resultDisplay: 'ok',
responseParts: [
{
functionResponse: {
name: reqInfo.name,
response: {},
id: reqInfo.callId,
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
};
});
};
}),
);
return results;
},
);
// Turn 2: Completion
mockModelResponse([
@@ -1005,7 +1019,7 @@ describe('LocalAgentExecutor', () => {
const output = await runPromise;
expect(mockExecuteToolCall).toHaveBeenCalledTimes(2);
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
// Safe access to message parts
@@ -1059,7 +1073,7 @@ describe('LocalAgentExecutor', () => {
await executor.run({ goal: 'Sec test' }, signal);
// Verify external executor was not called (Security held)
expect(mockExecuteToolCall).not.toHaveBeenCalled();
expect(mockScheduleAgentTools).not.toHaveBeenCalled();
// 2. Verify console warning
expect(consoleWarnSpy).toHaveBeenCalledWith(
@@ -1215,37 +1229,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'error',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '/fake' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: '',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { error: toolErrorMessage },
id: 'call1',
},
},
],
error: {
type: 'ToolError',
message: toolErrorMessage,
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'error',
request: {
callId: 'call1',
name: LS_TOOL_NAME,
args: { path: '/fake' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 'call1',
resultDisplay: '',
responseParts: [
{
functionResponse: {
name: LS_TOOL_NAME,
response: { error: toolErrorMessage },
id: 'call1',
},
},
],
error: new Error(toolErrorMessage),
errorType: 'ToolError',
contentLength: 0,
},
errorType: 'ToolError',
contentLength: 0,
},
});
]);
// Turn 2: Model sees the error and completes
mockModelResponse([
@@ -1258,7 +1271,7 @@ describe('LocalAgentExecutor', () => {
const output = await executor.run({ goal: 'Tool failure test' }, signal);
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1);
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
// Verify the error was reported in the activity stream
@@ -1391,28 +1404,30 @@ describe('LocalAgentExecutor', () => {
describe('run (Termination Conditions)', () => {
const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
};
it('should terminate when max_turns is reached', async () => {
@@ -1505,23 +1520,27 @@ describe('LocalAgentExecutor', () => {
]);
// Long running tool
mockExecuteToolCall.mockImplementationOnce(async (_ctx, reqInfo) => {
await vi.advanceTimersByTimeAsync(61 * 1000);
return {
status: 'success',
request: reqInfo,
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 't1',
resultDisplay: 'ok',
responseParts: [],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
};
});
mockScheduleAgentTools.mockImplementationOnce(
async (_ctx, requests: ToolCallRequestInfo[]) => {
await vi.advanceTimersByTimeAsync(61 * 1000);
return [
{
status: 'success',
request: requests[0],
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: 't1',
resultDisplay: 'ok',
responseParts: [],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
];
},
);
// Recovery turn
mockModelResponse([], 'I give up');
@@ -1557,28 +1576,30 @@ describe('LocalAgentExecutor', () => {
describe('run (Recovery Turns)', () => {
const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
};
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
@@ -1873,28 +1894,30 @@ describe('LocalAgentExecutor', () => {
describe('Telemetry and Logging', () => {
const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
};
beforeEach(() => {
@@ -1960,28 +1983,30 @@ describe('LocalAgentExecutor', () => {
describe('Chat Compression', () => {
const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
mockScheduleAgentTools.mockResolvedValueOnce([
{
status: 'success',
request: {
callId: id,
name: LS_TOOL_NAME,
args: { path: '.' },
isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
]);
};
it('should attempt to compress chat history on each turn', async () => {

View File

@@ -15,7 +15,6 @@ import type {
FunctionDeclaration,
Schema,
} from '@google/genai';
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { CompressionStatus } from '../core/turn.js';
import { type ToolCallRequestInfo } from '../scheduler/types.js';
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
import { debugLogger } from '../utils/debugLogger.js';
import { getModelConfigAlias } from './registry.js';
import { getVersion } from '../utils/version.js';
import { ApprovalMode } from '../policy/types.js';
import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
/** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private readonly runtimeContext: Config;
private readonly onActivity?: ActivityCallback;
private readonly compressionService: ChatCompressionService;
private readonly parentCallId?: string;
private hasFailedCompressionAttempt = false;
/**
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
// Get the parent prompt ID from context
const parentPromptId = promptIdContext.getStore();
// Get the parent tool call ID from context
const toolContext = getToolCallContext();
const parentCallId = toolContext?.callId;
return new LocalAgentExecutor(
definition,
runtimeContext,
agentToolRegistry,
parentPromptId,
parentCallId,
onActivity,
);
}
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
runtimeContext: Config,
toolRegistry: ToolRegistry,
parentPromptId: string | undefined,
parentCallId: string | undefined,
onActivity?: ActivityCallback,
) {
this.definition = definition;
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.toolRegistry = toolRegistry;
this.onActivity = onActivity;
this.compressionService = new ChatCompressionService();
this.parentCallId = parentCallId;
const randomIdPart = Math.random().toString(36).slice(2, 8);
// parentPromptId will be undefined if this agent is invoked directly
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
let submittedOutput: string | null = null;
let taskCompleted = false;
// We'll collect promises for the tool executions
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
const syncResponseParts: Part[] = [];
// We'll separate complete_task from other tools
const toolRequests: ToolCallRequestInfo[] = [];
// Map to keep track of tool name by callId for activity emission
const toolNameMap = new Map<string, string>();
// Synchronous results (like complete_task or unauthorized calls)
const syncResults = new Map<string, Part>();
for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`;
const args = functionCall.args ?? {};
const toolName = functionCall.name as string;
this.emitActivity('TOOL_CALL_START', {
name: functionCall.name,
name: toolName,
args,
});
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
if (toolName === TASK_COMPLETE_TOOL_NAME) {
if (taskCompleted) {
// We already have a completion from this turn. Ignore subsequent ones.
const error =
'Task already marked complete in this turn. Ignoring duplicate call.';
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
continue;
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
if (!validationResult.success) {
taskCompleted = false; // Validation failed, revoke completion
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
continue;
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
? outputValue
: JSON.stringify(outputValue, null, 2);
}
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { result: 'Output submitted and task completed.' },
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
name: toolName,
output: 'Output submitted and task completed.',
});
} else {
// Failed to provide required output.
taskCompleted = false; // Revoke completion status
const error = `Missing required argument '${outputName}' for completion.`;
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
}
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
typeof resultArg === 'string'
? resultArg
: JSON.stringify(resultArg, null, 2);
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { status: 'Result submitted and task completed.' },
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
},
});
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
name: toolName,
output: 'Result submitted and task completed.',
});
} else {
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
taskCompleted = false; // Revoke completion
const error =
'Missing required "result" argument. You must provide your findings when calling complete_task.';
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: TASK_COMPLETE_TOOL_NAME,
response: { error },
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
});
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
name: toolName,
error,
});
}
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}
// Handle standard tools
if (!allowedToolNames.has(functionCall.name as string)) {
const error = createUnauthorizedToolError(functionCall.name as string);
if (!allowedToolNames.has(toolName)) {
const error = createUnauthorizedToolError(toolName);
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
syncResponseParts.push({
syncResults.set(callId, {
functionResponse: {
name: functionCall.name as string,
name: toolName,
id: callId,
response: { error },
},
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.emitActivity('ERROR', {
context: 'tool_call_unauthorized',
name: functionCall.name,
name: toolName,
callId,
error,
});
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
continue;
}
const requestInfo: ToolCallRequestInfo = {
toolRequests.push({
callId,
name: functionCall.name as string,
name: toolName,
args,
isClientInitiated: true,
isClientInitiated: false, // These are coming from the subagent (the "model")
prompt_id: promptId,
};
});
toolNameMap.set(callId, toolName);
}
// Create a promise for the tool execution
const executionPromise = (async () => {
const agentContext = Object.create(this.runtimeContext);
agentContext.getToolRegistry = () => this.toolRegistry;
agentContext.getApprovalMode = () => ApprovalMode.YOLO;
const { response: toolResponse } = await executeToolCall(
agentContext,
requestInfo,
// Execute standard tool calls using the new scheduler
if (toolRequests.length > 0) {
const completedCalls = await scheduleAgentTools(
this.runtimeContext,
toolRequests,
{
schedulerId: this.agentId,
parentCallId: this.parentCallId,
toolRegistry: this.toolRegistry,
signal,
);
},
);
if (toolResponse.error) {
for (const call of completedCalls) {
const toolName =
toolNameMap.get(call.request.callId) || call.request.name;
if (call.status === 'success') {
this.emitActivity('TOOL_CALL_END', {
name: toolName,
output: call.response.resultDisplay,
});
} else if (call.status === 'error') {
this.emitActivity('ERROR', {
context: 'tool_call',
name: functionCall.name,
error: toolResponse.error.message,
name: toolName,
error: call.response.error?.message || 'Unknown error',
});
} else {
this.emitActivity('TOOL_CALL_END', {
name: functionCall.name,
output: toolResponse.resultDisplay,
} else if (call.status === 'cancelled') {
this.emitActivity('ERROR', {
context: 'tool_call',
name: toolName,
error: 'Tool call was cancelled.',
});
}
return toolResponse.responseParts;
})();
toolExecutionPromises.push(executionPromise);
// Add result to syncResults to preserve order later
syncResults.set(call.request.callId, call.response.responseParts[0]);
}
}
// Wait for all tool executions to complete
const asyncResults = await Promise.all(toolExecutionPromises);
// Combine all response parts
const toolResponseParts: Part[] = [...syncResponseParts];
for (const result of asyncResults) {
if (result) {
toolResponseParts.push(...result);
// Reconstruct toolResponseParts in the original order
const toolResponseParts: Part[] = [];
for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`;
const part = syncResults.get(callId);
if (part) {
toolResponseParts.push(part);
}
}

View File

@@ -70,6 +70,10 @@ import { ROOT_SCHEDULER_ID } from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
import * as ToolUtils from '../utils/tool-utils.js';
import type { EditorType } from '../utils/editor.js';
import {
getToolCallContext,
type ToolCallContext,
} from '../utils/toolCallContext.js';
describe('Scheduler (Orchestrator)', () => {
let scheduler: Scheduler;
@@ -1010,4 +1014,68 @@ describe('Scheduler (Orchestrator)', () => {
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
});
});
describe('Tool Call Context Propagation', () => {
it('should propagate context to the tool executor', async () => {
const schedulerId = 'custom-scheduler';
const parentCallId = 'parent-call';
const customScheduler = new Scheduler({
config: mockConfig,
messageBus: mockMessageBus,
getPreferredEditor,
schedulerId,
parentCallId,
});
const validatingCall: ValidatingToolCall = {
status: 'validating',
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Mock queueLength to run the loop once
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(validatingCall),
configurable: true,
});
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
mockToolRegistry.getTool.mockReturnValue(mockTool);
mockPolicyEngine.check.mockResolvedValue({
decision: PolicyDecision.ALLOW,
});
let capturedContext: ToolCallContext | undefined;
mockExecutor.execute.mockImplementation(async () => {
capturedContext = getToolCallContext();
return {
status: 'success',
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
response: {
callId: req1.callId,
responseParts: [],
resultDisplay: 'ok',
error: undefined,
errorType: undefined,
},
} as unknown as SuccessfulToolCall;
});
await customScheduler.schedule(req1, signal);
expect(capturedContext).toBeDefined();
expect(capturedContext!.callId).toBe(req1.callId);
expect(capturedContext!.schedulerId).toBe(schedulerId);
expect(capturedContext!.parentCallId).toBe(parentCallId);
});
});
});

View File

@@ -36,6 +36,7 @@ import {
type SerializableConfirmationDetails,
type ToolConfirmationRequest,
} from '../confirmation-bus/types.js';
import { runWithToolCallContext } from '../utils/toolCallContext.js';
interface SchedulerQueueItem {
requests: ToolCallRequestInfo[];
@@ -256,6 +257,7 @@ export class Scheduler {
return this.state.completedBatch;
} finally {
this.isProcessing = false;
this.state.clearBatch();
this._processNextInRequestQueue();
}
}
@@ -282,30 +284,39 @@ export class Scheduler {
request: ToolCallRequestInfo,
tool: AnyDeclarativeTool,
): ValidatingToolCall | ErroredToolCall {
try {
const invocation = tool.build(request.args);
return {
status: 'validating',
request,
tool,
invocation,
startTime: Date.now(),
return runWithToolCallContext(
{
callId: request.callId,
schedulerId: this.schedulerId,
};
} catch (e) {
return {
status: 'error',
request,
tool,
response: createErrorResponse(
request,
e instanceof Error ? e : new Error(String(e)),
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
schedulerId: this.schedulerId,
};
}
parentCallId: this.parentCallId,
},
() => {
try {
const invocation = tool.build(request.args);
return {
status: 'validating',
request,
tool,
invocation,
startTime: Date.now(),
schedulerId: this.schedulerId,
};
} catch (e) {
return {
status: 'error',
request,
tool,
response: createErrorResponse(
request,
e instanceof Error ? e : new Error(String(e)),
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
schedulerId: this.schedulerId,
};
}
},
);
}
// --- Phase 2: Processing Loop ---
@@ -460,17 +471,29 @@ export class Scheduler {
if (signal.aborted) throw new Error('Operation cancelled');
this.state.updateStatus(callId, 'executing');
const result = await this.executor.execute({
call: this.state.firstActiveCall as ExecutingToolCall,
signal,
outputUpdateHandler: (id, out) =>
this.state.updateStatus(id, 'executing', { liveOutput: out }),
onUpdateToolCall: (updated) => {
if (updated.status === 'executing' && updated.pid) {
this.state.updateStatus(callId, 'executing', { pid: updated.pid });
}
const activeCall = this.state.firstActiveCall as ExecutingToolCall;
const result = await runWithToolCallContext(
{
callId: activeCall.request.callId,
schedulerId: this.schedulerId,
parentCallId: this.parentCallId,
},
});
() =>
this.executor.execute({
call: activeCall,
signal,
outputUpdateHandler: (id, out) =>
this.state.updateStatus(id, 'executing', { liveOutput: out }),
onUpdateToolCall: (updated) => {
if (updated.status === 'executing' && updated.pid) {
this.state.updateStatus(callId, 'executing', {
pid: updated.pid,
});
}
},
}),
);
if (result.status === 'success') {
this.state.updateStatus(callId, 'success', result.response);

View File

@@ -0,0 +1,84 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import {
runWithToolCallContext,
getToolCallContext,
} from './toolCallContext.js';
describe('toolCallContext', () => {
it('should store and retrieve tool call context', () => {
const context = {
callId: 'test-call-id',
schedulerId: 'test-scheduler-id',
};
runWithToolCallContext(context, () => {
const storedContext = getToolCallContext();
expect(storedContext).toEqual(context);
});
});
it('should return undefined when no context is set', () => {
expect(getToolCallContext()).toBeUndefined();
});
it('should support nested contexts', () => {
const parentContext = {
callId: 'parent-call-id',
schedulerId: 'parent-scheduler-id',
};
const childContext = {
callId: 'child-call-id',
schedulerId: 'child-scheduler-id',
parentCallId: 'parent-call-id',
};
runWithToolCallContext(parentContext, () => {
expect(getToolCallContext()).toEqual(parentContext);
runWithToolCallContext(childContext, () => {
expect(getToolCallContext()).toEqual(childContext);
});
expect(getToolCallContext()).toEqual(parentContext);
});
});
it('should maintain isolation between parallel executions', async () => {
const context1 = {
callId: 'call-1',
schedulerId: 'scheduler-1',
};
const context2 = {
callId: 'call-2',
schedulerId: 'scheduler-2',
};
const promise1 = new Promise<void>((resolve) => {
runWithToolCallContext(context1, () => {
setTimeout(() => {
expect(getToolCallContext()).toEqual(context1);
resolve();
}, 10);
});
});
const promise2 = new Promise<void>((resolve) => {
runWithToolCallContext(context2, () => {
setTimeout(() => {
expect(getToolCallContext()).toEqual(context2);
resolve();
}, 5);
});
});
await Promise.all([promise1, promise2]);
});
});

View File

@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { AsyncLocalStorage } from 'node:async_hooks';
/**
* Contextual information for a tool call execution.
*/
export interface ToolCallContext {
/** The unique ID of the tool call. */
callId: string;
/** The ID of the scheduler managing the execution. */
schedulerId: string;
/** The ID of the parent tool call, if this is a nested execution (e.g., in a subagent). */
parentCallId?: string;
}
/**
* AsyncLocalStorage instance for tool call context.
*/
export const toolCallContext = new AsyncLocalStorage<ToolCallContext>();
/**
* Runs a function within a tool call context.
*
* @param context The context to set.
* @param fn The function to run.
* @returns The result of the function.
*/
export function runWithToolCallContext<T>(
context: ToolCallContext,
fn: () => T,
): T {
return toolCallContext.run(context, fn);
}
/**
* Retrieves the current tool call context.
*
* @returns The current ToolCallContext, or undefined if not in a context.
*/
export function getToolCallContext(): ToolCallContext | undefined {
return toolCallContext.getStore();
}