mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(agents): migrate subagents to event-driven scheduler (#17567)
This commit is contained in:
74
packages/core/src/agents/agent-scheduler.test.ts
Normal file
74
packages/core/src/agents/agent-scheduler.test.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
import { Scheduler } from '../scheduler/scheduler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
vi.mock('../scheduler/scheduler.js', () => ({
|
||||
Scheduler: vi.fn().mockImplementation(() => ({
|
||||
schedule: vi.fn().mockResolvedValue([{ status: 'success' }]),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('agent-scheduler', () => {
|
||||
let mockConfig: Mocked<Config>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockMessageBus: Mocked<MessageBus>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockMessageBus = {} as Mocked<MessageBus>;
|
||||
mockToolRegistry = {
|
||||
getTool: vi.fn(),
|
||||
} as unknown as Mocked<ToolRegistry>;
|
||||
mockConfig = {
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
} as unknown as Mocked<Config>;
|
||||
});
|
||||
|
||||
it('should create a scheduler with agent-specific config', async () => {
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: 'call-1',
|
||||
name: 'test-tool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-1',
|
||||
},
|
||||
];
|
||||
|
||||
const options = {
|
||||
schedulerId: 'subagent-1',
|
||||
parentCallId: 'parent-1',
|
||||
toolRegistry: mockToolRegistry as unknown as ToolRegistry,
|
||||
signal: new AbortController().signal,
|
||||
};
|
||||
|
||||
const results = await scheduleAgentTools(
|
||||
mockConfig as unknown as Config,
|
||||
requests,
|
||||
options,
|
||||
);
|
||||
|
||||
expect(results).toEqual([{ status: 'success' }]);
|
||||
expect(Scheduler).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
schedulerId: 'subagent-1',
|
||||
parentCallId: 'parent-1',
|
||||
messageBus: mockMessageBus,
|
||||
}),
|
||||
);
|
||||
|
||||
// Verify that the scheduler's config has the overridden tool registry
|
||||
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
|
||||
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
|
||||
});
|
||||
});
|
||||
66
packages/core/src/agents/agent-scheduler.ts
Normal file
66
packages/core/src/agents/agent-scheduler.ts
Normal file
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { Scheduler } from '../scheduler/scheduler.js';
|
||||
import type {
|
||||
ToolCallRequestInfo,
|
||||
CompletedToolCall,
|
||||
} from '../scheduler/types.js';
|
||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import type { EditorType } from '../utils/editor.js';
|
||||
|
||||
/**
|
||||
* Options for scheduling agent tools.
|
||||
*/
|
||||
export interface AgentSchedulingOptions {
|
||||
/** The unique ID for this agent's scheduler. */
|
||||
schedulerId: string;
|
||||
/** The ID of the tool call that invoked this agent. */
|
||||
parentCallId?: string;
|
||||
/** The tool registry specific to this agent. */
|
||||
toolRegistry: ToolRegistry;
|
||||
/** AbortSignal for cancellation. */
|
||||
signal: AbortSignal;
|
||||
/** Optional function to get the preferred editor for tool modifications. */
|
||||
getPreferredEditor?: () => EditorType | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedules a batch of tool calls for an agent using the new event-driven Scheduler.
|
||||
*
|
||||
* @param config The global runtime configuration.
|
||||
* @param requests The list of tool call requests from the agent.
|
||||
* @param options Scheduling options including registry and IDs.
|
||||
* @returns A promise that resolves to the completed tool calls.
|
||||
*/
|
||||
export async function scheduleAgentTools(
|
||||
config: Config,
|
||||
requests: ToolCallRequestInfo[],
|
||||
options: AgentSchedulingOptions,
|
||||
): Promise<CompletedToolCall[]> {
|
||||
const {
|
||||
schedulerId,
|
||||
parentCallId,
|
||||
toolRegistry,
|
||||
signal,
|
||||
getPreferredEditor,
|
||||
} = options;
|
||||
|
||||
// Create a proxy/override of the config to provide the agent-specific tool registry.
|
||||
const agentConfig: Config = Object.create(config);
|
||||
agentConfig.getToolRegistry = () => toolRegistry;
|
||||
|
||||
const scheduler = new Scheduler({
|
||||
config: agentConfig,
|
||||
messageBus: config.getMessageBus(),
|
||||
getPreferredEditor: getPreferredEditor ?? (() => undefined),
|
||||
schedulerId,
|
||||
parentCallId,
|
||||
});
|
||||
|
||||
return scheduler.schedule(requests, signal);
|
||||
}
|
||||
@@ -55,6 +55,7 @@ import type {
|
||||
} from './types.js';
|
||||
import { AgentTerminateMode } from './types.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import type {
|
||||
@@ -67,12 +68,12 @@ import type { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
|
||||
const {
|
||||
mockSendMessageStream,
|
||||
mockExecuteToolCall,
|
||||
mockScheduleAgentTools,
|
||||
mockSetSystemInstruction,
|
||||
mockCompress,
|
||||
} = vi.hoisted(() => ({
|
||||
mockSendMessageStream: vi.fn(),
|
||||
mockExecuteToolCall: vi.fn(),
|
||||
mockScheduleAgentTools: vi.fn(),
|
||||
mockSetSystemInstruction: vi.fn(),
|
||||
mockCompress: vi.fn(),
|
||||
}));
|
||||
@@ -101,8 +102,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../core/nonInteractiveToolExecutor.js', () => ({
|
||||
executeToolCall: mockExecuteToolCall,
|
||||
vi.mock('./agent-scheduler.js', () => ({
|
||||
scheduleAgentTools: mockScheduleAgentTools,
|
||||
}));
|
||||
|
||||
vi.mock('../utils/version.js', () => ({
|
||||
@@ -275,7 +276,7 @@ describe('LocalAgentExecutor', () => {
|
||||
mockSetHistory.mockClear();
|
||||
mockSendMessageStream.mockReset();
|
||||
mockSetSystemInstruction.mockReset();
|
||||
mockExecuteToolCall.mockReset();
|
||||
mockScheduleAgentTools.mockReset();
|
||||
mockedLogAgentStart.mockReset();
|
||||
mockedLogAgentFinish.mockReset();
|
||||
mockedPromptIdContext.getStore.mockReset();
|
||||
@@ -540,34 +541,36 @@ describe('LocalAgentExecutor', () => {
|
||||
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
|
||||
'T1: Listing',
|
||||
);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'file1.txt',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { result: 'file1.txt' },
|
||||
id: 'call1',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'file1.txt',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { result: 'file1.txt' },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
});
|
||||
]);
|
||||
|
||||
// Turn 2: Model calls complete_task with required output
|
||||
mockModelResponse(
|
||||
@@ -686,34 +689,36 @@ describe('LocalAgentExecutor', () => {
|
||||
mockModelResponse([
|
||||
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
||||
]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: {},
|
||||
id: 'call1',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: {},
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
});
|
||||
]);
|
||||
|
||||
mockModelResponse(
|
||||
[
|
||||
@@ -759,34 +764,36 @@ describe('LocalAgentExecutor', () => {
|
||||
mockModelResponse([
|
||||
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
||||
]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: {},
|
||||
id: 'call1',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: {},
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
});
|
||||
]);
|
||||
|
||||
// Turn 2 (protocol violation)
|
||||
mockModelResponse([], 'I think I am done.');
|
||||
@@ -959,33 +966,40 @@ describe('LocalAgentExecutor', () => {
|
||||
resolveCalls = r;
|
||||
});
|
||||
|
||||
mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => {
|
||||
callsStarted++;
|
||||
if (callsStarted === 2) resolveCalls();
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
return {
|
||||
status: 'success',
|
||||
request: reqInfo,
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: reqInfo.callId,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: reqInfo.name,
|
||||
response: {},
|
||||
id: reqInfo.callId,
|
||||
mockScheduleAgentTools.mockImplementation(
|
||||
async (_ctx, requests: ToolCallRequestInfo[]) => {
|
||||
const results = await Promise.all(
|
||||
requests.map(async (reqInfo) => {
|
||||
callsStarted++;
|
||||
if (callsStarted === 2) resolveCalls();
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
return {
|
||||
status: 'success',
|
||||
request: reqInfo,
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: reqInfo.callId,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: reqInfo.name,
|
||||
response: {},
|
||||
id: reqInfo.callId,
|
||||
},
|
||||
},
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
};
|
||||
});
|
||||
};
|
||||
}),
|
||||
);
|
||||
return results;
|
||||
},
|
||||
);
|
||||
|
||||
// Turn 2: Completion
|
||||
mockModelResponse([
|
||||
@@ -1005,7 +1019,7 @@ describe('LocalAgentExecutor', () => {
|
||||
|
||||
const output = await runPromise;
|
||||
|
||||
expect(mockExecuteToolCall).toHaveBeenCalledTimes(2);
|
||||
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
|
||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||
|
||||
// Safe access to message parts
|
||||
@@ -1059,7 +1073,7 @@ describe('LocalAgentExecutor', () => {
|
||||
await executor.run({ goal: 'Sec test' }, signal);
|
||||
|
||||
// Verify external executor was not called (Security held)
|
||||
expect(mockExecuteToolCall).not.toHaveBeenCalled();
|
||||
expect(mockScheduleAgentTools).not.toHaveBeenCalled();
|
||||
|
||||
// 2. Verify console warning
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
@@ -1215,37 +1229,36 @@ describe('LocalAgentExecutor', () => {
|
||||
mockModelResponse([
|
||||
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
|
||||
]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'error',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '/fake' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: '',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { error: toolErrorMessage },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
error: {
|
||||
type: 'ToolError',
|
||||
message: toolErrorMessage,
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'error',
|
||||
request: {
|
||||
callId: 'call1',
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '/fake' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 'call1',
|
||||
resultDisplay: '',
|
||||
responseParts: [
|
||||
{
|
||||
functionResponse: {
|
||||
name: LS_TOOL_NAME,
|
||||
response: { error: toolErrorMessage },
|
||||
id: 'call1',
|
||||
},
|
||||
},
|
||||
],
|
||||
error: new Error(toolErrorMessage),
|
||||
errorType: 'ToolError',
|
||||
contentLength: 0,
|
||||
},
|
||||
errorType: 'ToolError',
|
||||
contentLength: 0,
|
||||
},
|
||||
});
|
||||
]);
|
||||
|
||||
// Turn 2: Model sees the error and completes
|
||||
mockModelResponse([
|
||||
@@ -1258,7 +1271,7 @@ describe('LocalAgentExecutor', () => {
|
||||
|
||||
const output = await executor.run({ goal: 'Tool failure test' }, signal);
|
||||
|
||||
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1);
|
||||
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify the error was reported in the activity stream
|
||||
@@ -1391,28 +1404,30 @@ describe('LocalAgentExecutor', () => {
|
||||
describe('run (Termination Conditions)', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
]);
|
||||
};
|
||||
|
||||
it('should terminate when max_turns is reached', async () => {
|
||||
@@ -1505,23 +1520,27 @@ describe('LocalAgentExecutor', () => {
|
||||
]);
|
||||
|
||||
// Long running tool
|
||||
mockExecuteToolCall.mockImplementationOnce(async (_ctx, reqInfo) => {
|
||||
await vi.advanceTimersByTimeAsync(61 * 1000);
|
||||
return {
|
||||
status: 'success',
|
||||
request: reqInfo,
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 't1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
};
|
||||
});
|
||||
mockScheduleAgentTools.mockImplementationOnce(
|
||||
async (_ctx, requests: ToolCallRequestInfo[]) => {
|
||||
await vi.advanceTimersByTimeAsync(61 * 1000);
|
||||
return [
|
||||
{
|
||||
status: 'success',
|
||||
request: requests[0],
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: 't1',
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
];
|
||||
},
|
||||
);
|
||||
|
||||
// Recovery turn
|
||||
mockModelResponse([], 'I give up');
|
||||
@@ -1557,28 +1576,30 @@ describe('LocalAgentExecutor', () => {
|
||||
describe('run (Recovery Turns)', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
]);
|
||||
};
|
||||
|
||||
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
|
||||
@@ -1873,28 +1894,30 @@ describe('LocalAgentExecutor', () => {
|
||||
describe('Telemetry and Logging', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
]);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -1960,28 +1983,30 @@ describe('LocalAgentExecutor', () => {
|
||||
describe('Chat Compression', () => {
|
||||
const mockWorkResponse = (id: string) => {
|
||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||
mockExecuteToolCall.mockResolvedValueOnce({
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||
{
|
||||
status: 'success',
|
||||
request: {
|
||||
callId: id,
|
||||
name: LS_TOOL_NAME,
|
||||
args: { path: '.' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'test-prompt',
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
},
|
||||
tool: {} as AnyDeclarativeTool,
|
||||
invocation: {} as AnyToolInvocation,
|
||||
response: {
|
||||
callId: id,
|
||||
resultDisplay: 'ok',
|
||||
responseParts: [
|
||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||
],
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
contentLength: undefined,
|
||||
},
|
||||
});
|
||||
]);
|
||||
};
|
||||
|
||||
it('should attempt to compress chat history on each turn', async () => {
|
||||
|
||||
@@ -15,7 +15,6 @@ import type {
|
||||
FunctionDeclaration,
|
||||
Schema,
|
||||
} from '@google/genai';
|
||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||
import { CompressionStatus } from '../core/turn.js';
|
||||
import { type ToolCallRequestInfo } from '../scheduler/types.js';
|
||||
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getModelConfigAlias } from './registry.js';
|
||||
import { getVersion } from '../utils/version.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { getToolCallContext } from '../utils/toolCallContext.js';
|
||||
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||
|
||||
/** A callback function to report on agent activity. */
|
||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
private readonly runtimeContext: Config;
|
||||
private readonly onActivity?: ActivityCallback;
|
||||
private readonly compressionService: ChatCompressionService;
|
||||
private readonly parentCallId?: string;
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
/**
|
||||
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
// Get the parent prompt ID from context
|
||||
const parentPromptId = promptIdContext.getStore();
|
||||
|
||||
// Get the parent tool call ID from context
|
||||
const toolContext = getToolCallContext();
|
||||
const parentCallId = toolContext?.callId;
|
||||
|
||||
return new LocalAgentExecutor(
|
||||
definition,
|
||||
runtimeContext,
|
||||
agentToolRegistry,
|
||||
parentPromptId,
|
||||
parentCallId,
|
||||
onActivity,
|
||||
);
|
||||
}
|
||||
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
runtimeContext: Config,
|
||||
toolRegistry: ToolRegistry,
|
||||
parentPromptId: string | undefined,
|
||||
parentCallId: string | undefined,
|
||||
onActivity?: ActivityCallback,
|
||||
) {
|
||||
this.definition = definition;
|
||||
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.onActivity = onActivity;
|
||||
this.compressionService = new ChatCompressionService();
|
||||
this.parentCallId = parentCallId;
|
||||
|
||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||
// parentPromptId will be undefined if this agent is invoked directly
|
||||
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
let submittedOutput: string | null = null;
|
||||
let taskCompleted = false;
|
||||
|
||||
// We'll collect promises for the tool executions
|
||||
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
|
||||
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
|
||||
const syncResponseParts: Part[] = [];
|
||||
// We'll separate complete_task from other tools
|
||||
const toolRequests: ToolCallRequestInfo[] = [];
|
||||
// Map to keep track of tool name by callId for activity emission
|
||||
const toolNameMap = new Map<string, string>();
|
||||
// Synchronous results (like complete_task or unauthorized calls)
|
||||
const syncResults = new Map<string, Part>();
|
||||
|
||||
for (const [index, functionCall] of functionCalls.entries()) {
|
||||
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||
const args = functionCall.args ?? {};
|
||||
const toolName = functionCall.name as string;
|
||||
|
||||
this.emitActivity('TOOL_CALL_START', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
args,
|
||||
});
|
||||
|
||||
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
|
||||
if (toolName === TASK_COMPLETE_TOOL_NAME) {
|
||||
if (taskCompleted) {
|
||||
// We already have a completion from this turn. Ignore subsequent ones.
|
||||
const error =
|
||||
'Task already marked complete in this turn. Ignoring duplicate call.';
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
continue;
|
||||
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
if (!validationResult.success) {
|
||||
taskCompleted = false; // Validation failed, revoke completion
|
||||
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
continue;
|
||||
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
? outputValue
|
||||
: JSON.stringify(outputValue, null, 2);
|
||||
}
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { result: 'Output submitted and task completed.' },
|
||||
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
},
|
||||
});
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
output: 'Output submitted and task completed.',
|
||||
});
|
||||
} else {
|
||||
// Failed to provide required output.
|
||||
taskCompleted = false; // Revoke completion status
|
||||
const error = `Missing required argument '${outputName}' for completion.`;
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
}
|
||||
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
typeof resultArg === 'string'
|
||||
? resultArg
|
||||
: JSON.stringify(resultArg, null, 2);
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { status: 'Result submitted and task completed.' },
|
||||
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
},
|
||||
});
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
output: 'Result submitted and task completed.',
|
||||
});
|
||||
} else {
|
||||
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
taskCompleted = false; // Revoke completion
|
||||
const error =
|
||||
'Missing required "result" argument. You must provide your findings when calling complete_task.';
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: TASK_COMPLETE_TOOL_NAME,
|
||||
response: { error },
|
||||
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
});
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
error,
|
||||
});
|
||||
}
|
||||
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
}
|
||||
|
||||
// Handle standard tools
|
||||
if (!allowedToolNames.has(functionCall.name as string)) {
|
||||
const error = createUnauthorizedToolError(functionCall.name as string);
|
||||
|
||||
if (!allowedToolNames.has(toolName)) {
|
||||
const error = createUnauthorizedToolError(toolName);
|
||||
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
|
||||
|
||||
syncResponseParts.push({
|
||||
syncResults.set(callId, {
|
||||
functionResponse: {
|
||||
name: functionCall.name as string,
|
||||
name: toolName,
|
||||
id: callId,
|
||||
response: { error },
|
||||
},
|
||||
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call_unauthorized',
|
||||
name: functionCall.name,
|
||||
name: toolName,
|
||||
callId,
|
||||
error,
|
||||
});
|
||||
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
||||
continue;
|
||||
}
|
||||
|
||||
const requestInfo: ToolCallRequestInfo = {
|
||||
toolRequests.push({
|
||||
callId,
|
||||
name: functionCall.name as string,
|
||||
name: toolName,
|
||||
args,
|
||||
isClientInitiated: true,
|
||||
isClientInitiated: false, // These are coming from the subagent (the "model")
|
||||
prompt_id: promptId,
|
||||
};
|
||||
});
|
||||
toolNameMap.set(callId, toolName);
|
||||
}
|
||||
|
||||
// Create a promise for the tool execution
|
||||
const executionPromise = (async () => {
|
||||
const agentContext = Object.create(this.runtimeContext);
|
||||
agentContext.getToolRegistry = () => this.toolRegistry;
|
||||
agentContext.getApprovalMode = () => ApprovalMode.YOLO;
|
||||
|
||||
const { response: toolResponse } = await executeToolCall(
|
||||
agentContext,
|
||||
requestInfo,
|
||||
// Execute standard tool calls using the new scheduler
|
||||
if (toolRequests.length > 0) {
|
||||
const completedCalls = await scheduleAgentTools(
|
||||
this.runtimeContext,
|
||||
toolRequests,
|
||||
{
|
||||
schedulerId: this.agentId,
|
||||
parentCallId: this.parentCallId,
|
||||
toolRegistry: this.toolRegistry,
|
||||
signal,
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
if (toolResponse.error) {
|
||||
for (const call of completedCalls) {
|
||||
const toolName =
|
||||
toolNameMap.get(call.request.callId) || call.request.name;
|
||||
if (call.status === 'success') {
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: toolName,
|
||||
output: call.response.resultDisplay,
|
||||
});
|
||||
} else if (call.status === 'error') {
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: functionCall.name,
|
||||
error: toolResponse.error.message,
|
||||
name: toolName,
|
||||
error: call.response.error?.message || 'Unknown error',
|
||||
});
|
||||
} else {
|
||||
this.emitActivity('TOOL_CALL_END', {
|
||||
name: functionCall.name,
|
||||
output: toolResponse.resultDisplay,
|
||||
} else if (call.status === 'cancelled') {
|
||||
this.emitActivity('ERROR', {
|
||||
context: 'tool_call',
|
||||
name: toolName,
|
||||
error: 'Tool call was cancelled.',
|
||||
});
|
||||
}
|
||||
|
||||
return toolResponse.responseParts;
|
||||
})();
|
||||
|
||||
toolExecutionPromises.push(executionPromise);
|
||||
// Add result to syncResults to preserve order later
|
||||
syncResults.set(call.request.callId, call.response.responseParts[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for all tool executions to complete
|
||||
const asyncResults = await Promise.all(toolExecutionPromises);
|
||||
|
||||
// Combine all response parts
|
||||
const toolResponseParts: Part[] = [...syncResponseParts];
|
||||
for (const result of asyncResults) {
|
||||
if (result) {
|
||||
toolResponseParts.push(...result);
|
||||
// Reconstruct toolResponseParts in the original order
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const [index, functionCall] of functionCalls.entries()) {
|
||||
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||
const part = syncResults.get(callId);
|
||||
if (part) {
|
||||
toolResponseParts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,10 @@ import { ROOT_SCHEDULER_ID } from './types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
import * as ToolUtils from '../utils/tool-utils.js';
|
||||
import type { EditorType } from '../utils/editor.js';
|
||||
import {
|
||||
getToolCallContext,
|
||||
type ToolCallContext,
|
||||
} from '../utils/toolCallContext.js';
|
||||
|
||||
describe('Scheduler (Orchestrator)', () => {
|
||||
let scheduler: Scheduler;
|
||||
@@ -1010,4 +1014,68 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Call Context Propagation', () => {
|
||||
it('should propagate context to the tool executor', async () => {
|
||||
const schedulerId = 'custom-scheduler';
|
||||
const parentCallId = 'parent-call';
|
||||
const customScheduler = new Scheduler({
|
||||
config: mockConfig,
|
||||
messageBus: mockMessageBus,
|
||||
getPreferredEditor,
|
||||
schedulerId,
|
||||
parentCallId,
|
||||
});
|
||||
|
||||
const validatingCall: ValidatingToolCall = {
|
||||
status: 'validating',
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
};
|
||||
|
||||
// Mock queueLength to run the loop once
|
||||
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
|
||||
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||
get: vi.fn().mockReturnValue(validatingCall),
|
||||
configurable: true,
|
||||
});
|
||||
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
|
||||
|
||||
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||
mockPolicyEngine.check.mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
});
|
||||
|
||||
let capturedContext: ToolCallContext | undefined;
|
||||
mockExecutor.execute.mockImplementation(async () => {
|
||||
capturedContext = getToolCallContext();
|
||||
return {
|
||||
status: 'success',
|
||||
request: req1,
|
||||
tool: mockTool,
|
||||
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||
response: {
|
||||
callId: req1.callId,
|
||||
responseParts: [],
|
||||
resultDisplay: 'ok',
|
||||
error: undefined,
|
||||
errorType: undefined,
|
||||
},
|
||||
} as unknown as SuccessfulToolCall;
|
||||
});
|
||||
|
||||
await customScheduler.schedule(req1, signal);
|
||||
|
||||
expect(capturedContext).toBeDefined();
|
||||
expect(capturedContext!.callId).toBe(req1.callId);
|
||||
expect(capturedContext!.schedulerId).toBe(schedulerId);
|
||||
expect(capturedContext!.parentCallId).toBe(parentCallId);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -36,6 +36,7 @@ import {
|
||||
type SerializableConfirmationDetails,
|
||||
type ToolConfirmationRequest,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { runWithToolCallContext } from '../utils/toolCallContext.js';
|
||||
|
||||
interface SchedulerQueueItem {
|
||||
requests: ToolCallRequestInfo[];
|
||||
@@ -256,6 +257,7 @@ export class Scheduler {
|
||||
return this.state.completedBatch;
|
||||
} finally {
|
||||
this.isProcessing = false;
|
||||
this.state.clearBatch();
|
||||
this._processNextInRequestQueue();
|
||||
}
|
||||
}
|
||||
@@ -282,30 +284,39 @@ export class Scheduler {
|
||||
request: ToolCallRequestInfo,
|
||||
tool: AnyDeclarativeTool,
|
||||
): ValidatingToolCall | ErroredToolCall {
|
||||
try {
|
||||
const invocation = tool.build(request.args);
|
||||
return {
|
||||
status: 'validating',
|
||||
request,
|
||||
tool,
|
||||
invocation,
|
||||
startTime: Date.now(),
|
||||
return runWithToolCallContext(
|
||||
{
|
||||
callId: request.callId,
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
status: 'error',
|
||||
request,
|
||||
tool,
|
||||
response: createErrorResponse(
|
||||
request,
|
||||
e instanceof Error ? e : new Error(String(e)),
|
||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
),
|
||||
durationMs: 0,
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
}
|
||||
parentCallId: this.parentCallId,
|
||||
},
|
||||
() => {
|
||||
try {
|
||||
const invocation = tool.build(request.args);
|
||||
return {
|
||||
status: 'validating',
|
||||
request,
|
||||
tool,
|
||||
invocation,
|
||||
startTime: Date.now(),
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
status: 'error',
|
||||
request,
|
||||
tool,
|
||||
response: createErrorResponse(
|
||||
request,
|
||||
e instanceof Error ? e : new Error(String(e)),
|
||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
),
|
||||
durationMs: 0,
|
||||
schedulerId: this.schedulerId,
|
||||
};
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// --- Phase 2: Processing Loop ---
|
||||
@@ -460,17 +471,29 @@ export class Scheduler {
|
||||
if (signal.aborted) throw new Error('Operation cancelled');
|
||||
this.state.updateStatus(callId, 'executing');
|
||||
|
||||
const result = await this.executor.execute({
|
||||
call: this.state.firstActiveCall as ExecutingToolCall,
|
||||
signal,
|
||||
outputUpdateHandler: (id, out) =>
|
||||
this.state.updateStatus(id, 'executing', { liveOutput: out }),
|
||||
onUpdateToolCall: (updated) => {
|
||||
if (updated.status === 'executing' && updated.pid) {
|
||||
this.state.updateStatus(callId, 'executing', { pid: updated.pid });
|
||||
}
|
||||
const activeCall = this.state.firstActiveCall as ExecutingToolCall;
|
||||
|
||||
const result = await runWithToolCallContext(
|
||||
{
|
||||
callId: activeCall.request.callId,
|
||||
schedulerId: this.schedulerId,
|
||||
parentCallId: this.parentCallId,
|
||||
},
|
||||
});
|
||||
() =>
|
||||
this.executor.execute({
|
||||
call: activeCall,
|
||||
signal,
|
||||
outputUpdateHandler: (id, out) =>
|
||||
this.state.updateStatus(id, 'executing', { liveOutput: out }),
|
||||
onUpdateToolCall: (updated) => {
|
||||
if (updated.status === 'executing' && updated.pid) {
|
||||
this.state.updateStatus(callId, 'executing', {
|
||||
pid: updated.pid,
|
||||
});
|
||||
}
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
if (result.status === 'success') {
|
||||
this.state.updateStatus(callId, 'success', result.response);
|
||||
|
||||
84
packages/core/src/utils/toolCallContext.test.ts
Normal file
84
packages/core/src/utils/toolCallContext.test.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import {
|
||||
runWithToolCallContext,
|
||||
getToolCallContext,
|
||||
} from './toolCallContext.js';
|
||||
|
||||
describe('toolCallContext', () => {
|
||||
it('should store and retrieve tool call context', () => {
|
||||
const context = {
|
||||
callId: 'test-call-id',
|
||||
schedulerId: 'test-scheduler-id',
|
||||
};
|
||||
|
||||
runWithToolCallContext(context, () => {
|
||||
const storedContext = getToolCallContext();
|
||||
expect(storedContext).toEqual(context);
|
||||
});
|
||||
});
|
||||
|
||||
it('should return undefined when no context is set', () => {
|
||||
expect(getToolCallContext()).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should support nested contexts', () => {
|
||||
const parentContext = {
|
||||
callId: 'parent-call-id',
|
||||
schedulerId: 'parent-scheduler-id',
|
||||
};
|
||||
|
||||
const childContext = {
|
||||
callId: 'child-call-id',
|
||||
schedulerId: 'child-scheduler-id',
|
||||
parentCallId: 'parent-call-id',
|
||||
};
|
||||
|
||||
runWithToolCallContext(parentContext, () => {
|
||||
expect(getToolCallContext()).toEqual(parentContext);
|
||||
|
||||
runWithToolCallContext(childContext, () => {
|
||||
expect(getToolCallContext()).toEqual(childContext);
|
||||
});
|
||||
|
||||
expect(getToolCallContext()).toEqual(parentContext);
|
||||
});
|
||||
});
|
||||
|
||||
it('should maintain isolation between parallel executions', async () => {
|
||||
const context1 = {
|
||||
callId: 'call-1',
|
||||
schedulerId: 'scheduler-1',
|
||||
};
|
||||
|
||||
const context2 = {
|
||||
callId: 'call-2',
|
||||
schedulerId: 'scheduler-2',
|
||||
};
|
||||
|
||||
const promise1 = new Promise<void>((resolve) => {
|
||||
runWithToolCallContext(context1, () => {
|
||||
setTimeout(() => {
|
||||
expect(getToolCallContext()).toEqual(context1);
|
||||
resolve();
|
||||
}, 10);
|
||||
});
|
||||
});
|
||||
|
||||
const promise2 = new Promise<void>((resolve) => {
|
||||
runWithToolCallContext(context2, () => {
|
||||
setTimeout(() => {
|
||||
expect(getToolCallContext()).toEqual(context2);
|
||||
resolve();
|
||||
}, 5);
|
||||
});
|
||||
});
|
||||
|
||||
await Promise.all([promise1, promise2]);
|
||||
});
|
||||
});
|
||||
47
packages/core/src/utils/toolCallContext.ts
Normal file
47
packages/core/src/utils/toolCallContext.ts
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
/**
|
||||
* Contextual information for a tool call execution.
|
||||
*/
|
||||
export interface ToolCallContext {
|
||||
/** The unique ID of the tool call. */
|
||||
callId: string;
|
||||
/** The ID of the scheduler managing the execution. */
|
||||
schedulerId: string;
|
||||
/** The ID of the parent tool call, if this is a nested execution (e.g., in a subagent). */
|
||||
parentCallId?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* AsyncLocalStorage instance for tool call context.
|
||||
*/
|
||||
export const toolCallContext = new AsyncLocalStorage<ToolCallContext>();
|
||||
|
||||
/**
|
||||
* Runs a function within a tool call context.
|
||||
*
|
||||
* @param context The context to set.
|
||||
* @param fn The function to run.
|
||||
* @returns The result of the function.
|
||||
*/
|
||||
export function runWithToolCallContext<T>(
|
||||
context: ToolCallContext,
|
||||
fn: () => T,
|
||||
): T {
|
||||
return toolCallContext.run(context, fn);
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves the current tool call context.
|
||||
*
|
||||
* @returns The current ToolCallContext, or undefined if not in a context.
|
||||
*/
|
||||
export function getToolCallContext(): ToolCallContext | undefined {
|
||||
return toolCallContext.getStore();
|
||||
}
|
||||
Reference in New Issue
Block a user