feat(agents): migrate subagents to event-driven scheduler (#17567)

This commit is contained in:
Abhi
2026-01-26 17:12:55 -05:00
committed by GitHub
parent 13bc5f620c
commit 9d34ae52d6
8 changed files with 741 additions and 335 deletions
@@ -0,0 +1,74 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
import { scheduleAgentTools } from './agent-scheduler.js';
import { Scheduler } from '../scheduler/scheduler.js';
import type { Config } from '../config/config.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import type { ToolCallRequestInfo } from '../scheduler/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
vi.mock('../scheduler/scheduler.js', () => ({
Scheduler: vi.fn().mockImplementation(() => ({
schedule: vi.fn().mockResolvedValue([{ status: 'success' }]),
})),
}));
describe('agent-scheduler', () => {
let mockConfig: Mocked<Config>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockMessageBus: Mocked<MessageBus>;
beforeEach(() => {
mockMessageBus = {} as Mocked<MessageBus>;
mockToolRegistry = {
getTool: vi.fn(),
} as unknown as Mocked<ToolRegistry>;
mockConfig = {
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
} as unknown as Mocked<Config>;
});
it('should create a scheduler with agent-specific config', async () => {
const requests: ToolCallRequestInfo[] = [
{
callId: 'call-1',
name: 'test-tool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
];
const options = {
schedulerId: 'subagent-1',
parentCallId: 'parent-1',
toolRegistry: mockToolRegistry as unknown as ToolRegistry,
signal: new AbortController().signal,
};
const results = await scheduleAgentTools(
mockConfig as unknown as Config,
requests,
options,
);
expect(results).toEqual([{ status: 'success' }]);
expect(Scheduler).toHaveBeenCalledWith(
expect.objectContaining({
schedulerId: 'subagent-1',
parentCallId: 'parent-1',
messageBus: mockMessageBus,
}),
);
// Verify that the scheduler's config has the overridden tool registry
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
});
});
@@ -0,0 +1,66 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import { Scheduler } from '../scheduler/scheduler.js';
import type {
ToolCallRequestInfo,
CompletedToolCall,
} from '../scheduler/types.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import type { EditorType } from '../utils/editor.js';
/**
* Options for scheduling agent tools.
*/
export interface AgentSchedulingOptions {
/** The unique ID for this agent's scheduler. */
schedulerId: string;
/** The ID of the tool call that invoked this agent. */
parentCallId?: string;
/** The tool registry specific to this agent. */
toolRegistry: ToolRegistry;
/** AbortSignal for cancellation. */
signal: AbortSignal;
/** Optional function to get the preferred editor for tool modifications. */
getPreferredEditor?: () => EditorType | undefined;
}
/**
* Schedules a batch of tool calls for an agent using the new event-driven Scheduler.
*
* @param config The global runtime configuration.
* @param requests The list of tool call requests from the agent.
* @param options Scheduling options including registry and IDs.
* @returns A promise that resolves to the completed tool calls.
*/
export async function scheduleAgentTools(
config: Config,
requests: ToolCallRequestInfo[],
options: AgentSchedulingOptions,
): Promise<CompletedToolCall[]> {
const {
schedulerId,
parentCallId,
toolRegistry,
signal,
getPreferredEditor,
} = options;
// Create a proxy/override of the config to provide the agent-specific tool registry.
const agentConfig: Config = Object.create(config);
agentConfig.getToolRegistry = () => toolRegistry;
const scheduler = new Scheduler({
config: agentConfig,
messageBus: config.getMessageBus(),
getPreferredEditor: getPreferredEditor ?? (() => undefined),
schedulerId,
parentCallId,
});
return scheduler.schedule(requests, signal);
}
+267 -242
View File
@@ -55,6 +55,7 @@ import type {
} from './types.js'; } from './types.js';
import { AgentTerminateMode } from './types.js'; import { AgentTerminateMode } from './types.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import type { ToolCallRequestInfo } from '../scheduler/types.js';
import { CompressionStatus } from '../core/turn.js'; import { CompressionStatus } from '../core/turn.js';
import { ChatCompressionService } from '../services/chatCompressionService.js'; import { ChatCompressionService } from '../services/chatCompressionService.js';
import type { import type {
@@ -67,12 +68,12 @@ import type { ModelRouterService } from '../routing/modelRouterService.js';
const { const {
mockSendMessageStream, mockSendMessageStream,
mockExecuteToolCall, mockScheduleAgentTools,
mockSetSystemInstruction, mockSetSystemInstruction,
mockCompress, mockCompress,
} = vi.hoisted(() => ({ } = vi.hoisted(() => ({
mockSendMessageStream: vi.fn(), mockSendMessageStream: vi.fn(),
mockExecuteToolCall: vi.fn(), mockScheduleAgentTools: vi.fn(),
mockSetSystemInstruction: vi.fn(), mockSetSystemInstruction: vi.fn(),
mockCompress: vi.fn(), mockCompress: vi.fn(),
})); }));
@@ -101,8 +102,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
}; };
}); });
vi.mock('../core/nonInteractiveToolExecutor.js', () => ({ vi.mock('./agent-scheduler.js', () => ({
executeToolCall: mockExecuteToolCall, scheduleAgentTools: mockScheduleAgentTools,
})); }));
vi.mock('../utils/version.js', () => ({ vi.mock('../utils/version.js', () => ({
@@ -275,7 +276,7 @@ describe('LocalAgentExecutor', () => {
mockSetHistory.mockClear(); mockSetHistory.mockClear();
mockSendMessageStream.mockReset(); mockSendMessageStream.mockReset();
mockSetSystemInstruction.mockReset(); mockSetSystemInstruction.mockReset();
mockExecuteToolCall.mockReset(); mockScheduleAgentTools.mockReset();
mockedLogAgentStart.mockReset(); mockedLogAgentStart.mockReset();
mockedLogAgentFinish.mockReset(); mockedLogAgentFinish.mockReset();
mockedPromptIdContext.getStore.mockReset(); mockedPromptIdContext.getStore.mockReset();
@@ -540,34 +541,36 @@ describe('LocalAgentExecutor', () => {
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }], [{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
'T1: Listing', 'T1: Listing',
); );
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: 'call1', request: {
name: LS_TOOL_NAME, callId: 'call1',
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
}, prompt_id: 'test-prompt',
tool: {} as AnyDeclarativeTool, },
invocation: {} as AnyToolInvocation, tool: {} as AnyDeclarativeTool,
response: { invocation: {} as AnyToolInvocation,
callId: 'call1', response: {
resultDisplay: 'file1.txt', callId: 'call1',
responseParts: [ resultDisplay: 'file1.txt',
{ responseParts: [
functionResponse: { {
name: LS_TOOL_NAME, functionResponse: {
response: { result: 'file1.txt' }, name: LS_TOOL_NAME,
id: 'call1', response: { result: 'file1.txt' },
id: 'call1',
},
}, },
}, ],
], error: undefined,
error: undefined, errorType: undefined,
errorType: undefined, contentLength: undefined,
contentLength: undefined, },
}, },
}); ]);
// Turn 2: Model calls complete_task with required output // Turn 2: Model calls complete_task with required output
mockModelResponse( mockModelResponse(
@@ -686,34 +689,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([ mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }, { name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
]); ]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: 'call1', request: {
name: LS_TOOL_NAME, callId: 'call1',
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
}, prompt_id: 'test-prompt',
tool: {} as AnyDeclarativeTool, },
invocation: {} as AnyToolInvocation, tool: {} as AnyDeclarativeTool,
response: { invocation: {} as AnyToolInvocation,
callId: 'call1', response: {
resultDisplay: 'ok', callId: 'call1',
responseParts: [ resultDisplay: 'ok',
{ responseParts: [
functionResponse: { {
name: LS_TOOL_NAME, functionResponse: {
response: {}, name: LS_TOOL_NAME,
id: 'call1', response: {},
id: 'call1',
},
}, },
}, ],
], error: undefined,
error: undefined, errorType: undefined,
errorType: undefined, contentLength: undefined,
contentLength: undefined, },
}, },
}); ]);
mockModelResponse( mockModelResponse(
[ [
@@ -759,34 +764,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([ mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }, { name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
]); ]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: 'call1', request: {
name: LS_TOOL_NAME, callId: 'call1',
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
}, prompt_id: 'test-prompt',
tool: {} as AnyDeclarativeTool, },
invocation: {} as AnyToolInvocation, tool: {} as AnyDeclarativeTool,
response: { invocation: {} as AnyToolInvocation,
callId: 'call1', response: {
resultDisplay: 'ok', callId: 'call1',
responseParts: [ resultDisplay: 'ok',
{ responseParts: [
functionResponse: { {
name: LS_TOOL_NAME, functionResponse: {
response: {}, name: LS_TOOL_NAME,
id: 'call1', response: {},
id: 'call1',
},
}, },
}, ],
], error: undefined,
error: undefined, errorType: undefined,
errorType: undefined, contentLength: undefined,
contentLength: undefined, },
}, },
}); ]);
// Turn 2 (protocol violation) // Turn 2 (protocol violation)
mockModelResponse([], 'I think I am done.'); mockModelResponse([], 'I think I am done.');
@@ -959,33 +966,40 @@ describe('LocalAgentExecutor', () => {
resolveCalls = r; resolveCalls = r;
}); });
mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => { mockScheduleAgentTools.mockImplementation(
callsStarted++; async (_ctx, requests: ToolCallRequestInfo[]) => {
if (callsStarted === 2) resolveCalls(); const results = await Promise.all(
await vi.advanceTimersByTimeAsync(100); requests.map(async (reqInfo) => {
return { callsStarted++;
status: 'success', if (callsStarted === 2) resolveCalls();
request: reqInfo, await vi.advanceTimersByTimeAsync(100);
tool: {} as AnyDeclarativeTool, return {
invocation: {} as AnyToolInvocation, status: 'success',
response: { request: reqInfo,
callId: reqInfo.callId, tool: {} as AnyDeclarativeTool,
resultDisplay: 'ok', invocation: {} as AnyToolInvocation,
responseParts: [ response: {
{ callId: reqInfo.callId,
functionResponse: { resultDisplay: 'ok',
name: reqInfo.name, responseParts: [
response: {}, {
id: reqInfo.callId, functionResponse: {
name: reqInfo.name,
response: {},
id: reqInfo.callId,
},
},
],
error: undefined,
errorType: undefined,
contentLength: undefined,
}, },
}, };
], }),
error: undefined, );
errorType: undefined, return results;
contentLength: undefined, },
}, );
};
});
// Turn 2: Completion // Turn 2: Completion
mockModelResponse([ mockModelResponse([
@@ -1005,7 +1019,7 @@ describe('LocalAgentExecutor', () => {
const output = await runPromise; const output = await runPromise;
expect(mockExecuteToolCall).toHaveBeenCalledTimes(2); expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL); expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
// Safe access to message parts // Safe access to message parts
@@ -1059,7 +1073,7 @@ describe('LocalAgentExecutor', () => {
await executor.run({ goal: 'Sec test' }, signal); await executor.run({ goal: 'Sec test' }, signal);
// Verify external executor was not called (Security held) // Verify external executor was not called (Security held)
expect(mockExecuteToolCall).not.toHaveBeenCalled(); expect(mockScheduleAgentTools).not.toHaveBeenCalled();
// 2. Verify console warning // 2. Verify console warning
expect(consoleWarnSpy).toHaveBeenCalledWith( expect(consoleWarnSpy).toHaveBeenCalledWith(
@@ -1215,37 +1229,36 @@ describe('LocalAgentExecutor', () => {
mockModelResponse([ mockModelResponse([
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' }, { name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
]); ]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'error', {
request: { status: 'error',
callId: 'call1', request: {
name: LS_TOOL_NAME, callId: 'call1',
args: { path: '/fake' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '/fake' },
prompt_id: 'test-prompt', isClientInitiated: false,
}, prompt_id: 'test-prompt',
tool: {} as AnyDeclarativeTool, },
invocation: {} as AnyToolInvocation, tool: {} as AnyDeclarativeTool,
response: { invocation: {} as AnyToolInvocation,
callId: 'call1', response: {
resultDisplay: '', callId: 'call1',
responseParts: [ resultDisplay: '',
{ responseParts: [
functionResponse: { {
name: LS_TOOL_NAME, functionResponse: {
response: { error: toolErrorMessage }, name: LS_TOOL_NAME,
id: 'call1', response: { error: toolErrorMessage },
}, id: 'call1',
}, },
], },
error: { ],
type: 'ToolError', error: new Error(toolErrorMessage),
message: toolErrorMessage, errorType: 'ToolError',
contentLength: 0,
}, },
errorType: 'ToolError',
contentLength: 0,
}, },
}); ]);
// Turn 2: Model sees the error and completes // Turn 2: Model sees the error and completes
mockModelResponse([ mockModelResponse([
@@ -1258,7 +1271,7 @@ describe('LocalAgentExecutor', () => {
const output = await executor.run({ goal: 'Tool failure test' }, signal); const output = await executor.run({ goal: 'Tool failure test' }, signal);
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1); expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledTimes(2); expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
// Verify the error was reported in the activity stream // Verify the error was reported in the activity stream
@@ -1391,28 +1404,30 @@ describe('LocalAgentExecutor', () => {
describe('run (Termination Conditions)', () => { describe('run (Termination Conditions)', () => {
const mockWorkResponse = (id: string) => { const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: id, request: {
name: LS_TOOL_NAME, callId: id,
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
}, },
tool: {} as AnyDeclarativeTool, ]);
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
}; };
it('should terminate when max_turns is reached', async () => { it('should terminate when max_turns is reached', async () => {
@@ -1505,23 +1520,27 @@ describe('LocalAgentExecutor', () => {
]); ]);
// Long running tool // Long running tool
mockExecuteToolCall.mockImplementationOnce(async (_ctx, reqInfo) => { mockScheduleAgentTools.mockImplementationOnce(
await vi.advanceTimersByTimeAsync(61 * 1000); async (_ctx, requests: ToolCallRequestInfo[]) => {
return { await vi.advanceTimersByTimeAsync(61 * 1000);
status: 'success', return [
request: reqInfo, {
tool: {} as AnyDeclarativeTool, status: 'success',
invocation: {} as AnyToolInvocation, request: requests[0],
response: { tool: {} as AnyDeclarativeTool,
callId: 't1', invocation: {} as AnyToolInvocation,
resultDisplay: 'ok', response: {
responseParts: [], callId: 't1',
error: undefined, resultDisplay: 'ok',
errorType: undefined, responseParts: [],
contentLength: undefined, error: undefined,
}, errorType: undefined,
}; contentLength: undefined,
}); },
},
];
},
);
// Recovery turn // Recovery turn
mockModelResponse([], 'I give up'); mockModelResponse([], 'I give up');
@@ -1557,28 +1576,30 @@ describe('LocalAgentExecutor', () => {
describe('run (Recovery Turns)', () => { describe('run (Recovery Turns)', () => {
const mockWorkResponse = (id: string) => { const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: id, request: {
name: LS_TOOL_NAME, callId: id,
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
}, },
tool: {} as AnyDeclarativeTool, ]);
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
}; };
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => { it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
@@ -1873,28 +1894,30 @@ describe('LocalAgentExecutor', () => {
describe('Telemetry and Logging', () => { describe('Telemetry and Logging', () => {
const mockWorkResponse = (id: string) => { const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: id, request: {
name: LS_TOOL_NAME, callId: id,
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
}, },
tool: {} as AnyDeclarativeTool, ]);
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
}; };
beforeEach(() => { beforeEach(() => {
@@ -1960,28 +1983,30 @@ describe('LocalAgentExecutor', () => {
describe('Chat Compression', () => { describe('Chat Compression', () => {
const mockWorkResponse = (id: string) => { const mockWorkResponse = (id: string) => {
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]); mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
mockExecuteToolCall.mockResolvedValueOnce({ mockScheduleAgentTools.mockResolvedValueOnce([
status: 'success', {
request: { status: 'success',
callId: id, request: {
name: LS_TOOL_NAME, callId: id,
args: { path: '.' }, name: LS_TOOL_NAME,
isClientInitiated: false, args: { path: '.' },
prompt_id: 'test-prompt', isClientInitiated: false,
prompt_id: 'test-prompt',
},
tool: {} as AnyDeclarativeTool,
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
}, },
tool: {} as AnyDeclarativeTool, ]);
invocation: {} as AnyToolInvocation,
response: {
callId: id,
resultDisplay: 'ok',
responseParts: [
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
],
error: undefined,
errorType: undefined,
contentLength: undefined,
},
});
}; };
it('should attempt to compress chat history on each turn', async () => { it('should attempt to compress chat history on each turn', async () => {
+79 -60
View File
@@ -15,7 +15,6 @@ import type {
FunctionDeclaration, FunctionDeclaration,
Schema, Schema,
} from '@google/genai'; } from '@google/genai';
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
import { ToolRegistry } from '../tools/tool-registry.js'; import { ToolRegistry } from '../tools/tool-registry.js';
import { CompressionStatus } from '../core/turn.js'; import { CompressionStatus } from '../core/turn.js';
import { type ToolCallRequestInfo } from '../scheduler/types.js'; import { type ToolCallRequestInfo } from '../scheduler/types.js';
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { getModelConfigAlias } from './registry.js'; import { getModelConfigAlias } from './registry.js';
import { getVersion } from '../utils/version.js'; import { getVersion } from '../utils/version.js';
import { ApprovalMode } from '../policy/types.js'; import { getToolCallContext } from '../utils/toolCallContext.js';
import { scheduleAgentTools } from './agent-scheduler.js';
/** A callback function to report on agent activity. */ /** A callback function to report on agent activity. */
export type ActivityCallback = (activity: SubagentActivityEvent) => void; export type ActivityCallback = (activity: SubagentActivityEvent) => void;
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private readonly runtimeContext: Config; private readonly runtimeContext: Config;
private readonly onActivity?: ActivityCallback; private readonly onActivity?: ActivityCallback;
private readonly compressionService: ChatCompressionService; private readonly compressionService: ChatCompressionService;
private readonly parentCallId?: string;
private hasFailedCompressionAttempt = false; private hasFailedCompressionAttempt = false;
/** /**
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
// Get the parent prompt ID from context // Get the parent prompt ID from context
const parentPromptId = promptIdContext.getStore(); const parentPromptId = promptIdContext.getStore();
// Get the parent tool call ID from context
const toolContext = getToolCallContext();
const parentCallId = toolContext?.callId;
return new LocalAgentExecutor( return new LocalAgentExecutor(
definition, definition,
runtimeContext, runtimeContext,
agentToolRegistry, agentToolRegistry,
parentPromptId, parentPromptId,
parentCallId,
onActivity, onActivity,
); );
} }
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
runtimeContext: Config, runtimeContext: Config,
toolRegistry: ToolRegistry, toolRegistry: ToolRegistry,
parentPromptId: string | undefined, parentPromptId: string | undefined,
parentCallId: string | undefined,
onActivity?: ActivityCallback, onActivity?: ActivityCallback,
) { ) {
this.definition = definition; this.definition = definition;
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.toolRegistry = toolRegistry; this.toolRegistry = toolRegistry;
this.onActivity = onActivity; this.onActivity = onActivity;
this.compressionService = new ChatCompressionService(); this.compressionService = new ChatCompressionService();
this.parentCallId = parentCallId;
const randomIdPart = Math.random().toString(36).slice(2, 8); const randomIdPart = Math.random().toString(36).slice(2, 8);
// parentPromptId will be undefined if this agent is invoked directly // parentPromptId will be undefined if this agent is invoked directly
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
let submittedOutput: string | null = null; let submittedOutput: string | null = null;
let taskCompleted = false; let taskCompleted = false;
// We'll collect promises for the tool executions // We'll separate complete_task from other tools
const toolExecutionPromises: Array<Promise<Part[] | void>> = []; const toolRequests: ToolCallRequestInfo[] = [];
// And we'll need a place to store the synchronous results (like complete_task or blocked calls) // Map to keep track of tool name by callId for activity emission
const syncResponseParts: Part[] = []; const toolNameMap = new Map<string, string>();
// Synchronous results (like complete_task or unauthorized calls)
const syncResults = new Map<string, Part>();
for (const [index, functionCall] of functionCalls.entries()) { for (const [index, functionCall] of functionCalls.entries()) {
const callId = functionCall.id ?? `${promptId}-${index}`; const callId = functionCall.id ?? `${promptId}-${index}`;
const args = functionCall.args ?? {}; const args = functionCall.args ?? {};
const toolName = functionCall.name as string;
this.emitActivity('TOOL_CALL_START', { this.emitActivity('TOOL_CALL_START', {
name: functionCall.name, name: toolName,
args, args,
}); });
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) { if (toolName === TASK_COMPLETE_TOOL_NAME) {
if (taskCompleted) { if (taskCompleted) {
// We already have a completion from this turn. Ignore subsequent ones.
const error = const error =
'Task already marked complete in this turn. Ignoring duplicate call.'; 'Task already marked complete in this turn. Ignoring duplicate call.';
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { error }, response: { error },
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}); });
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call', context: 'tool_call',
name: functionCall.name, name: toolName,
error, error,
}); });
continue; continue;
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
if (!validationResult.success) { if (!validationResult.success) {
taskCompleted = false; // Validation failed, revoke completion taskCompleted = false; // Validation failed, revoke completion
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`; const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { error }, response: { error },
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}); });
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call', context: 'tool_call',
name: functionCall.name, name: toolName,
error, error,
}); });
continue; continue;
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
? outputValue ? outputValue
: JSON.stringify(outputValue, null, 2); : JSON.stringify(outputValue, null, 2);
} }
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { result: 'Output submitted and task completed.' }, response: { result: 'Output submitted and task completed.' },
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}, },
}); });
this.emitActivity('TOOL_CALL_END', { this.emitActivity('TOOL_CALL_END', {
name: functionCall.name, name: toolName,
output: 'Output submitted and task completed.', output: 'Output submitted and task completed.',
}); });
} else { } else {
// Failed to provide required output. // Failed to provide required output.
taskCompleted = false; // Revoke completion status taskCompleted = false; // Revoke completion status
const error = `Missing required argument '${outputName}' for completion.`; const error = `Missing required argument '${outputName}' for completion.`;
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { error }, response: { error },
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}); });
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call', context: 'tool_call',
name: functionCall.name, name: toolName,
error, error,
}); });
} }
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
typeof resultArg === 'string' typeof resultArg === 'string'
? resultArg ? resultArg
: JSON.stringify(resultArg, null, 2); : JSON.stringify(resultArg, null, 2);
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { status: 'Result submitted and task completed.' }, response: { status: 'Result submitted and task completed.' },
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}, },
}); });
this.emitActivity('TOOL_CALL_END', { this.emitActivity('TOOL_CALL_END', {
name: functionCall.name, name: toolName,
output: 'Result submitted and task completed.', output: 'Result submitted and task completed.',
}); });
} else { } else {
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
taskCompleted = false; // Revoke completion taskCompleted = false; // Revoke completion
const error = const error =
'Missing required "result" argument. You must provide your findings when calling complete_task.'; 'Missing required "result" argument. You must provide your findings when calling complete_task.';
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: TASK_COMPLETE_TOOL_NAME, name: TASK_COMPLETE_TOOL_NAME,
response: { error }, response: { error },
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
}); });
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call', context: 'tool_call',
name: functionCall.name, name: toolName,
error, error,
}); });
} }
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
} }
// Handle standard tools // Handle standard tools
if (!allowedToolNames.has(functionCall.name as string)) { if (!allowedToolNames.has(toolName)) {
const error = createUnauthorizedToolError(functionCall.name as string); const error = createUnauthorizedToolError(toolName);
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`); debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
syncResponseParts.push({ syncResults.set(callId, {
functionResponse: { functionResponse: {
name: functionCall.name as string, name: toolName,
id: callId, id: callId,
response: { error }, response: { error },
}, },
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call_unauthorized', context: 'tool_call_unauthorized',
name: functionCall.name, name: toolName,
callId, callId,
error, error,
}); });
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
continue; continue;
} }
const requestInfo: ToolCallRequestInfo = { toolRequests.push({
callId, callId,
name: functionCall.name as string, name: toolName,
args, args,
isClientInitiated: true, isClientInitiated: false, // These are coming from the subagent (the "model")
prompt_id: promptId, prompt_id: promptId,
}; });
toolNameMap.set(callId, toolName);
}
// Create a promise for the tool execution // Execute standard tool calls using the new scheduler
const executionPromise = (async () => { if (toolRequests.length > 0) {
const agentContext = Object.create(this.runtimeContext); const completedCalls = await scheduleAgentTools(
agentContext.getToolRegistry = () => this.toolRegistry; this.runtimeContext,
agentContext.getApprovalMode = () => ApprovalMode.YOLO; toolRequests,
{
const { response: toolResponse } = await executeToolCall( schedulerId: this.agentId,
agentContext, parentCallId: this.parentCallId,
requestInfo, toolRegistry: this.toolRegistry,
signal, signal,
); },
);
if (toolResponse.error) { for (const call of completedCalls) {
const toolName =
toolNameMap.get(call.request.callId) || call.request.name;
if (call.status === 'success') {
this.emitActivity('TOOL_CALL_END', {
name: toolName,
output: call.response.resultDisplay,
});
} else if (call.status === 'error') {
this.emitActivity('ERROR', { this.emitActivity('ERROR', {
context: 'tool_call', context: 'tool_call',
name: functionCall.name, name: toolName,
error: toolResponse.error.message, error: call.response.error?.message || 'Unknown error',
}); });
} else { } else if (call.status === 'cancelled') {
this.emitActivity('TOOL_CALL_END', { this.emitActivity('ERROR', {
name: functionCall.name, context: 'tool_call',
output: toolResponse.resultDisplay, name: toolName,
error: 'Tool call was cancelled.',
}); });
} }
return toolResponse.responseParts; // Add result to syncResults to preserve order later
})(); syncResults.set(call.request.callId, call.response.responseParts[0]);
}
toolExecutionPromises.push(executionPromise);
} }
// Wait for all tool executions to complete // Reconstruct toolResponseParts in the original order
const asyncResults = await Promise.all(toolExecutionPromises); const toolResponseParts: Part[] = [];
for (const [index, functionCall] of functionCalls.entries()) {
// Combine all response parts const callId = functionCall.id ?? `${promptId}-${index}`;
const toolResponseParts: Part[] = [...syncResponseParts]; const part = syncResults.get(callId);
for (const result of asyncResults) { if (part) {
if (result) { toolResponseParts.push(part);
toolResponseParts.push(...result);
} }
} }
@@ -70,6 +70,10 @@ import { ROOT_SCHEDULER_ID } from './types.js';
import { ToolErrorType } from '../tools/tool-error.js'; import { ToolErrorType } from '../tools/tool-error.js';
import * as ToolUtils from '../utils/tool-utils.js'; import * as ToolUtils from '../utils/tool-utils.js';
import type { EditorType } from '../utils/editor.js'; import type { EditorType } from '../utils/editor.js';
import {
getToolCallContext,
type ToolCallContext,
} from '../utils/toolCallContext.js';
describe('Scheduler (Orchestrator)', () => { describe('Scheduler (Orchestrator)', () => {
let scheduler: Scheduler; let scheduler: Scheduler;
@@ -1010,4 +1014,68 @@ describe('Scheduler (Orchestrator)', () => {
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
}); });
}); });
describe('Tool Call Context Propagation', () => {
it('should propagate context to the tool executor', async () => {
const schedulerId = 'custom-scheduler';
const parentCallId = 'parent-call';
const customScheduler = new Scheduler({
config: mockConfig,
messageBus: mockMessageBus,
getPreferredEditor,
schedulerId,
parentCallId,
});
const validatingCall: ValidatingToolCall = {
status: 'validating',
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
};
// Mock queueLength to run the loop once
Object.defineProperty(mockStateManager, 'queueLength', {
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
configurable: true,
});
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
Object.defineProperty(mockStateManager, 'firstActiveCall', {
get: vi.fn().mockReturnValue(validatingCall),
configurable: true,
});
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
mockToolRegistry.getTool.mockReturnValue(mockTool);
mockPolicyEngine.check.mockResolvedValue({
decision: PolicyDecision.ALLOW,
});
let capturedContext: ToolCallContext | undefined;
mockExecutor.execute.mockImplementation(async () => {
capturedContext = getToolCallContext();
return {
status: 'success',
request: req1,
tool: mockTool,
invocation: mockInvocation as unknown as AnyToolInvocation,
response: {
callId: req1.callId,
responseParts: [],
resultDisplay: 'ok',
error: undefined,
errorType: undefined,
},
} as unknown as SuccessfulToolCall;
});
await customScheduler.schedule(req1, signal);
expect(capturedContext).toBeDefined();
expect(capturedContext!.callId).toBe(req1.callId);
expect(capturedContext!.schedulerId).toBe(schedulerId);
expect(capturedContext!.parentCallId).toBe(parentCallId);
});
});
}); });
+56 -33
View File
@@ -36,6 +36,7 @@ import {
type SerializableConfirmationDetails, type SerializableConfirmationDetails,
type ToolConfirmationRequest, type ToolConfirmationRequest,
} from '../confirmation-bus/types.js'; } from '../confirmation-bus/types.js';
import { runWithToolCallContext } from '../utils/toolCallContext.js';
interface SchedulerQueueItem { interface SchedulerQueueItem {
requests: ToolCallRequestInfo[]; requests: ToolCallRequestInfo[];
@@ -256,6 +257,7 @@ export class Scheduler {
return this.state.completedBatch; return this.state.completedBatch;
} finally { } finally {
this.isProcessing = false; this.isProcessing = false;
this.state.clearBatch();
this._processNextInRequestQueue(); this._processNextInRequestQueue();
} }
} }
@@ -282,30 +284,39 @@ export class Scheduler {
request: ToolCallRequestInfo, request: ToolCallRequestInfo,
tool: AnyDeclarativeTool, tool: AnyDeclarativeTool,
): ValidatingToolCall | ErroredToolCall { ): ValidatingToolCall | ErroredToolCall {
try { return runWithToolCallContext(
const invocation = tool.build(request.args); {
return { callId: request.callId,
status: 'validating',
request,
tool,
invocation,
startTime: Date.now(),
schedulerId: this.schedulerId, schedulerId: this.schedulerId,
}; parentCallId: this.parentCallId,
} catch (e) { },
return { () => {
status: 'error', try {
request, const invocation = tool.build(request.args);
tool, return {
response: createErrorResponse( status: 'validating',
request, request,
e instanceof Error ? e : new Error(String(e)), tool,
ToolErrorType.INVALID_TOOL_PARAMS, invocation,
), startTime: Date.now(),
durationMs: 0, schedulerId: this.schedulerId,
schedulerId: this.schedulerId, };
}; } catch (e) {
} return {
status: 'error',
request,
tool,
response: createErrorResponse(
request,
e instanceof Error ? e : new Error(String(e)),
ToolErrorType.INVALID_TOOL_PARAMS,
),
durationMs: 0,
schedulerId: this.schedulerId,
};
}
},
);
} }
// --- Phase 2: Processing Loop --- // --- Phase 2: Processing Loop ---
@@ -460,17 +471,29 @@ export class Scheduler {
if (signal.aborted) throw new Error('Operation cancelled'); if (signal.aborted) throw new Error('Operation cancelled');
this.state.updateStatus(callId, 'executing'); this.state.updateStatus(callId, 'executing');
const result = await this.executor.execute({ const activeCall = this.state.firstActiveCall as ExecutingToolCall;
call: this.state.firstActiveCall as ExecutingToolCall,
signal, const result = await runWithToolCallContext(
outputUpdateHandler: (id, out) => {
this.state.updateStatus(id, 'executing', { liveOutput: out }), callId: activeCall.request.callId,
onUpdateToolCall: (updated) => { schedulerId: this.schedulerId,
if (updated.status === 'executing' && updated.pid) { parentCallId: this.parentCallId,
this.state.updateStatus(callId, 'executing', { pid: updated.pid });
}
}, },
}); () =>
this.executor.execute({
call: activeCall,
signal,
outputUpdateHandler: (id, out) =>
this.state.updateStatus(id, 'executing', { liveOutput: out }),
onUpdateToolCall: (updated) => {
if (updated.status === 'executing' && updated.pid) {
this.state.updateStatus(callId, 'executing', {
pid: updated.pid,
});
}
},
}),
);
if (result.status === 'success') { if (result.status === 'success') {
this.state.updateStatus(callId, 'success', result.response); this.state.updateStatus(callId, 'success', result.response);
@@ -0,0 +1,84 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import {
runWithToolCallContext,
getToolCallContext,
} from './toolCallContext.js';
describe('toolCallContext', () => {
it('should store and retrieve tool call context', () => {
const context = {
callId: 'test-call-id',
schedulerId: 'test-scheduler-id',
};
runWithToolCallContext(context, () => {
const storedContext = getToolCallContext();
expect(storedContext).toEqual(context);
});
});
it('should return undefined when no context is set', () => {
expect(getToolCallContext()).toBeUndefined();
});
it('should support nested contexts', () => {
const parentContext = {
callId: 'parent-call-id',
schedulerId: 'parent-scheduler-id',
};
const childContext = {
callId: 'child-call-id',
schedulerId: 'child-scheduler-id',
parentCallId: 'parent-call-id',
};
runWithToolCallContext(parentContext, () => {
expect(getToolCallContext()).toEqual(parentContext);
runWithToolCallContext(childContext, () => {
expect(getToolCallContext()).toEqual(childContext);
});
expect(getToolCallContext()).toEqual(parentContext);
});
});
it('should maintain isolation between parallel executions', async () => {
const context1 = {
callId: 'call-1',
schedulerId: 'scheduler-1',
};
const context2 = {
callId: 'call-2',
schedulerId: 'scheduler-2',
};
const promise1 = new Promise<void>((resolve) => {
runWithToolCallContext(context1, () => {
setTimeout(() => {
expect(getToolCallContext()).toEqual(context1);
resolve();
}, 10);
});
});
const promise2 = new Promise<void>((resolve) => {
runWithToolCallContext(context2, () => {
setTimeout(() => {
expect(getToolCallContext()).toEqual(context2);
resolve();
}, 5);
});
});
await Promise.all([promise1, promise2]);
});
});
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { AsyncLocalStorage } from 'node:async_hooks';
/**
* Contextual information for a tool call execution.
*/
export interface ToolCallContext {
/** The unique ID of the tool call. */
callId: string;
/** The ID of the scheduler managing the execution. */
schedulerId: string;
/** The ID of the parent tool call, if this is a nested execution (e.g., in a subagent). */
parentCallId?: string;
}
/**
* AsyncLocalStorage instance for tool call context.
*/
export const toolCallContext = new AsyncLocalStorage<ToolCallContext>();
/**
* Runs a function within a tool call context.
*
* @param context The context to set.
* @param fn The function to run.
* @returns The result of the function.
*/
export function runWithToolCallContext<T>(
context: ToolCallContext,
fn: () => T,
): T {
return toolCallContext.run(context, fn);
}
/**
* Retrieves the current tool call context.
*
* @returns The current ToolCallContext, or undefined if not in a context.
*/
export function getToolCallContext(): ToolCallContext | undefined {
return toolCallContext.getStore();
}