From 66c2184fe55fe863929ea79908aa1d8a36876f32 Mon Sep 17 00:00:00 2001 From: fuyou Date: Thu, 25 Sep 2025 03:10:55 +0800 Subject: [PATCH] feat: Add AbortSignal support for retry logic and tool execution (#9196) Co-authored-by: Sandy Tao --- packages/core/src/core/coreToolScheduler.ts | 30 ++-- packages/core/src/tools/mcp-tool.test.ts | 155 ++++++++++++++++++++ packages/core/src/tools/mcp-tool.ts | 33 ++++- 3 files changed, 205 insertions(+), 13 deletions(-) diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 689e5abc15..6743c0a1d6 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -1078,17 +1078,25 @@ export class CoreToolScheduler { } }) .catch((executionError: Error) => { - this.setStatusInternal( - callId, - 'error', - createErrorResponse( - scheduledCall.request, - executionError instanceof Error - ? executionError - : new Error(String(executionError)), - ToolErrorType.UNHANDLED_EXCEPTION, - ), - ); + if (signal.aborted) { + this.setStatusInternal( + callId, + 'cancelled', + 'User cancelled tool execution.', + ); + } else { + this.setStatusInternal( + callId, + 'error', + createErrorResponse( + scheduledCall.request, + executionError instanceof Error + ? executionError + : new Error(String(executionError)), + ToolErrorType.UNHANDLED_EXCEPTION, + ), + ); + } }); }); } diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index 9fb155e4c8..680fa92998 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -572,6 +572,161 @@ describe('DiscoveredMCPTool', () => { 'Here is a resource.\n[Link to My Resource: file:///path/to/resource]\nEmbedded text content.\n[Image: image/jpeg]', ); }); + + describe('AbortSignal support', () => { + it('should abort immediately if signal is already aborted', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + controller.abort(); + + const invocation = tool.build(params); + + await expect(invocation.execute(controller.signal)).rejects.toThrow( + 'Tool call aborted', + ); + + // Tool should not be called if signal is already aborted + expect(mockCallTool).not.toHaveBeenCalled(); + }); + + it('should abort during tool execution', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + + // Mock a delayed response to simulate long-running tool + mockCallTool.mockImplementation( + () => + new Promise((resolve) => { + setTimeout(() => { + resolve([ + { + functionResponse: { + name: serverToolName, + response: { + content: [{ type: 'text', text: 'Success' }], + }, + }, + }, + ]); + }, 1000); + }), + ); + + const invocation = tool.build(params); + const promise = invocation.execute(controller.signal); + + // Abort after a short delay to simulate cancellation during execution + setTimeout(() => controller.abort(), 50); + + await expect(promise).rejects.toThrow('Tool call aborted'); + }); + + it('should complete successfully if not aborted', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + const successResponse = [ + { + functionResponse: { + name: serverToolName, + response: { + content: [{ type: 'text', text: 'Success' }], + }, + }, + }, + ]; + + mockCallTool.mockResolvedValue(successResponse); + + const invocation = tool.build(params); + const result = await invocation.execute(controller.signal); + + expect(result.llmContent).toEqual([{ text: 'Success' }]); + expect(result.returnDisplay).toBe('Success'); + expect(mockCallTool).toHaveBeenCalledWith([ + { name: serverToolName, args: params }, + ]); + }); + + it('should handle tool error even when abort signal is provided', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + const errorResponse = [ + { + functionResponse: { + name: serverToolName, + response: { error: { isError: true } }, + }, + }, + ]; + + mockCallTool.mockResolvedValue(errorResponse); + + const invocation = tool.build(params); + const result = await invocation.execute(controller.signal); + + expect(result.error?.type).toBe(ToolErrorType.MCP_TOOL_ERROR); + expect(result.returnDisplay).toContain( + `Error: MCP tool '${serverToolName}' reported an error.`, + ); + }); + + it('should handle callTool rejection with abort signal', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + const expectedError = new Error('Network error'); + + mockCallTool.mockRejectedValue(expectedError); + + const invocation = tool.build(params); + + await expect(invocation.execute(controller.signal)).rejects.toThrow( + expectedError, + ); + }); + + it('should cleanup event listeners properly on successful completion', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + const successResponse = [ + { + functionResponse: { + name: serverToolName, + response: { + content: [{ type: 'text', text: 'Success' }], + }, + }, + }, + ]; + + mockCallTool.mockResolvedValue(successResponse); + + const invocation = tool.build(params); + await invocation.execute(controller.signal); + + controller.abort(); + expect(controller.signal.aborted).toBe(true); + }); + + it('should cleanup event listeners properly on error', async () => { + const params = { param: 'test' }; + const controller = new AbortController(); + const expectedError = new Error('Tool execution failed'); + + mockCallTool.mockRejectedValue(expectedError); + + const invocation = tool.build(params); + + try { + await invocation.execute(controller.signal); + } catch (error) { + expect(error).toBe(expectedError); + } + + // Verify cleanup by aborting after error + controller.abort(); + expect(controller.signal.aborted).toBe(true); + }); + }); }); describe('shouldConfirmExecute', () => { diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index b53487799b..afffa103e5 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -131,7 +131,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< return false; } - async execute(): Promise { + async execute(signal: AbortSignal): Promise { const functionCalls: FunctionCall[] = [ { name: this.serverToolName, @@ -139,7 +139,36 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< }, ]; - const rawResponseParts = await this.mcpTool.callTool(functionCalls); + // Race MCP tool call with abort signal to respect cancellation + const rawResponseParts = await new Promise((resolve, reject) => { + if (signal.aborted) { + const error = new Error('Tool call aborted'); + error.name = 'AbortError'; + reject(error); + return; + } + const onAbort = () => { + cleanup(); + const error = new Error('Tool call aborted'); + error.name = 'AbortError'; + reject(error); + }; + const cleanup = () => { + signal.removeEventListener('abort', onAbort); + }; + signal.addEventListener('abort', onAbort, { once: true }); + + this.mcpTool + .callTool(functionCalls) + .then((res) => { + cleanup(); + resolve(res); + }) + .catch((err) => { + cleanup(); + reject(err); + }); + }); // Ensure the response is not an error if (this.isMCPToolError(rawResponseParts)) {