mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 20:44:46 -07:00
feat(agents): migrate subagents to event-driven scheduler (#17567)
This commit is contained in:
@@ -0,0 +1,74 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2026 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach, type Mocked } from 'vitest';
|
||||||
|
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||||
|
import { Scheduler } from '../scheduler/scheduler.js';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
|
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||||
|
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||||
|
|
||||||
|
vi.mock('../scheduler/scheduler.js', () => ({
|
||||||
|
Scheduler: vi.fn().mockImplementation(() => ({
|
||||||
|
schedule: vi.fn().mockResolvedValue([{ status: 'success' }]),
|
||||||
|
})),
|
||||||
|
}));
|
||||||
|
|
||||||
|
describe('agent-scheduler', () => {
|
||||||
|
let mockConfig: Mocked<Config>;
|
||||||
|
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||||
|
let mockMessageBus: Mocked<MessageBus>;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
mockMessageBus = {} as Mocked<MessageBus>;
|
||||||
|
mockToolRegistry = {
|
||||||
|
getTool: vi.fn(),
|
||||||
|
} as unknown as Mocked<ToolRegistry>;
|
||||||
|
mockConfig = {
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||||
|
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||||
|
} as unknown as Mocked<Config>;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should create a scheduler with agent-specific config', async () => {
|
||||||
|
const requests: ToolCallRequestInfo[] = [
|
||||||
|
{
|
||||||
|
callId: 'call-1',
|
||||||
|
name: 'test-tool',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-1',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const options = {
|
||||||
|
schedulerId: 'subagent-1',
|
||||||
|
parentCallId: 'parent-1',
|
||||||
|
toolRegistry: mockToolRegistry as unknown as ToolRegistry,
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
};
|
||||||
|
|
||||||
|
const results = await scheduleAgentTools(
|
||||||
|
mockConfig as unknown as Config,
|
||||||
|
requests,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(results).toEqual([{ status: 'success' }]);
|
||||||
|
expect(Scheduler).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
schedulerId: 'subagent-1',
|
||||||
|
parentCallId: 'parent-1',
|
||||||
|
messageBus: mockMessageBus,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify that the scheduler's config has the overridden tool registry
|
||||||
|
const schedulerConfig = vi.mocked(Scheduler).mock.calls[0][0].config;
|
||||||
|
expect(schedulerConfig.getToolRegistry()).toBe(mockToolRegistry);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2026 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
|
import { Scheduler } from '../scheduler/scheduler.js';
|
||||||
|
import type {
|
||||||
|
ToolCallRequestInfo,
|
||||||
|
CompletedToolCall,
|
||||||
|
} from '../scheduler/types.js';
|
||||||
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
|
import type { EditorType } from '../utils/editor.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for scheduling agent tools.
|
||||||
|
*/
|
||||||
|
export interface AgentSchedulingOptions {
|
||||||
|
/** The unique ID for this agent's scheduler. */
|
||||||
|
schedulerId: string;
|
||||||
|
/** The ID of the tool call that invoked this agent. */
|
||||||
|
parentCallId?: string;
|
||||||
|
/** The tool registry specific to this agent. */
|
||||||
|
toolRegistry: ToolRegistry;
|
||||||
|
/** AbortSignal for cancellation. */
|
||||||
|
signal: AbortSignal;
|
||||||
|
/** Optional function to get the preferred editor for tool modifications. */
|
||||||
|
getPreferredEditor?: () => EditorType | undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schedules a batch of tool calls for an agent using the new event-driven Scheduler.
|
||||||
|
*
|
||||||
|
* @param config The global runtime configuration.
|
||||||
|
* @param requests The list of tool call requests from the agent.
|
||||||
|
* @param options Scheduling options including registry and IDs.
|
||||||
|
* @returns A promise that resolves to the completed tool calls.
|
||||||
|
*/
|
||||||
|
export async function scheduleAgentTools(
|
||||||
|
config: Config,
|
||||||
|
requests: ToolCallRequestInfo[],
|
||||||
|
options: AgentSchedulingOptions,
|
||||||
|
): Promise<CompletedToolCall[]> {
|
||||||
|
const {
|
||||||
|
schedulerId,
|
||||||
|
parentCallId,
|
||||||
|
toolRegistry,
|
||||||
|
signal,
|
||||||
|
getPreferredEditor,
|
||||||
|
} = options;
|
||||||
|
|
||||||
|
// Create a proxy/override of the config to provide the agent-specific tool registry.
|
||||||
|
const agentConfig: Config = Object.create(config);
|
||||||
|
agentConfig.getToolRegistry = () => toolRegistry;
|
||||||
|
|
||||||
|
const scheduler = new Scheduler({
|
||||||
|
config: agentConfig,
|
||||||
|
messageBus: config.getMessageBus(),
|
||||||
|
getPreferredEditor: getPreferredEditor ?? (() => undefined),
|
||||||
|
schedulerId,
|
||||||
|
parentCallId,
|
||||||
|
});
|
||||||
|
|
||||||
|
return scheduler.schedule(requests, signal);
|
||||||
|
}
|
||||||
@@ -55,6 +55,7 @@ import type {
|
|||||||
} from './types.js';
|
} from './types.js';
|
||||||
import { AgentTerminateMode } from './types.js';
|
import { AgentTerminateMode } from './types.js';
|
||||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||||
|
import type { ToolCallRequestInfo } from '../scheduler/types.js';
|
||||||
import { CompressionStatus } from '../core/turn.js';
|
import { CompressionStatus } from '../core/turn.js';
|
||||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||||
import type {
|
import type {
|
||||||
@@ -67,12 +68,12 @@ import type { ModelRouterService } from '../routing/modelRouterService.js';
|
|||||||
|
|
||||||
const {
|
const {
|
||||||
mockSendMessageStream,
|
mockSendMessageStream,
|
||||||
mockExecuteToolCall,
|
mockScheduleAgentTools,
|
||||||
mockSetSystemInstruction,
|
mockSetSystemInstruction,
|
||||||
mockCompress,
|
mockCompress,
|
||||||
} = vi.hoisted(() => ({
|
} = vi.hoisted(() => ({
|
||||||
mockSendMessageStream: vi.fn(),
|
mockSendMessageStream: vi.fn(),
|
||||||
mockExecuteToolCall: vi.fn(),
|
mockScheduleAgentTools: vi.fn(),
|
||||||
mockSetSystemInstruction: vi.fn(),
|
mockSetSystemInstruction: vi.fn(),
|
||||||
mockCompress: vi.fn(),
|
mockCompress: vi.fn(),
|
||||||
}));
|
}));
|
||||||
@@ -101,8 +102,8 @@ vi.mock('../core/geminiChat.js', async (importOriginal) => {
|
|||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
vi.mock('../core/nonInteractiveToolExecutor.js', () => ({
|
vi.mock('./agent-scheduler.js', () => ({
|
||||||
executeToolCall: mockExecuteToolCall,
|
scheduleAgentTools: mockScheduleAgentTools,
|
||||||
}));
|
}));
|
||||||
|
|
||||||
vi.mock('../utils/version.js', () => ({
|
vi.mock('../utils/version.js', () => ({
|
||||||
@@ -275,7 +276,7 @@ describe('LocalAgentExecutor', () => {
|
|||||||
mockSetHistory.mockClear();
|
mockSetHistory.mockClear();
|
||||||
mockSendMessageStream.mockReset();
|
mockSendMessageStream.mockReset();
|
||||||
mockSetSystemInstruction.mockReset();
|
mockSetSystemInstruction.mockReset();
|
||||||
mockExecuteToolCall.mockReset();
|
mockScheduleAgentTools.mockReset();
|
||||||
mockedLogAgentStart.mockReset();
|
mockedLogAgentStart.mockReset();
|
||||||
mockedLogAgentFinish.mockReset();
|
mockedLogAgentFinish.mockReset();
|
||||||
mockedPromptIdContext.getStore.mockReset();
|
mockedPromptIdContext.getStore.mockReset();
|
||||||
@@ -540,34 +541,36 @@ describe('LocalAgentExecutor', () => {
|
|||||||
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
|
[{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' }],
|
||||||
'T1: Listing',
|
'T1: Listing',
|
||||||
);
|
);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: 'call1',
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: 'call1',
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
},
|
prompt_id: 'test-prompt',
|
||||||
tool: {} as AnyDeclarativeTool,
|
},
|
||||||
invocation: {} as AnyToolInvocation,
|
tool: {} as AnyDeclarativeTool,
|
||||||
response: {
|
invocation: {} as AnyToolInvocation,
|
||||||
callId: 'call1',
|
response: {
|
||||||
resultDisplay: 'file1.txt',
|
callId: 'call1',
|
||||||
responseParts: [
|
resultDisplay: 'file1.txt',
|
||||||
{
|
responseParts: [
|
||||||
functionResponse: {
|
{
|
||||||
name: LS_TOOL_NAME,
|
functionResponse: {
|
||||||
response: { result: 'file1.txt' },
|
name: LS_TOOL_NAME,
|
||||||
id: 'call1',
|
response: { result: 'file1.txt' },
|
||||||
|
id: 'call1',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
],
|
||||||
],
|
error: undefined,
|
||||||
error: undefined,
|
errorType: undefined,
|
||||||
errorType: undefined,
|
contentLength: undefined,
|
||||||
contentLength: undefined,
|
},
|
||||||
},
|
},
|
||||||
});
|
]);
|
||||||
|
|
||||||
// Turn 2: Model calls complete_task with required output
|
// Turn 2: Model calls complete_task with required output
|
||||||
mockModelResponse(
|
mockModelResponse(
|
||||||
@@ -686,34 +689,36 @@ describe('LocalAgentExecutor', () => {
|
|||||||
mockModelResponse([
|
mockModelResponse([
|
||||||
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
||||||
]);
|
]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: 'call1',
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: 'call1',
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
},
|
prompt_id: 'test-prompt',
|
||||||
tool: {} as AnyDeclarativeTool,
|
},
|
||||||
invocation: {} as AnyToolInvocation,
|
tool: {} as AnyDeclarativeTool,
|
||||||
response: {
|
invocation: {} as AnyToolInvocation,
|
||||||
callId: 'call1',
|
response: {
|
||||||
resultDisplay: 'ok',
|
callId: 'call1',
|
||||||
responseParts: [
|
resultDisplay: 'ok',
|
||||||
{
|
responseParts: [
|
||||||
functionResponse: {
|
{
|
||||||
name: LS_TOOL_NAME,
|
functionResponse: {
|
||||||
response: {},
|
name: LS_TOOL_NAME,
|
||||||
id: 'call1',
|
response: {},
|
||||||
|
id: 'call1',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
],
|
||||||
],
|
error: undefined,
|
||||||
error: undefined,
|
errorType: undefined,
|
||||||
errorType: undefined,
|
contentLength: undefined,
|
||||||
contentLength: undefined,
|
},
|
||||||
},
|
},
|
||||||
});
|
]);
|
||||||
|
|
||||||
mockModelResponse(
|
mockModelResponse(
|
||||||
[
|
[
|
||||||
@@ -759,34 +764,36 @@ describe('LocalAgentExecutor', () => {
|
|||||||
mockModelResponse([
|
mockModelResponse([
|
||||||
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
{ name: LS_TOOL_NAME, args: { path: '.' }, id: 'call1' },
|
||||||
]);
|
]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: 'call1',
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: 'call1',
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
},
|
prompt_id: 'test-prompt',
|
||||||
tool: {} as AnyDeclarativeTool,
|
},
|
||||||
invocation: {} as AnyToolInvocation,
|
tool: {} as AnyDeclarativeTool,
|
||||||
response: {
|
invocation: {} as AnyToolInvocation,
|
||||||
callId: 'call1',
|
response: {
|
||||||
resultDisplay: 'ok',
|
callId: 'call1',
|
||||||
responseParts: [
|
resultDisplay: 'ok',
|
||||||
{
|
responseParts: [
|
||||||
functionResponse: {
|
{
|
||||||
name: LS_TOOL_NAME,
|
functionResponse: {
|
||||||
response: {},
|
name: LS_TOOL_NAME,
|
||||||
id: 'call1',
|
response: {},
|
||||||
|
id: 'call1',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
],
|
||||||
],
|
error: undefined,
|
||||||
error: undefined,
|
errorType: undefined,
|
||||||
errorType: undefined,
|
contentLength: undefined,
|
||||||
contentLength: undefined,
|
},
|
||||||
},
|
},
|
||||||
});
|
]);
|
||||||
|
|
||||||
// Turn 2 (protocol violation)
|
// Turn 2 (protocol violation)
|
||||||
mockModelResponse([], 'I think I am done.');
|
mockModelResponse([], 'I think I am done.');
|
||||||
@@ -959,33 +966,40 @@ describe('LocalAgentExecutor', () => {
|
|||||||
resolveCalls = r;
|
resolveCalls = r;
|
||||||
});
|
});
|
||||||
|
|
||||||
mockExecuteToolCall.mockImplementation(async (_ctx, reqInfo) => {
|
mockScheduleAgentTools.mockImplementation(
|
||||||
callsStarted++;
|
async (_ctx, requests: ToolCallRequestInfo[]) => {
|
||||||
if (callsStarted === 2) resolveCalls();
|
const results = await Promise.all(
|
||||||
await vi.advanceTimersByTimeAsync(100);
|
requests.map(async (reqInfo) => {
|
||||||
return {
|
callsStarted++;
|
||||||
status: 'success',
|
if (callsStarted === 2) resolveCalls();
|
||||||
request: reqInfo,
|
await vi.advanceTimersByTimeAsync(100);
|
||||||
tool: {} as AnyDeclarativeTool,
|
return {
|
||||||
invocation: {} as AnyToolInvocation,
|
status: 'success',
|
||||||
response: {
|
request: reqInfo,
|
||||||
callId: reqInfo.callId,
|
tool: {} as AnyDeclarativeTool,
|
||||||
resultDisplay: 'ok',
|
invocation: {} as AnyToolInvocation,
|
||||||
responseParts: [
|
response: {
|
||||||
{
|
callId: reqInfo.callId,
|
||||||
functionResponse: {
|
resultDisplay: 'ok',
|
||||||
name: reqInfo.name,
|
responseParts: [
|
||||||
response: {},
|
{
|
||||||
id: reqInfo.callId,
|
functionResponse: {
|
||||||
|
name: reqInfo.name,
|
||||||
|
response: {},
|
||||||
|
id: reqInfo.callId,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
},
|
},
|
||||||
},
|
};
|
||||||
],
|
}),
|
||||||
error: undefined,
|
);
|
||||||
errorType: undefined,
|
return results;
|
||||||
contentLength: undefined,
|
},
|
||||||
},
|
);
|
||||||
};
|
|
||||||
});
|
|
||||||
|
|
||||||
// Turn 2: Completion
|
// Turn 2: Completion
|
||||||
mockModelResponse([
|
mockModelResponse([
|
||||||
@@ -1005,7 +1019,7 @@ describe('LocalAgentExecutor', () => {
|
|||||||
|
|
||||||
const output = await runPromise;
|
const output = await runPromise;
|
||||||
|
|
||||||
expect(mockExecuteToolCall).toHaveBeenCalledTimes(2);
|
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
|
||||||
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
expect(output.terminate_reason).toBe(AgentTerminateMode.GOAL);
|
||||||
|
|
||||||
// Safe access to message parts
|
// Safe access to message parts
|
||||||
@@ -1059,7 +1073,7 @@ describe('LocalAgentExecutor', () => {
|
|||||||
await executor.run({ goal: 'Sec test' }, signal);
|
await executor.run({ goal: 'Sec test' }, signal);
|
||||||
|
|
||||||
// Verify external executor was not called (Security held)
|
// Verify external executor was not called (Security held)
|
||||||
expect(mockExecuteToolCall).not.toHaveBeenCalled();
|
expect(mockScheduleAgentTools).not.toHaveBeenCalled();
|
||||||
|
|
||||||
// 2. Verify console warning
|
// 2. Verify console warning
|
||||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||||
@@ -1215,37 +1229,36 @@ describe('LocalAgentExecutor', () => {
|
|||||||
mockModelResponse([
|
mockModelResponse([
|
||||||
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
|
{ name: LS_TOOL_NAME, args: { path: '/fake' }, id: 'call1' },
|
||||||
]);
|
]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'error',
|
{
|
||||||
request: {
|
status: 'error',
|
||||||
callId: 'call1',
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: 'call1',
|
||||||
args: { path: '/fake' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '/fake' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
},
|
prompt_id: 'test-prompt',
|
||||||
tool: {} as AnyDeclarativeTool,
|
},
|
||||||
invocation: {} as AnyToolInvocation,
|
tool: {} as AnyDeclarativeTool,
|
||||||
response: {
|
invocation: {} as AnyToolInvocation,
|
||||||
callId: 'call1',
|
response: {
|
||||||
resultDisplay: '',
|
callId: 'call1',
|
||||||
responseParts: [
|
resultDisplay: '',
|
||||||
{
|
responseParts: [
|
||||||
functionResponse: {
|
{
|
||||||
name: LS_TOOL_NAME,
|
functionResponse: {
|
||||||
response: { error: toolErrorMessage },
|
name: LS_TOOL_NAME,
|
||||||
id: 'call1',
|
response: { error: toolErrorMessage },
|
||||||
},
|
id: 'call1',
|
||||||
},
|
},
|
||||||
],
|
},
|
||||||
error: {
|
],
|
||||||
type: 'ToolError',
|
error: new Error(toolErrorMessage),
|
||||||
message: toolErrorMessage,
|
errorType: 'ToolError',
|
||||||
|
contentLength: 0,
|
||||||
},
|
},
|
||||||
errorType: 'ToolError',
|
|
||||||
contentLength: 0,
|
|
||||||
},
|
},
|
||||||
});
|
]);
|
||||||
|
|
||||||
// Turn 2: Model sees the error and completes
|
// Turn 2: Model sees the error and completes
|
||||||
mockModelResponse([
|
mockModelResponse([
|
||||||
@@ -1258,7 +1271,7 @@ describe('LocalAgentExecutor', () => {
|
|||||||
|
|
||||||
const output = await executor.run({ goal: 'Tool failure test' }, signal);
|
const output = await executor.run({ goal: 'Tool failure test' }, signal);
|
||||||
|
|
||||||
expect(mockExecuteToolCall).toHaveBeenCalledTimes(1);
|
expect(mockScheduleAgentTools).toHaveBeenCalledTimes(1);
|
||||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
// Verify the error was reported in the activity stream
|
// Verify the error was reported in the activity stream
|
||||||
@@ -1391,28 +1404,30 @@ describe('LocalAgentExecutor', () => {
|
|||||||
describe('run (Termination Conditions)', () => {
|
describe('run (Termination Conditions)', () => {
|
||||||
const mockWorkResponse = (id: string) => {
|
const mockWorkResponse = (id: string) => {
|
||||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: id,
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: id,
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
tool: {} as AnyDeclarativeTool,
|
||||||
|
invocation: {} as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: id,
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
responseParts: [
|
||||||
|
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
tool: {} as AnyDeclarativeTool,
|
]);
|
||||||
invocation: {} as AnyToolInvocation,
|
|
||||||
response: {
|
|
||||||
callId: id,
|
|
||||||
resultDisplay: 'ok',
|
|
||||||
responseParts: [
|
|
||||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
|
||||||
],
|
|
||||||
error: undefined,
|
|
||||||
errorType: undefined,
|
|
||||||
contentLength: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
it('should terminate when max_turns is reached', async () => {
|
it('should terminate when max_turns is reached', async () => {
|
||||||
@@ -1505,23 +1520,27 @@ describe('LocalAgentExecutor', () => {
|
|||||||
]);
|
]);
|
||||||
|
|
||||||
// Long running tool
|
// Long running tool
|
||||||
mockExecuteToolCall.mockImplementationOnce(async (_ctx, reqInfo) => {
|
mockScheduleAgentTools.mockImplementationOnce(
|
||||||
await vi.advanceTimersByTimeAsync(61 * 1000);
|
async (_ctx, requests: ToolCallRequestInfo[]) => {
|
||||||
return {
|
await vi.advanceTimersByTimeAsync(61 * 1000);
|
||||||
status: 'success',
|
return [
|
||||||
request: reqInfo,
|
{
|
||||||
tool: {} as AnyDeclarativeTool,
|
status: 'success',
|
||||||
invocation: {} as AnyToolInvocation,
|
request: requests[0],
|
||||||
response: {
|
tool: {} as AnyDeclarativeTool,
|
||||||
callId: 't1',
|
invocation: {} as AnyToolInvocation,
|
||||||
resultDisplay: 'ok',
|
response: {
|
||||||
responseParts: [],
|
callId: 't1',
|
||||||
error: undefined,
|
resultDisplay: 'ok',
|
||||||
errorType: undefined,
|
responseParts: [],
|
||||||
contentLength: undefined,
|
error: undefined,
|
||||||
},
|
errorType: undefined,
|
||||||
};
|
contentLength: undefined,
|
||||||
});
|
},
|
||||||
|
},
|
||||||
|
];
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// Recovery turn
|
// Recovery turn
|
||||||
mockModelResponse([], 'I give up');
|
mockModelResponse([], 'I give up');
|
||||||
@@ -1557,28 +1576,30 @@ describe('LocalAgentExecutor', () => {
|
|||||||
describe('run (Recovery Turns)', () => {
|
describe('run (Recovery Turns)', () => {
|
||||||
const mockWorkResponse = (id: string) => {
|
const mockWorkResponse = (id: string) => {
|
||||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: id,
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: id,
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
tool: {} as AnyDeclarativeTool,
|
||||||
|
invocation: {} as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: id,
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
responseParts: [
|
||||||
|
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
tool: {} as AnyDeclarativeTool,
|
]);
|
||||||
invocation: {} as AnyToolInvocation,
|
|
||||||
response: {
|
|
||||||
callId: id,
|
|
||||||
resultDisplay: 'ok',
|
|
||||||
responseParts: [
|
|
||||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
|
||||||
],
|
|
||||||
error: undefined,
|
|
||||||
errorType: undefined,
|
|
||||||
contentLength: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
|
it('should recover successfully if complete_task is called during the grace turn after MAX_TURNS', async () => {
|
||||||
@@ -1873,28 +1894,30 @@ describe('LocalAgentExecutor', () => {
|
|||||||
describe('Telemetry and Logging', () => {
|
describe('Telemetry and Logging', () => {
|
||||||
const mockWorkResponse = (id: string) => {
|
const mockWorkResponse = (id: string) => {
|
||||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: id,
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: id,
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
tool: {} as AnyDeclarativeTool,
|
||||||
|
invocation: {} as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: id,
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
responseParts: [
|
||||||
|
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
tool: {} as AnyDeclarativeTool,
|
]);
|
||||||
invocation: {} as AnyToolInvocation,
|
|
||||||
response: {
|
|
||||||
callId: id,
|
|
||||||
resultDisplay: 'ok',
|
|
||||||
responseParts: [
|
|
||||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
|
||||||
],
|
|
||||||
error: undefined,
|
|
||||||
errorType: undefined,
|
|
||||||
contentLength: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -1960,28 +1983,30 @@ describe('LocalAgentExecutor', () => {
|
|||||||
describe('Chat Compression', () => {
|
describe('Chat Compression', () => {
|
||||||
const mockWorkResponse = (id: string) => {
|
const mockWorkResponse = (id: string) => {
|
||||||
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
mockModelResponse([{ name: LS_TOOL_NAME, args: { path: '.' }, id }]);
|
||||||
mockExecuteToolCall.mockResolvedValueOnce({
|
mockScheduleAgentTools.mockResolvedValueOnce([
|
||||||
status: 'success',
|
{
|
||||||
request: {
|
status: 'success',
|
||||||
callId: id,
|
request: {
|
||||||
name: LS_TOOL_NAME,
|
callId: id,
|
||||||
args: { path: '.' },
|
name: LS_TOOL_NAME,
|
||||||
isClientInitiated: false,
|
args: { path: '.' },
|
||||||
prompt_id: 'test-prompt',
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'test-prompt',
|
||||||
|
},
|
||||||
|
tool: {} as AnyDeclarativeTool,
|
||||||
|
invocation: {} as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: id,
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
responseParts: [
|
||||||
|
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
||||||
|
],
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
contentLength: undefined,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
tool: {} as AnyDeclarativeTool,
|
]);
|
||||||
invocation: {} as AnyToolInvocation,
|
|
||||||
response: {
|
|
||||||
callId: id,
|
|
||||||
resultDisplay: 'ok',
|
|
||||||
responseParts: [
|
|
||||||
{ functionResponse: { name: LS_TOOL_NAME, response: {}, id } },
|
|
||||||
],
|
|
||||||
error: undefined,
|
|
||||||
errorType: undefined,
|
|
||||||
contentLength: undefined,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
|
|
||||||
it('should attempt to compress chat history on each turn', async () => {
|
it('should attempt to compress chat history on each turn', async () => {
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ import type {
|
|||||||
FunctionDeclaration,
|
FunctionDeclaration,
|
||||||
Schema,
|
Schema,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
|
||||||
import { ToolRegistry } from '../tools/tool-registry.js';
|
import { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
import { CompressionStatus } from '../core/turn.js';
|
import { CompressionStatus } from '../core/turn.js';
|
||||||
import { type ToolCallRequestInfo } from '../scheduler/types.js';
|
import { type ToolCallRequestInfo } from '../scheduler/types.js';
|
||||||
@@ -48,7 +47,8 @@ import { zodToJsonSchema } from 'zod-to-json-schema';
|
|||||||
import { debugLogger } from '../utils/debugLogger.js';
|
import { debugLogger } from '../utils/debugLogger.js';
|
||||||
import { getModelConfigAlias } from './registry.js';
|
import { getModelConfigAlias } from './registry.js';
|
||||||
import { getVersion } from '../utils/version.js';
|
import { getVersion } from '../utils/version.js';
|
||||||
import { ApprovalMode } from '../policy/types.js';
|
import { getToolCallContext } from '../utils/toolCallContext.js';
|
||||||
|
import { scheduleAgentTools } from './agent-scheduler.js';
|
||||||
|
|
||||||
/** A callback function to report on agent activity. */
|
/** A callback function to report on agent activity. */
|
||||||
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
export type ActivityCallback = (activity: SubagentActivityEvent) => void;
|
||||||
@@ -86,6 +86,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
private readonly runtimeContext: Config;
|
private readonly runtimeContext: Config;
|
||||||
private readonly onActivity?: ActivityCallback;
|
private readonly onActivity?: ActivityCallback;
|
||||||
private readonly compressionService: ChatCompressionService;
|
private readonly compressionService: ChatCompressionService;
|
||||||
|
private readonly parentCallId?: string;
|
||||||
private hasFailedCompressionAttempt = false;
|
private hasFailedCompressionAttempt = false;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -158,11 +159,16 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
// Get the parent prompt ID from context
|
// Get the parent prompt ID from context
|
||||||
const parentPromptId = promptIdContext.getStore();
|
const parentPromptId = promptIdContext.getStore();
|
||||||
|
|
||||||
|
// Get the parent tool call ID from context
|
||||||
|
const toolContext = getToolCallContext();
|
||||||
|
const parentCallId = toolContext?.callId;
|
||||||
|
|
||||||
return new LocalAgentExecutor(
|
return new LocalAgentExecutor(
|
||||||
definition,
|
definition,
|
||||||
runtimeContext,
|
runtimeContext,
|
||||||
agentToolRegistry,
|
agentToolRegistry,
|
||||||
parentPromptId,
|
parentPromptId,
|
||||||
|
parentCallId,
|
||||||
onActivity,
|
onActivity,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@@ -178,6 +184,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
runtimeContext: Config,
|
runtimeContext: Config,
|
||||||
toolRegistry: ToolRegistry,
|
toolRegistry: ToolRegistry,
|
||||||
parentPromptId: string | undefined,
|
parentPromptId: string | undefined,
|
||||||
|
parentCallId: string | undefined,
|
||||||
onActivity?: ActivityCallback,
|
onActivity?: ActivityCallback,
|
||||||
) {
|
) {
|
||||||
this.definition = definition;
|
this.definition = definition;
|
||||||
@@ -185,6 +192,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
this.toolRegistry = toolRegistry;
|
this.toolRegistry = toolRegistry;
|
||||||
this.onActivity = onActivity;
|
this.onActivity = onActivity;
|
||||||
this.compressionService = new ChatCompressionService();
|
this.compressionService = new ChatCompressionService();
|
||||||
|
this.parentCallId = parentCallId;
|
||||||
|
|
||||||
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
const randomIdPart = Math.random().toString(36).slice(2, 8);
|
||||||
// parentPromptId will be undefined if this agent is invoked directly
|
// parentPromptId will be undefined if this agent is invoked directly
|
||||||
@@ -763,26 +771,28 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
let submittedOutput: string | null = null;
|
let submittedOutput: string | null = null;
|
||||||
let taskCompleted = false;
|
let taskCompleted = false;
|
||||||
|
|
||||||
// We'll collect promises for the tool executions
|
// We'll separate complete_task from other tools
|
||||||
const toolExecutionPromises: Array<Promise<Part[] | void>> = [];
|
const toolRequests: ToolCallRequestInfo[] = [];
|
||||||
// And we'll need a place to store the synchronous results (like complete_task or blocked calls)
|
// Map to keep track of tool name by callId for activity emission
|
||||||
const syncResponseParts: Part[] = [];
|
const toolNameMap = new Map<string, string>();
|
||||||
|
// Synchronous results (like complete_task or unauthorized calls)
|
||||||
|
const syncResults = new Map<string, Part>();
|
||||||
|
|
||||||
for (const [index, functionCall] of functionCalls.entries()) {
|
for (const [index, functionCall] of functionCalls.entries()) {
|
||||||
const callId = functionCall.id ?? `${promptId}-${index}`;
|
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||||
const args = functionCall.args ?? {};
|
const args = functionCall.args ?? {};
|
||||||
|
const toolName = functionCall.name as string;
|
||||||
|
|
||||||
this.emitActivity('TOOL_CALL_START', {
|
this.emitActivity('TOOL_CALL_START', {
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
args,
|
args,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (functionCall.name === TASK_COMPLETE_TOOL_NAME) {
|
if (toolName === TASK_COMPLETE_TOOL_NAME) {
|
||||||
if (taskCompleted) {
|
if (taskCompleted) {
|
||||||
// We already have a completion from this turn. Ignore subsequent ones.
|
|
||||||
const error =
|
const error =
|
||||||
'Task already marked complete in this turn. Ignoring duplicate call.';
|
'Task already marked complete in this turn. Ignoring duplicate call.';
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { error },
|
response: { error },
|
||||||
@@ -791,7 +801,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
});
|
});
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call',
|
context: 'tool_call',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
error,
|
error,
|
||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
@@ -809,7 +819,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
if (!validationResult.success) {
|
if (!validationResult.success) {
|
||||||
taskCompleted = false; // Validation failed, revoke completion
|
taskCompleted = false; // Validation failed, revoke completion
|
||||||
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
|
const error = `Output validation failed: ${JSON.stringify(validationResult.error.flatten())}`;
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { error },
|
response: { error },
|
||||||
@@ -818,7 +828,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
});
|
});
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call',
|
context: 'tool_call',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
error,
|
error,
|
||||||
});
|
});
|
||||||
continue;
|
continue;
|
||||||
@@ -833,7 +843,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
? outputValue
|
? outputValue
|
||||||
: JSON.stringify(outputValue, null, 2);
|
: JSON.stringify(outputValue, null, 2);
|
||||||
}
|
}
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { result: 'Output submitted and task completed.' },
|
response: { result: 'Output submitted and task completed.' },
|
||||||
@@ -841,14 +851,14 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
this.emitActivity('TOOL_CALL_END', {
|
this.emitActivity('TOOL_CALL_END', {
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
output: 'Output submitted and task completed.',
|
output: 'Output submitted and task completed.',
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
// Failed to provide required output.
|
// Failed to provide required output.
|
||||||
taskCompleted = false; // Revoke completion status
|
taskCompleted = false; // Revoke completion status
|
||||||
const error = `Missing required argument '${outputName}' for completion.`;
|
const error = `Missing required argument '${outputName}' for completion.`;
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { error },
|
response: { error },
|
||||||
@@ -857,7 +867,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
});
|
});
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call',
|
context: 'tool_call',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
error,
|
error,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -873,7 +883,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
typeof resultArg === 'string'
|
typeof resultArg === 'string'
|
||||||
? resultArg
|
? resultArg
|
||||||
: JSON.stringify(resultArg, null, 2);
|
: JSON.stringify(resultArg, null, 2);
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { status: 'Result submitted and task completed.' },
|
response: { status: 'Result submitted and task completed.' },
|
||||||
@@ -881,7 +891,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
this.emitActivity('TOOL_CALL_END', {
|
this.emitActivity('TOOL_CALL_END', {
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
output: 'Result submitted and task completed.',
|
output: 'Result submitted and task completed.',
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
@@ -889,7 +899,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
taskCompleted = false; // Revoke completion
|
taskCompleted = false; // Revoke completion
|
||||||
const error =
|
const error =
|
||||||
'Missing required "result" argument. You must provide your findings when calling complete_task.';
|
'Missing required "result" argument. You must provide your findings when calling complete_task.';
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: TASK_COMPLETE_TOOL_NAME,
|
name: TASK_COMPLETE_TOOL_NAME,
|
||||||
response: { error },
|
response: { error },
|
||||||
@@ -898,7 +908,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
});
|
});
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call',
|
context: 'tool_call',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
error,
|
error,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -907,14 +917,13 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle standard tools
|
// Handle standard tools
|
||||||
if (!allowedToolNames.has(functionCall.name as string)) {
|
if (!allowedToolNames.has(toolName)) {
|
||||||
const error = createUnauthorizedToolError(functionCall.name as string);
|
const error = createUnauthorizedToolError(toolName);
|
||||||
|
|
||||||
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
|
debugLogger.warn(`[LocalAgentExecutor] Blocked call: ${error}`);
|
||||||
|
|
||||||
syncResponseParts.push({
|
syncResults.set(callId, {
|
||||||
functionResponse: {
|
functionResponse: {
|
||||||
name: functionCall.name as string,
|
name: toolName,
|
||||||
id: callId,
|
id: callId,
|
||||||
response: { error },
|
response: { error },
|
||||||
},
|
},
|
||||||
@@ -922,7 +931,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
|
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call_unauthorized',
|
context: 'tool_call_unauthorized',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
callId,
|
callId,
|
||||||
error,
|
error,
|
||||||
});
|
});
|
||||||
@@ -930,53 +939,63 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const requestInfo: ToolCallRequestInfo = {
|
toolRequests.push({
|
||||||
callId,
|
callId,
|
||||||
name: functionCall.name as string,
|
name: toolName,
|
||||||
args,
|
args,
|
||||||
isClientInitiated: true,
|
isClientInitiated: false, // These are coming from the subagent (the "model")
|
||||||
prompt_id: promptId,
|
prompt_id: promptId,
|
||||||
};
|
});
|
||||||
|
toolNameMap.set(callId, toolName);
|
||||||
|
}
|
||||||
|
|
||||||
// Create a promise for the tool execution
|
// Execute standard tool calls using the new scheduler
|
||||||
const executionPromise = (async () => {
|
if (toolRequests.length > 0) {
|
||||||
const agentContext = Object.create(this.runtimeContext);
|
const completedCalls = await scheduleAgentTools(
|
||||||
agentContext.getToolRegistry = () => this.toolRegistry;
|
this.runtimeContext,
|
||||||
agentContext.getApprovalMode = () => ApprovalMode.YOLO;
|
toolRequests,
|
||||||
|
{
|
||||||
const { response: toolResponse } = await executeToolCall(
|
schedulerId: this.agentId,
|
||||||
agentContext,
|
parentCallId: this.parentCallId,
|
||||||
requestInfo,
|
toolRegistry: this.toolRegistry,
|
||||||
signal,
|
signal,
|
||||||
);
|
},
|
||||||
|
);
|
||||||
|
|
||||||
if (toolResponse.error) {
|
for (const call of completedCalls) {
|
||||||
|
const toolName =
|
||||||
|
toolNameMap.get(call.request.callId) || call.request.name;
|
||||||
|
if (call.status === 'success') {
|
||||||
|
this.emitActivity('TOOL_CALL_END', {
|
||||||
|
name: toolName,
|
||||||
|
output: call.response.resultDisplay,
|
||||||
|
});
|
||||||
|
} else if (call.status === 'error') {
|
||||||
this.emitActivity('ERROR', {
|
this.emitActivity('ERROR', {
|
||||||
context: 'tool_call',
|
context: 'tool_call',
|
||||||
name: functionCall.name,
|
name: toolName,
|
||||||
error: toolResponse.error.message,
|
error: call.response.error?.message || 'Unknown error',
|
||||||
});
|
});
|
||||||
} else {
|
} else if (call.status === 'cancelled') {
|
||||||
this.emitActivity('TOOL_CALL_END', {
|
this.emitActivity('ERROR', {
|
||||||
name: functionCall.name,
|
context: 'tool_call',
|
||||||
output: toolResponse.resultDisplay,
|
name: toolName,
|
||||||
|
error: 'Tool call was cancelled.',
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return toolResponse.responseParts;
|
// Add result to syncResults to preserve order later
|
||||||
})();
|
syncResults.set(call.request.callId, call.response.responseParts[0]);
|
||||||
|
}
|
||||||
toolExecutionPromises.push(executionPromise);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for all tool executions to complete
|
// Reconstruct toolResponseParts in the original order
|
||||||
const asyncResults = await Promise.all(toolExecutionPromises);
|
const toolResponseParts: Part[] = [];
|
||||||
|
for (const [index, functionCall] of functionCalls.entries()) {
|
||||||
// Combine all response parts
|
const callId = functionCall.id ?? `${promptId}-${index}`;
|
||||||
const toolResponseParts: Part[] = [...syncResponseParts];
|
const part = syncResults.get(callId);
|
||||||
for (const result of asyncResults) {
|
if (part) {
|
||||||
if (result) {
|
toolResponseParts.push(part);
|
||||||
toolResponseParts.push(...result);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ import { ROOT_SCHEDULER_ID } from './types.js';
|
|||||||
import { ToolErrorType } from '../tools/tool-error.js';
|
import { ToolErrorType } from '../tools/tool-error.js';
|
||||||
import * as ToolUtils from '../utils/tool-utils.js';
|
import * as ToolUtils from '../utils/tool-utils.js';
|
||||||
import type { EditorType } from '../utils/editor.js';
|
import type { EditorType } from '../utils/editor.js';
|
||||||
|
import {
|
||||||
|
getToolCallContext,
|
||||||
|
type ToolCallContext,
|
||||||
|
} from '../utils/toolCallContext.js';
|
||||||
|
|
||||||
describe('Scheduler (Orchestrator)', () => {
|
describe('Scheduler (Orchestrator)', () => {
|
||||||
let scheduler: Scheduler;
|
let scheduler: Scheduler;
|
||||||
@@ -1010,4 +1014,68 @@ describe('Scheduler (Orchestrator)', () => {
|
|||||||
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
|
expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Tool Call Context Propagation', () => {
|
||||||
|
it('should propagate context to the tool executor', async () => {
|
||||||
|
const schedulerId = 'custom-scheduler';
|
||||||
|
const parentCallId = 'parent-call';
|
||||||
|
const customScheduler = new Scheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
messageBus: mockMessageBus,
|
||||||
|
getPreferredEditor,
|
||||||
|
schedulerId,
|
||||||
|
parentCallId,
|
||||||
|
});
|
||||||
|
|
||||||
|
const validatingCall: ValidatingToolCall = {
|
||||||
|
status: 'validating',
|
||||||
|
request: req1,
|
||||||
|
tool: mockTool,
|
||||||
|
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Mock queueLength to run the loop once
|
||||||
|
Object.defineProperty(mockStateManager, 'queueLength', {
|
||||||
|
get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0),
|
||||||
|
configurable: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall);
|
||||||
|
Object.defineProperty(mockStateManager, 'firstActiveCall', {
|
||||||
|
get: vi.fn().mockReturnValue(validatingCall),
|
||||||
|
configurable: true,
|
||||||
|
});
|
||||||
|
vi.mocked(mockStateManager.getToolCall).mockReturnValue(validatingCall);
|
||||||
|
|
||||||
|
mockToolRegistry.getTool.mockReturnValue(mockTool);
|
||||||
|
mockPolicyEngine.check.mockResolvedValue({
|
||||||
|
decision: PolicyDecision.ALLOW,
|
||||||
|
});
|
||||||
|
|
||||||
|
let capturedContext: ToolCallContext | undefined;
|
||||||
|
mockExecutor.execute.mockImplementation(async () => {
|
||||||
|
capturedContext = getToolCallContext();
|
||||||
|
return {
|
||||||
|
status: 'success',
|
||||||
|
request: req1,
|
||||||
|
tool: mockTool,
|
||||||
|
invocation: mockInvocation as unknown as AnyToolInvocation,
|
||||||
|
response: {
|
||||||
|
callId: req1.callId,
|
||||||
|
responseParts: [],
|
||||||
|
resultDisplay: 'ok',
|
||||||
|
error: undefined,
|
||||||
|
errorType: undefined,
|
||||||
|
},
|
||||||
|
} as unknown as SuccessfulToolCall;
|
||||||
|
});
|
||||||
|
|
||||||
|
await customScheduler.schedule(req1, signal);
|
||||||
|
|
||||||
|
expect(capturedContext).toBeDefined();
|
||||||
|
expect(capturedContext!.callId).toBe(req1.callId);
|
||||||
|
expect(capturedContext!.schedulerId).toBe(schedulerId);
|
||||||
|
expect(capturedContext!.parentCallId).toBe(parentCallId);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import {
|
|||||||
type SerializableConfirmationDetails,
|
type SerializableConfirmationDetails,
|
||||||
type ToolConfirmationRequest,
|
type ToolConfirmationRequest,
|
||||||
} from '../confirmation-bus/types.js';
|
} from '../confirmation-bus/types.js';
|
||||||
|
import { runWithToolCallContext } from '../utils/toolCallContext.js';
|
||||||
|
|
||||||
interface SchedulerQueueItem {
|
interface SchedulerQueueItem {
|
||||||
requests: ToolCallRequestInfo[];
|
requests: ToolCallRequestInfo[];
|
||||||
@@ -256,6 +257,7 @@ export class Scheduler {
|
|||||||
return this.state.completedBatch;
|
return this.state.completedBatch;
|
||||||
} finally {
|
} finally {
|
||||||
this.isProcessing = false;
|
this.isProcessing = false;
|
||||||
|
this.state.clearBatch();
|
||||||
this._processNextInRequestQueue();
|
this._processNextInRequestQueue();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -282,30 +284,39 @@ export class Scheduler {
|
|||||||
request: ToolCallRequestInfo,
|
request: ToolCallRequestInfo,
|
||||||
tool: AnyDeclarativeTool,
|
tool: AnyDeclarativeTool,
|
||||||
): ValidatingToolCall | ErroredToolCall {
|
): ValidatingToolCall | ErroredToolCall {
|
||||||
try {
|
return runWithToolCallContext(
|
||||||
const invocation = tool.build(request.args);
|
{
|
||||||
return {
|
callId: request.callId,
|
||||||
status: 'validating',
|
|
||||||
request,
|
|
||||||
tool,
|
|
||||||
invocation,
|
|
||||||
startTime: Date.now(),
|
|
||||||
schedulerId: this.schedulerId,
|
schedulerId: this.schedulerId,
|
||||||
};
|
parentCallId: this.parentCallId,
|
||||||
} catch (e) {
|
},
|
||||||
return {
|
() => {
|
||||||
status: 'error',
|
try {
|
||||||
request,
|
const invocation = tool.build(request.args);
|
||||||
tool,
|
return {
|
||||||
response: createErrorResponse(
|
status: 'validating',
|
||||||
request,
|
request,
|
||||||
e instanceof Error ? e : new Error(String(e)),
|
tool,
|
||||||
ToolErrorType.INVALID_TOOL_PARAMS,
|
invocation,
|
||||||
),
|
startTime: Date.now(),
|
||||||
durationMs: 0,
|
schedulerId: this.schedulerId,
|
||||||
schedulerId: this.schedulerId,
|
};
|
||||||
};
|
} catch (e) {
|
||||||
}
|
return {
|
||||||
|
status: 'error',
|
||||||
|
request,
|
||||||
|
tool,
|
||||||
|
response: createErrorResponse(
|
||||||
|
request,
|
||||||
|
e instanceof Error ? e : new Error(String(e)),
|
||||||
|
ToolErrorType.INVALID_TOOL_PARAMS,
|
||||||
|
),
|
||||||
|
durationMs: 0,
|
||||||
|
schedulerId: this.schedulerId,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Phase 2: Processing Loop ---
|
// --- Phase 2: Processing Loop ---
|
||||||
@@ -460,17 +471,29 @@ export class Scheduler {
|
|||||||
if (signal.aborted) throw new Error('Operation cancelled');
|
if (signal.aborted) throw new Error('Operation cancelled');
|
||||||
this.state.updateStatus(callId, 'executing');
|
this.state.updateStatus(callId, 'executing');
|
||||||
|
|
||||||
const result = await this.executor.execute({
|
const activeCall = this.state.firstActiveCall as ExecutingToolCall;
|
||||||
call: this.state.firstActiveCall as ExecutingToolCall,
|
|
||||||
signal,
|
const result = await runWithToolCallContext(
|
||||||
outputUpdateHandler: (id, out) =>
|
{
|
||||||
this.state.updateStatus(id, 'executing', { liveOutput: out }),
|
callId: activeCall.request.callId,
|
||||||
onUpdateToolCall: (updated) => {
|
schedulerId: this.schedulerId,
|
||||||
if (updated.status === 'executing' && updated.pid) {
|
parentCallId: this.parentCallId,
|
||||||
this.state.updateStatus(callId, 'executing', { pid: updated.pid });
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
});
|
() =>
|
||||||
|
this.executor.execute({
|
||||||
|
call: activeCall,
|
||||||
|
signal,
|
||||||
|
outputUpdateHandler: (id, out) =>
|
||||||
|
this.state.updateStatus(id, 'executing', { liveOutput: out }),
|
||||||
|
onUpdateToolCall: (updated) => {
|
||||||
|
if (updated.status === 'executing' && updated.pid) {
|
||||||
|
this.state.updateStatus(callId, 'executing', {
|
||||||
|
pid: updated.pid,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
if (result.status === 'success') {
|
if (result.status === 'success') {
|
||||||
this.state.updateStatus(callId, 'success', result.response);
|
this.state.updateStatus(callId, 'success', result.response);
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2026 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect } from 'vitest';
|
||||||
|
import {
|
||||||
|
runWithToolCallContext,
|
||||||
|
getToolCallContext,
|
||||||
|
} from './toolCallContext.js';
|
||||||
|
|
||||||
|
describe('toolCallContext', () => {
|
||||||
|
it('should store and retrieve tool call context', () => {
|
||||||
|
const context = {
|
||||||
|
callId: 'test-call-id',
|
||||||
|
schedulerId: 'test-scheduler-id',
|
||||||
|
};
|
||||||
|
|
||||||
|
runWithToolCallContext(context, () => {
|
||||||
|
const storedContext = getToolCallContext();
|
||||||
|
expect(storedContext).toEqual(context);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return undefined when no context is set', () => {
|
||||||
|
expect(getToolCallContext()).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should support nested contexts', () => {
|
||||||
|
const parentContext = {
|
||||||
|
callId: 'parent-call-id',
|
||||||
|
schedulerId: 'parent-scheduler-id',
|
||||||
|
};
|
||||||
|
|
||||||
|
const childContext = {
|
||||||
|
callId: 'child-call-id',
|
||||||
|
schedulerId: 'child-scheduler-id',
|
||||||
|
parentCallId: 'parent-call-id',
|
||||||
|
};
|
||||||
|
|
||||||
|
runWithToolCallContext(parentContext, () => {
|
||||||
|
expect(getToolCallContext()).toEqual(parentContext);
|
||||||
|
|
||||||
|
runWithToolCallContext(childContext, () => {
|
||||||
|
expect(getToolCallContext()).toEqual(childContext);
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(getToolCallContext()).toEqual(parentContext);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should maintain isolation between parallel executions', async () => {
|
||||||
|
const context1 = {
|
||||||
|
callId: 'call-1',
|
||||||
|
schedulerId: 'scheduler-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
const context2 = {
|
||||||
|
callId: 'call-2',
|
||||||
|
schedulerId: 'scheduler-2',
|
||||||
|
};
|
||||||
|
|
||||||
|
const promise1 = new Promise<void>((resolve) => {
|
||||||
|
runWithToolCallContext(context1, () => {
|
||||||
|
setTimeout(() => {
|
||||||
|
expect(getToolCallContext()).toEqual(context1);
|
||||||
|
resolve();
|
||||||
|
}, 10);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
const promise2 = new Promise<void>((resolve) => {
|
||||||
|
runWithToolCallContext(context2, () => {
|
||||||
|
setTimeout(() => {
|
||||||
|
expect(getToolCallContext()).toEqual(context2);
|
||||||
|
resolve();
|
||||||
|
}, 5);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all([promise1, promise2]);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2026 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Contextual information for a tool call execution.
|
||||||
|
*/
|
||||||
|
export interface ToolCallContext {
|
||||||
|
/** The unique ID of the tool call. */
|
||||||
|
callId: string;
|
||||||
|
/** The ID of the scheduler managing the execution. */
|
||||||
|
schedulerId: string;
|
||||||
|
/** The ID of the parent tool call, if this is a nested execution (e.g., in a subagent). */
|
||||||
|
parentCallId?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AsyncLocalStorage instance for tool call context.
|
||||||
|
*/
|
||||||
|
export const toolCallContext = new AsyncLocalStorage<ToolCallContext>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs a function within a tool call context.
|
||||||
|
*
|
||||||
|
* @param context The context to set.
|
||||||
|
* @param fn The function to run.
|
||||||
|
* @returns The result of the function.
|
||||||
|
*/
|
||||||
|
export function runWithToolCallContext<T>(
|
||||||
|
context: ToolCallContext,
|
||||||
|
fn: () => T,
|
||||||
|
): T {
|
||||||
|
return toolCallContext.run(context, fn);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retrieves the current tool call context.
|
||||||
|
*
|
||||||
|
* @returns The current ToolCallContext, or undefined if not in a context.
|
||||||
|
*/
|
||||||
|
export function getToolCallContext(): ToolCallContext | undefined {
|
||||||
|
return toolCallContext.getStore();
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user