diff --git a/packages/core/src/agents/agent-scheduler.test.ts b/packages/core/src/agents/agent-scheduler.test.ts new file mode 100644 index 0000000000..5edcb664b6 --- /dev/null +++ b/packages/core/src/agents/agent-scheduler.test.ts @@ -0,0 +1,74 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest'; +import { scheduleAgentTools } from './agent-scheduler.js'; +import { Scheduler } from '../scheduler/scheduler.js'; +import type { Config } from '../config/config.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; +import type { ToolCallRequestInfo } from '../scheduler/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; + +vi.mock('../scheduler/scheduler.js', () => ({ + Scheduler: vi.fn().mockImplementation(() => ({ + schedule: vi.fn().mockResolvedValue([{ status: 'success' }]), + })), +})); + +describe('agent-scheduler', () => { + let mockConfig: Mocked; + let mockToolRegistry: Mocked; + let mockMessageBus: Mocked; + + beforeEach(() => { + mockMessageBus = {} as Mocked; + mockToolRegistry = { + getTool: vi.fn(), + } as unknown as Mocked; + mockConfig = { + getMessageBus: vi.fn().mockReturnValue(mockMessageBus), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + } as unknown as Mocked; + }); + + it('should create a scheduler with agent-specific config', async () => { + const requests: ToolCallRequestInfo[] = [ + { + callId: 'call-1', + name: 'test-tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + ]; + + const options = { + schedulerId: 'subagent-1', + parentCallId: 'parent-1', + toolRegistry: mockToolRegistry as unknown as ToolRegistry, + signal: new AbortController().signal, + }; + + const results = await scheduleAgentTools( + mockConfig as unknown as Config, + requests, + options, + ); + + expect(results).toEqual([{ status: 'success' }]); + expect(Scheduler).toHaveBeenCalledWith( + expect.objectContaining({ + schedulerId: 'subagent-1', + parentCallId: 'parent-1', + messageBus: mockMessageBus, + }), + ); + + // Verify that the scheduler's config has the overridden tool registry + const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config; + expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry); + }); +}); diff --git a/packages/core/src/agents/agent-scheduler.ts b/packages/core/src/agents/agent-scheduler.ts new file mode 100644 index 0000000000..c3201b7255 --- /dev/null +++ b/packages/core/src/agents/agent-scheduler.ts @@ -0,0 +1,66 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import { Scheduler } from '../scheduler/scheduler.js'; +import type { + ToolCallRequestInfo, + CompletedToolCall, +} from '../scheduler/types.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; +import type { EditorType } from '../utils/editor.js'; + +/** + * Options for scheduling agent tools. + */ +export interface AgentSchedulingOptions { + /** The unique ID for this agent's scheduler. */ + schedulerId: string; + /** The ID of the tool call that invoked this agent. */ + parentCallId?: string; + /** The tool registry specific to this agent. */ + toolRegistry: ToolRegistry; + /** AbortSignal for cancellation. */ + signal: AbortSignal; + /** Optional function to get the preferred editor for tool modifications. */ + getPreferredEditor?: () => EditorType | undefined; +} + +/** + * Schedules a batch of tool calls for an agent using the new event-driven Scheduler. + * + * @param config The global runtime configuration. + * @param requests The list of tool call requests from the agent. + * @param options Scheduling options including registry and IDs. + * @returns A promise that resolves to the completed tool calls. + */ +export async function scheduleAgentTools( + config: Config, + requests: ToolCallRequestInfo[], + options: AgentSchedulingOptions, +): Promise { + const { + schedulerId, + parentCallId, + toolRegistry, + signal, + getPreferredEditor, + } = options; + + // Create a proxy/override of the config to provide the agent-specific tool registry. + const agentConfig: Config = Object.create(config); + agentConfig.getToolRegistry = () => toolRegistry; + + const scheduler = new Scheduler({ + config: agentConfig, + messageBus: config.getMessageBus(), + getPreferredEditor: getPreferredEditor ?? (() => undefined), + schedulerId, + parentCallId, + }); + + return scheduler.schedule(requests, signal); +} diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index b9e6488c1e..3cb7b188fd 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -55,6 +55,7 @@ import type { } from './types.js'; import { AgentTerminateMode } from './types.js'; import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; +import type { ToolCallRequestInfo } from '../scheduler/types.js'; import { CompressionStatus } from '../core/turn.js'; import { ChatCompressionService } from '../services/chatCompressionService.js'; import type { @@ -67,12 +68,12 @@ import type { ModelRouterService } from '../routing/modelRouterService.js'; const { mockSendMessageStream, - mockExecuteToolCall, + mockScheduleAgentTools, mockSetSystemInstruction, mockCompress, } = vi.hoisted(() => ({ mockSendMessageStream: vi.fn(), - mockExecuteToolCall: vi.fn(), + mockScheduleAgentTools: vi.fn(), mockSetSystemInstruction: vi.fn(), mockCompress: vi.fn(), })); @@ -101,8 +102,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => { }; }); -vi.mock('../core/nonInteractiveToolExecutor.js', () => ({ - executeToolCall: mockExecuteToolCall, +vi.mock('./agent-scheduler.js', () => ({ + scheduleAgentTools: mockScheduleAgentTools, })); vi.mock('../utils/version.js', () => ({ @@ -275,7 +276,7 @@ describe('LocalAgentExecutor', () => { mockSetHistory.mockClear(); mockSendMessageStream.mockReset(); mockSetSystemInstruction.mockReset(); - mockExecuteToolCall.mockReset(); + mockScheduleAgentTools.mockReset(); mockedLogAgentStart.mockReset(); mockedLogAgentFinish.mockReset(); mockedPromptIdContext.getStore.mockReset(); @@ -540,34 +541,36 @@ describe('LocalAgentExecutor', () => { [{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }], 'T1: Listing', ); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: 'call1', - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', - }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'call1', - resultDisplay: 'file1.txt', - responseParts: [ - { - functionResponse: { - name: LS_TOOL_NAME, - response: { result: 'file1.txt' }, - id: 'call1', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'call1', + resultDisplay: 'file1.txt', + responseParts: [ + { + functionResponse: { + name: LS_TOOL_NAME, + response: { result: 'file1.txt' }, + id: 'call1', + }, }, - }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - }); + ]); // Turn 2: Model calls complete_task with required output mockModelResponse( @@ -686,34 +689,36 @@ describe('LocalAgentExecutor', () => { mockModelResponse([ { name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }, ]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: 'call1', - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', - }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'call1', - resultDisplay: 'ok', - responseParts: [ - { - functionResponse: { - name: LS_TOOL_NAME, - response: {}, - id: 'call1', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'call1', + resultDisplay: 'ok', + responseParts: [ + { + functionResponse: { + name: LS_TOOL_NAME, + response: {}, + id: 'call1', + }, }, - }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - }); + ]); mockModelResponse( [ @@ -759,34 +764,36 @@ describe('LocalAgentExecutor', () => { mockModelResponse([ { name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }, ]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: 'call1', - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', - }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'call1', - resultDisplay: 'ok', - responseParts: [ - { - functionResponse: { - name: LS_TOOL_NAME, - response: {}, - id: 'call1', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'call1', + resultDisplay: 'ok', + responseParts: [ + { + functionResponse: { + name: LS_TOOL_NAME, + response: {}, + id: 'call1', + }, }, - }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - }); + ]); // Turn 2 (protocol violation) mockModelResponse([], 'I think I am done.'); @@ -959,33 +966,40 @@ describe('LocalAgentExecutor', () => { resolveCalls = r; }); - mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => { - callsStarted++; - if (callsStarted === 2) resolveCalls(); - await vi.advanceTimersByTimeAsync(100); - return { - status: 'success', - request: reqInfo, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: reqInfo.callId, - resultDisplay: 'ok', - responseParts: [ - { - functionResponse: { - name: reqInfo.name, - response: {}, - id: reqInfo.callId, + mockScheduleAgentTools.mockImplementation( + async (_ctx, requests: ToolCallRequestInfo[]) => { + const results = await Promise.all( + requests.map(async (reqInfo) => { + callsStarted++; + if (callsStarted === 2) resolveCalls(); + await vi.advanceTimersByTimeAsync(100); + return { + status: 'success', + request: reqInfo, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: reqInfo.callId, + resultDisplay: 'ok', + responseParts: [ + { + functionResponse: { + name: reqInfo.name, + response: {}, + id: reqInfo.callId, + }, + }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, }, - }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }; - }); + }; + }), + ); + return results; + }, + ); // Turn 2: Completion mockModelResponse([ @@ -1005,7 +1019,7 @@ describe('LocalAgentExecutor', () => { const output = await runPromise; - expect(mockExecuteToolCall).toHaveBeenCalledTimes(2); + expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1); expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); // Safe access to message parts @@ -1059,7 +1073,7 @@ describe('LocalAgentExecutor', () => { await executor.run({ goal: 'Sec test' }, signal); // Verify external executor was not called (Security held) - expect(mockExecuteToolCall).not.toHaveBeenCalled(); + expect(mockScheduleAgentTools).not.toHaveBeenCalled(); // 2. Verify console warning expect(consoleWarnSpy).toHaveBeenCalledWith( @@ -1215,37 +1229,36 @@ describe('LocalAgentExecutor', () => { mockModelResponse([ { name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' }, ]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'error', - request: { - callId: 'call1', - name: LS_TOOL_NAME, - args: { path: '/fake' }, - isClientInitiated: false, - prompt_id: 'test-prompt', - }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 'call1', - resultDisplay: '', - responseParts: [ - { - functionResponse: { - name: LS_TOOL_NAME, - response: { error: toolErrorMessage }, - id: 'call1', - }, - }, - ], - error: { - type: 'ToolError', - message: toolErrorMessage, + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'error', + request: { + callId: 'call1', + name: LS_TOOL_NAME, + args: { path: '/fake' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 'call1', + resultDisplay: '', + responseParts: [ + { + functionResponse: { + name: LS_TOOL_NAME, + response: { error: toolErrorMessage }, + id: 'call1', + }, + }, + ], + error: new Error(toolErrorMessage), + errorType: 'ToolError', + contentLength: 0, }, - errorType: 'ToolError', - contentLength: 0, }, - }); + ]); // Turn 2: Model sees the error and completes mockModelResponse([ @@ -1258,7 +1271,7 @@ describe('LocalAgentExecutor', () => { const output = await executor.run({ goal: 'Tool failure test' }, signal); - expect(mockExecuteToolCall).toHaveBeenCalledTimes(1); + expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1); expect(mockSendMessageStream).toHaveBeenCalledTimes(2); // Verify the error was reported in the activity stream @@ -1391,28 +1404,30 @@ describe('LocalAgentExecutor', () => { describe('run (Termination Conditions)', () => { const mockWorkResponse = (id: string) => { mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: id, - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: id, - resultDisplay: 'ok', - responseParts: [ - { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); }; it('should terminate when max_turns is reached', async () => { @@ -1505,23 +1520,27 @@ describe('LocalAgentExecutor', () => { ]); // Long running tool - mockExecuteToolCall.mockImplementationOnce(async (_ctx, reqInfo) => { - await vi.advanceTimersByTimeAsync(61 * 1000); - return { - status: 'success', - request: reqInfo, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: 't1', - resultDisplay: 'ok', - responseParts: [], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }; - }); + mockScheduleAgentTools.mockImplementationOnce( + async (_ctx, requests: ToolCallRequestInfo[]) => { + await vi.advanceTimersByTimeAsync(61 * 1000); + return [ + { + status: 'success', + request: requests[0], + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: 't1', + resultDisplay: 'ok', + responseParts: [], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, + }, + ]; + }, + ); // Recovery turn mockModelResponse([], 'I give up'); @@ -1557,28 +1576,30 @@ describe('LocalAgentExecutor', () => { describe('run (Recovery Turns)', () => { const mockWorkResponse = (id: string) => { mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: id, - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: id, - resultDisplay: 'ok', - responseParts: [ - { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); }; it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => { @@ -1873,28 +1894,30 @@ describe('LocalAgentExecutor', () => { describe('Telemetry and Logging', () => { const mockWorkResponse = (id: string) => { mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: id, - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: id, - resultDisplay: 'ok', - responseParts: [ - { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); }; beforeEach(() => { @@ -1960,28 +1983,30 @@ describe('LocalAgentExecutor', () => { describe('Chat Compression', () => { const mockWorkResponse = (id: string) => { mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); - mockExecuteToolCall.mockResolvedValueOnce({ - status: 'success', - request: { - callId: id, - name: LS_TOOL_NAME, - args: { path: '.' }, - isClientInitiated: false, - prompt_id: 'test-prompt', + mockScheduleAgentTools.mockResolvedValueOnce([ + { + status: 'success', + request: { + callId: id, + name: LS_TOOL_NAME, + args: { path: '.' }, + isClientInitiated: false, + prompt_id: 'test-prompt', + }, + tool: {} as AnyDeclarativeTool, + invocation: {} as AnyToolInvocation, + response: { + callId: id, + resultDisplay: 'ok', + responseParts: [ + { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, + ], + error: undefined, + errorType: undefined, + contentLength: undefined, + }, }, - tool: {} as AnyDeclarativeTool, - invocation: {} as AnyToolInvocation, - response: { - callId: id, - resultDisplay: 'ok', - responseParts: [ - { functionResponse: { name: LS_TOOL_NAME, response: {}, id } }, - ], - error: undefined, - errorType: undefined, - contentLength: undefined, - }, - }); + ]); }; it('should attempt to compress chat history on each turn', async () => { diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index a75a92a4ec..e22143ac54 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -15,7 +15,6 @@ import type { FunctionDeclaration, Schema, } from '@google/genai'; -import { executeToolCall } from '../core/nonInteractiveToolExecutor.js'; import { ToolRegistry } from '../tools/tool-registry.js'; import { CompressionStatus } from '../core/turn.js'; import { type ToolCallRequestInfo } from '../scheduler/types.js'; @@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema'; import { debugLogger } from '../utils/debugLogger.js'; import { getModelConfigAlias } from './registry.js'; import { getVersion } from '../utils/version.js'; -import { ApprovalMode } from '../policy/types.js'; +import { getToolCallContext } from '../utils/toolCallContext.js'; +import { scheduleAgentTools } from './agent-scheduler.js'; /** A callback function to report on agent activity. */ export type ActivityCallback = (activity: SubagentActivityEvent) => void; @@ -86,6 +86,7 @@ export class LocalAgentExecutor { private readonly runtimeContext: Config; private readonly onActivity?: ActivityCallback; private readonly compressionService: ChatCompressionService; + private readonly parentCallId?: string; private hasFailedCompressionAttempt = false; /** @@ -158,11 +159,16 @@ export class LocalAgentExecutor { // Get the parent prompt ID from context const parentPromptId = promptIdContext.getStore(); + // Get the parent tool call ID from context + const toolContext = getToolCallContext(); + const parentCallId = toolContext?.callId; + return new LocalAgentExecutor( definition, runtimeContext, agentToolRegistry, parentPromptId, + parentCallId, onActivity, ); } @@ -178,6 +184,7 @@ export class LocalAgentExecutor { runtimeContext: Config, toolRegistry: ToolRegistry, parentPromptId: string | undefined, + parentCallId: string | undefined, onActivity?: ActivityCallback, ) { this.definition = definition; @@ -185,6 +192,7 @@ export class LocalAgentExecutor { this.toolRegistry = toolRegistry; this.onActivity = onActivity; this.compressionService = new ChatCompressionService(); + this.parentCallId = parentCallId; const randomIdPart = Math.random().toString(36).slice(2, 8); // parentPromptId will be undefined if this agent is invoked directly @@ -763,26 +771,28 @@ export class LocalAgentExecutor { let submittedOutput: string | null = null; let taskCompleted = false; - // We'll collect promises for the tool executions - const toolExecutionPromises: Array> = []; - // And we'll need a place to store the synchronous results (like complete_task or blocked calls) - const syncResponseParts: Part[] = []; + // We'll separate complete_task from other tools + const toolRequests: ToolCallRequestInfo[] = []; + // Map to keep track of tool name by callId for activity emission + const toolNameMap = new Map(); + // Synchronous results (like complete_task or unauthorized calls) + const syncResults = new Map(); for (const [index, functionCall] of functionCalls.entries()) { const callId = functionCall.id ?? `${promptId}-${index}`; const args = functionCall.args ?? {}; + const toolName = functionCall.name as string; this.emitActivity('TOOL_CALL_START', { - name: functionCall.name, + name: toolName, args, }); - if (functionCall.name === TASK_COMPLETE_TOOL_NAME) { + if (toolName === TASK_COMPLETE_TOOL_NAME) { if (taskCompleted) { - // We already have a completion from this turn. Ignore subsequent ones. const error = 'Task already marked complete in this turn. Ignoring duplicate call.'; - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { error }, @@ -791,7 +801,7 @@ export class LocalAgentExecutor { }); this.emitActivity('ERROR', { context: 'tool_call', - name: functionCall.name, + name: toolName, error, }); continue; @@ -809,7 +819,7 @@ export class LocalAgentExecutor { if (!validationResult.success) { taskCompleted = false; // Validation failed, revoke completion const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`; - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { error }, @@ -818,7 +828,7 @@ export class LocalAgentExecutor { }); this.emitActivity('ERROR', { context: 'tool_call', - name: functionCall.name, + name: toolName, error, }); continue; @@ -833,7 +843,7 @@ export class LocalAgentExecutor { ? outputValue : JSON.stringify(outputValue, null, 2); } - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { result: 'Output submitted and task completed.' }, @@ -841,14 +851,14 @@ export class LocalAgentExecutor { }, }); this.emitActivity('TOOL_CALL_END', { - name: functionCall.name, + name: toolName, output: 'Output submitted and task completed.', }); } else { // Failed to provide required output. taskCompleted = false; // Revoke completion status const error = `Missing required argument '${outputName}' for completion.`; - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { error }, @@ -857,7 +867,7 @@ export class LocalAgentExecutor { }); this.emitActivity('ERROR', { context: 'tool_call', - name: functionCall.name, + name: toolName, error, }); } @@ -873,7 +883,7 @@ export class LocalAgentExecutor { typeof resultArg === 'string' ? resultArg : JSON.stringify(resultArg, null, 2); - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { status: 'Result submitted and task completed.' }, @@ -881,7 +891,7 @@ export class LocalAgentExecutor { }, }); this.emitActivity('TOOL_CALL_END', { - name: functionCall.name, + name: toolName, output: 'Result submitted and task completed.', }); } else { @@ -889,7 +899,7 @@ export class LocalAgentExecutor { taskCompleted = false; // Revoke completion const error = 'Missing required "result" argument. You must provide your findings when calling complete_task.'; - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { name: TASK_COMPLETE_TOOL_NAME, response: { error }, @@ -898,7 +908,7 @@ export class LocalAgentExecutor { }); this.emitActivity('ERROR', { context: 'tool_call', - name: functionCall.name, + name: toolName, error, }); } @@ -907,14 +917,13 @@ export class LocalAgentExecutor { } // Handle standard tools - if (!allowedToolNames.has(functionCall.name as string)) { - const error = createUnauthorizedToolError(functionCall.name as string); - + if (!allowedToolNames.has(toolName)) { + const error = createUnauthorizedToolError(toolName); debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`); - syncResponseParts.push({ + syncResults.set(callId, { functionResponse: { - name: functionCall.name as string, + name: toolName, id: callId, response: { error }, }, @@ -922,7 +931,7 @@ export class LocalAgentExecutor { this.emitActivity('ERROR', { context: 'tool_call_unauthorized', - name: functionCall.name, + name: toolName, callId, error, }); @@ -930,53 +939,63 @@ export class LocalAgentExecutor { continue; } - const requestInfo: ToolCallRequestInfo = { + toolRequests.push({ callId, - name: functionCall.name as string, + name: toolName, args, - isClientInitiated: true, + isClientInitiated: false, // These are coming from the subagent (the "model") prompt_id: promptId, - }; + }); + toolNameMap.set(callId, toolName); + } - // Create a promise for the tool execution - const executionPromise = (async () => { - const agentContext = Object.create(this.runtimeContext); - agentContext.getToolRegistry = () => this.toolRegistry; - agentContext.getApprovalMode = () => ApprovalMode.YOLO; - - const { response: toolResponse } = await executeToolCall( - agentContext, - requestInfo, + // Execute standard tool calls using the new scheduler + if (toolRequests.length > 0) { + const completedCalls = await scheduleAgentTools( + this.runtimeContext, + toolRequests, + { + schedulerId: this.agentId, + parentCallId: this.parentCallId, + toolRegistry: this.toolRegistry, signal, - ); + }, + ); - if (toolResponse.error) { + for (const call of completedCalls) { + const toolName = + toolNameMap.get(call.request.callId) || call.request.name; + if (call.status === 'success') { + this.emitActivity('TOOL_CALL_END', { + name: toolName, + output: call.response.resultDisplay, + }); + } else if (call.status === 'error') { this.emitActivity('ERROR', { context: 'tool_call', - name: functionCall.name, - error: toolResponse.error.message, + name: toolName, + error: call.response.error?.message || 'Unknown error', }); - } else { - this.emitActivity('TOOL_CALL_END', { - name: functionCall.name, - output: toolResponse.resultDisplay, + } else if (call.status === 'cancelled') { + this.emitActivity('ERROR', { + context: 'tool_call', + name: toolName, + error: 'Tool call was cancelled.', }); } - return toolResponse.responseParts; - })(); - - toolExecutionPromises.push(executionPromise); + // Add result to syncResults to preserve order later + syncResults.set(call.request.callId, call.response.responseParts[0]); + } } - // Wait for all tool executions to complete - const asyncResults = await Promise.all(toolExecutionPromises); - - // Combine all response parts - const toolResponseParts: Part[] = [...syncResponseParts]; - for (const result of asyncResults) { - if (result) { - toolResponseParts.push(...result); + // Reconstruct toolResponseParts in the original order + const toolResponseParts: Part[] = []; + for (const [index, functionCall] of functionCalls.entries()) { + const callId = functionCall.id ?? `${promptId}-${index}`; + const part = syncResults.get(callId); + if (part) { + toolResponseParts.push(part); } } diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts index 95b6470d1b..45884f1de0 100644 --- a/packages/core/src/scheduler/scheduler.test.ts +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -70,6 +70,10 @@ import { ROOT_SCHEDULER_ID } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; import * as ToolUtils from '../utils/tool-utils.js'; import type { EditorType } from '../utils/editor.js'; +import { + getToolCallContext, + type ToolCallContext, +} from '../utils/toolCallContext.js'; describe('Scheduler (Orchestrator)', () => { let scheduler: Scheduler; @@ -1010,4 +1014,68 @@ describe('Scheduler (Orchestrator)', () => { expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); }); }); + + describe('Tool Call Context Propagation', () => { + it('should propagate context to the tool executor', async () => { + const schedulerId = 'custom-scheduler'; + const parentCallId = 'parent-call'; + const customScheduler = new Scheduler({ + config: mockConfig, + messageBus: mockMessageBus, + getPreferredEditor, + schedulerId, + parentCallId, + }); + + const validatingCall: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + // Mock queueLength to run the loop once + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0), + configurable: true, + }); + + vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(validatingCall), + configurable: true, + }); + vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall); + + mockToolRegistry.getTool.mockReturnValue(mockTool); + mockPolicyEngine.check.mockResolvedValue({ + decision: PolicyDecision.ALLOW, + }); + + let capturedContext: ToolCallContext | undefined; + mockExecutor.execute.mockImplementation(async () => { + capturedContext = getToolCallContext(); + return { + status: 'success', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + response: { + callId: req1.callId, + responseParts: [], + resultDisplay: 'ok', + error: undefined, + errorType: undefined, + }, + } as unknown as SuccessfulToolCall; + }); + + await customScheduler.schedule(req1, signal); + + expect(capturedContext).toBeDefined(); + expect(capturedContext!.callId).toBe(req1.callId); + expect(capturedContext!.schedulerId).toBe(schedulerId); + expect(capturedContext!.parentCallId).toBe(parentCallId); + }); + }); }); diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index a8d295b1f9..5853736a01 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -36,6 +36,7 @@ import { type SerializableConfirmationDetails, type ToolConfirmationRequest, } from '../confirmation-bus/types.js'; +import { runWithToolCallContext } from '../utils/toolCallContext.js'; interface SchedulerQueueItem { requests: ToolCallRequestInfo[]; @@ -256,6 +257,7 @@ export class Scheduler { return this.state.completedBatch; } finally { this.isProcessing = false; + this.state.clearBatch(); this._processNextInRequestQueue(); } } @@ -282,30 +284,39 @@ export class Scheduler { request: ToolCallRequestInfo, tool: AnyDeclarativeTool, ): ValidatingToolCall | ErroredToolCall { - try { - const invocation = tool.build(request.args); - return { - status: 'validating', - request, - tool, - invocation, - startTime: Date.now(), + return runWithToolCallContext( + { + callId: request.callId, schedulerId: this.schedulerId, - }; - } catch (e) { - return { - status: 'error', - request, - tool, - response: createErrorResponse( - request, - e instanceof Error ? e : new Error(String(e)), - ToolErrorType.INVALID_TOOL_PARAMS, - ), - durationMs: 0, - schedulerId: this.schedulerId, - }; - } + parentCallId: this.parentCallId, + }, + () => { + try { + const invocation = tool.build(request.args); + return { + status: 'validating', + request, + tool, + invocation, + startTime: Date.now(), + schedulerId: this.schedulerId, + }; + } catch (e) { + return { + status: 'error', + request, + tool, + response: createErrorResponse( + request, + e instanceof Error ? e : new Error(String(e)), + ToolErrorType.INVALID_TOOL_PARAMS, + ), + durationMs: 0, + schedulerId: this.schedulerId, + }; + } + }, + ); } // --- Phase 2: Processing Loop --- @@ -460,17 +471,29 @@ export class Scheduler { if (signal.aborted) throw new Error('Operation cancelled'); this.state.updateStatus(callId, 'executing'); - const result = await this.executor.execute({ - call: this.state.firstActiveCall as ExecutingToolCall, - signal, - outputUpdateHandler: (id, out) => - this.state.updateStatus(id, 'executing', { liveOutput: out }), - onUpdateToolCall: (updated) => { - if (updated.status === 'executing' && updated.pid) { - this.state.updateStatus(callId, 'executing', { pid: updated.pid }); - } + const activeCall = this.state.firstActiveCall as ExecutingToolCall; + + const result = await runWithToolCallContext( + { + callId: activeCall.request.callId, + schedulerId: this.schedulerId, + parentCallId: this.parentCallId, }, - }); + () => + this.executor.execute({ + call: activeCall, + signal, + outputUpdateHandler: (id, out) => + this.state.updateStatus(id, 'executing', { liveOutput: out }), + onUpdateToolCall: (updated) => { + if (updated.status === 'executing' && updated.pid) { + this.state.updateStatus(callId, 'executing', { + pid: updated.pid, + }); + } + }, + }), + ); if (result.status === 'success') { this.state.updateStatus(callId, 'success', result.response); diff --git a/packages/core/src/utils/toolCallContext.test.ts b/packages/core/src/utils/toolCallContext.test.ts new file mode 100644 index 0000000000..e649a216c7 --- /dev/null +++ b/packages/core/src/utils/toolCallContext.test.ts @@ -0,0 +1,84 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { + runWithToolCallContext, + getToolCallContext, +} from './toolCallContext.js'; + +describe('toolCallContext', () => { + it('should store and retrieve tool call context', () => { + const context = { + callId: 'test-call-id', + schedulerId: 'test-scheduler-id', + }; + + runWithToolCallContext(context, () => { + const storedContext = getToolCallContext(); + expect(storedContext).toEqual(context); + }); + }); + + it('should return undefined when no context is set', () => { + expect(getToolCallContext()).toBeUndefined(); + }); + + it('should support nested contexts', () => { + const parentContext = { + callId: 'parent-call-id', + schedulerId: 'parent-scheduler-id', + }; + + const childContext = { + callId: 'child-call-id', + schedulerId: 'child-scheduler-id', + parentCallId: 'parent-call-id', + }; + + runWithToolCallContext(parentContext, () => { + expect(getToolCallContext()).toEqual(parentContext); + + runWithToolCallContext(childContext, () => { + expect(getToolCallContext()).toEqual(childContext); + }); + + expect(getToolCallContext()).toEqual(parentContext); + }); + }); + + it('should maintain isolation between parallel executions', async () => { + const context1 = { + callId: 'call-1', + schedulerId: 'scheduler-1', + }; + + const context2 = { + callId: 'call-2', + schedulerId: 'scheduler-2', + }; + + const promise1 = new Promise((resolve) => { + runWithToolCallContext(context1, () => { + setTimeout(() => { + expect(getToolCallContext()).toEqual(context1); + resolve(); + }, 10); + }); + }); + + const promise2 = new Promise((resolve) => { + runWithToolCallContext(context2, () => { + setTimeout(() => { + expect(getToolCallContext()).toEqual(context2); + resolve(); + }, 5); + }); + }); + + await Promise.all([promise1, promise2]); + }); +}); diff --git a/packages/core/src/utils/toolCallContext.ts b/packages/core/src/utils/toolCallContext.ts new file mode 100644 index 0000000000..c371d23783 --- /dev/null +++ b/packages/core/src/utils/toolCallContext.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AsyncLocalStorage } from 'node:async_hooks'; + +/** + * Contextual information for a tool call execution. + */ +export interface ToolCallContext { + /** The unique ID of the tool call. */ + callId: string; + /** The ID of the scheduler managing the execution. */ + schedulerId: string; + /** The ID of the parent tool call, if this is a nested execution (e.g., in a subagent). */ + parentCallId?: string; +} + +/** + * AsyncLocalStorage instance for tool call context. + */ +export const toolCallContext = new AsyncLocalStorage(); + +/** + * Runs a function within a tool call context. + * + * @param context The context to set. + * @param fn The function to run. + * @returns The result of the function. + */ +export function runWithToolCallContext( + context: ToolCallContext, + fn: () => T, +): T { + return toolCallContext.run(context, fn); +} + +/** + * Retrieves the current tool call context. + * + * @returns The current ToolCallContext, or undefined if not in a context. + */ +export function getToolCallContext(): ToolCallContext | undefined { + return toolCallContext.getStore(); +}