diff --git a/packages/core/src/core/coreToolHookTriggers.test.ts b/packages/core/src/core/coreToolHookTriggers.test.ts index 2a654042c6..ff9601fc33 100644 --- a/packages/core/src/core/coreToolHookTriggers.test.ts +++ b/packages/core/src/core/coreToolHookTriggers.test.ts @@ -11,6 +11,7 @@ import { BaseToolInvocation, type ToolResult, type AnyDeclarativeTool, + type ToolLiveOutput, } from '../tools/tools.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { HookSystem } from '../hooks/hookSystem.js'; @@ -37,6 +38,30 @@ class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> { } } +class MockBackgroundableInvocation extends BaseToolInvocation< + { key?: string }, + ToolResult +> { + constructor(params: { key?: string }, messageBus: MessageBus) { + super(params, messageBus); + } + getDescription() { + return 'mock-pid'; + } + async execute( + _signal: AbortSignal, + _updateOutput?: (output: ToolLiveOutput) => void, + _shellExecutionConfig?: unknown, + setExecutionIdCallback?: (executionId: number) => void, + ) { + setExecutionIdCallback?.(4242); + return { + llmContent: 'pid', + returnDisplay: 'pid', + }; + } +} + describe('executeToolWithHooks', () => { let messageBus: MessageBus; let mockTool: AnyDeclarativeTool; @@ -258,4 +283,26 @@ describe('executeToolWithHooks', () => { expect(invocation.params.key).toBe('original'); expect(mockTool.build).not.toHaveBeenCalled(); }); + + it('should pass execution ID callback through for non-shell invocations', async () => { + const invocation = new MockBackgroundableInvocation({}, messageBus); + const abortSignal = new AbortController().signal; + const setExecutionIdCallback = vi.fn(); + + vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(undefined); + vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined); + + await executeToolWithHooks( + invocation, + 'test_tool', + abortSignal, + mockTool, + undefined, + undefined, + setExecutionIdCallback, + mockConfig, + ); + + expect(setExecutionIdCallback).toHaveBeenCalledWith(4242); + }); }); diff --git a/packages/core/src/scheduler/tool-executor.test.ts b/packages/core/src/scheduler/tool-executor.test.ts index b382b2b208..9e26ff4b3e 100644 --- a/packages/core/src/scheduler/tool-executor.test.ts +++ b/packages/core/src/scheduler/tool-executor.test.ts @@ -534,6 +534,59 @@ describe('ToolExecutor', () => { ); }); + it('should report execution ID updates for non-shell backgroundable tools', async () => { + const mockTool = new MockTool({ + name: 'remote_agent_call', + description: 'Remote agent call', + }); + const invocation = mockTool.build({}); + + const testExecutionId = 67890; + vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation( + async ( + _inv, + _name, + _sig, + _tool, + _liveCb, + _shellCfg, + setExecutionIdCallback, + ) => { + setExecutionIdCallback?.(testExecutionId); + return { llmContent: 'done', returnDisplay: 'done' }; + }, + ); + + const scheduledCall: ScheduledToolCall = { + status: CoreToolCallStatus.Scheduled, + request: { + callId: 'call-remote-pid', + name: 'remote_agent_call', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-remote-pid', + }, + tool: mockTool, + invocation: invocation as unknown as AnyToolInvocation, + startTime: Date.now(), + }; + + const onUpdateToolCall = vi.fn(); + + await executor.execute({ + call: scheduledCall, + signal: new AbortController().signal, + onUpdateToolCall, + }); + + expect(onUpdateToolCall).toHaveBeenCalledWith( + expect.objectContaining({ + status: CoreToolCallStatus.Executing, + pid: testExecutionId, + }), + ); + }); + it('should return cancelled result with partial output when signal is aborted', async () => { const mockTool = new MockTool({ name: 'slowTool', diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index da0e03d10b..585805dd5f 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -96,20 +96,30 @@ export function isBackgroundExecutionData( return false; } - const value = data as Partial; + const executionId = 'executionId' in data ? data.executionId : undefined; + const pid = 'pid' in data ? data.pid : undefined; + const command = 'command' in data ? data.command : undefined; + const initialOutput = + 'initialOutput' in data ? data.initialOutput : undefined; + return ( - (value.executionId === undefined || typeof value.executionId === 'number') && - (value.pid === undefined || typeof value.pid === 'number') && - (value.command === undefined || typeof value.command === 'string') && - (value.initialOutput === undefined || - typeof value.initialOutput === 'string') + (executionId === undefined || typeof executionId === 'number') && + (pid === undefined || typeof pid === 'number') && + (command === undefined || typeof command === 'string') && + (initialOutput === undefined || typeof initialOutput === 'string') ); } export function getBackgroundExecutionId( data: BackgroundExecutionData, ): number | undefined { - return data.executionId ?? data.pid; + if (typeof data.executionId === 'number') { + return data.executionId; + } + if (typeof data.pid === 'number') { + return data.pid; + } + return undefined; } /**