From 89aba7cbcda666877ce010c281cc72fa4449f0e9 Mon Sep 17 00:00:00 2001 From: Sandy Tao Date: Wed, 24 Sep 2025 12:16:00 -0700 Subject: [PATCH] fix(core): Fix unable to cancel edit tool (#9299) --- .../core/src/core/coreToolScheduler.test.ts | 138 ++++++++++++++++++ packages/core/src/core/coreToolScheduler.ts | 9 ++ packages/core/src/tools/edit.test.ts | 55 +++++++ packages/core/src/tools/edit.ts | 6 + packages/core/src/tools/smart-edit.test.ts | 63 ++++++++ packages/core/src/tools/smart-edit.ts | 6 + 6 files changed, 277 insertions(+) diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index d9ce66e133..4aaa0b3d45 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -109,6 +109,65 @@ class TestApprovalInvocation extends BaseToolInvocation< } } +class AbortDuringConfirmationInvocation extends BaseToolInvocation< + Record, + ToolResult +> { + constructor( + private readonly abortController: AbortController, + private readonly abortError: Error, + params: Record, + ) { + super(params); + } + + override async shouldConfirmExecute( + _signal: AbortSignal, + ): Promise { + this.abortController.abort(); + throw this.abortError; + } + + async execute(_abortSignal: AbortSignal): Promise { + throw new Error('execute should not be called when confirmation fails'); + } + + getDescription(): string { + return 'Abort during confirmation invocation'; + } +} + +class AbortDuringConfirmationTool extends BaseDeclarativeTool< + Record, + ToolResult +> { + constructor( + private readonly abortController: AbortController, + private readonly abortError: Error, + ) { + super( + 'abortDuringConfirmationTool', + 'Abort During Confirmation Tool', + 'A tool that aborts while confirming execution.', + Kind.Other, + { + type: 'object', + properties: {}, + }, + ); + } + + protected createInvocation( + params: Record, + ): ToolInvocation, ToolResult> { + return new AbortDuringConfirmationInvocation( + this.abortController, + this.abortError, + params, + ); + } +} + async function waitForStatus( onToolCallsUpdate: Mock, status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled', @@ -218,6 +277,85 @@ describe('CoreToolScheduler', () => { expect(completedCalls[0].status).toBe('cancelled'); }); + it('should mark tool call as cancelled when abort happens during confirmation error', async () => { + const abortController = new AbortController(); + const abortError = new Error('Abort requested during confirmation'); + const declarativeTool = new AbortDuringConfirmationTool( + abortController, + abortError, + ); + + const mockToolRegistry = { + getTool: () => declarativeTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByName: () => declarativeTool, + getToolByDisplayName: () => declarativeTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.DEFAULT, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getToolRegistry: () => mockToolRegistry, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, + } as unknown as Config; + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const request = { + callId: 'abort-1', + name: 'abortDuringConfirmationTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-abort', + }; + + await scheduler.schedule([request], abortController.signal); + + expect(onAllToolCallsComplete).toHaveBeenCalled(); + const completedCalls = onAllToolCallsComplete.mock + .calls[0][0] as ToolCall[]; + expect(completedCalls[0].status).toBe('cancelled'); + const statuses = onToolCallsUpdate.mock.calls.flatMap((call) => + (call[0] as ToolCall[]).map((toolCall) => toolCall.status), + ); + expect(statuses).not.toContain('error'); + }); + describe('getToolSuggestion', () => { it('should suggest the top N closest tool names for a typo', () => { // Create mocked tool registry diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 6743c0a1d6..e1972a671a 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -799,6 +799,15 @@ export class CoreToolScheduler { ); } } catch (error) { + if (signal.aborted) { + this.setStatusInternal( + reqInfo.callId, + 'cancelled', + 'Tool call cancelled by user.', + ); + continue; + } + this.setStatusInternal( reqInfo.callId, 'error', diff --git a/packages/core/src/tools/edit.test.ts b/packages/core/src/tools/edit.test.ts index d2ca62d929..8d81c3e844 100644 --- a/packages/core/src/tools/edit.test.ts +++ b/packages/core/src/tools/edit.test.ts @@ -471,6 +471,34 @@ describe('EditTool', () => { ); expect(patchedContent).toBe(expectedFinalContent); }); + + it('should rethrow calculateEdit errors when the abort signal is triggered', async () => { + const filePath = path.join(rootDir, 'abort-confirmation.txt'); + const params: EditToolParams = { + file_path: filePath, + old_string: 'old', + new_string: 'new', + }; + + const invocation = tool.build(params); + const abortController = new AbortController(); + const abortError = new Error('Abort requested'); + + const calculateSpy = vi + .spyOn(invocation as any, 'calculateEdit') + .mockImplementation(async () => { + if (!abortController.signal.aborted) { + abortController.abort(); + } + throw abortError; + }); + + await expect( + invocation.shouldConfirmExecute(abortController.signal), + ).rejects.toBe(abortError); + + calculateSpy.mockRestore(); + }); }); describe('execute', () => { @@ -515,6 +543,33 @@ describe('EditTool', () => { ); }); + it('should reject when calculateEdit fails after an abort signal', async () => { + const params: EditToolParams = { + file_path: path.join(rootDir, 'abort-execute.txt'), + old_string: 'old', + new_string: 'new', + }; + + const invocation = tool.build(params); + const abortController = new AbortController(); + const abortError = new Error('Abort requested during execute'); + + const calculateSpy = vi + .spyOn(invocation as any, 'calculateEdit') + .mockImplementation(async () => { + if (!abortController.signal.aborted) { + abortController.abort(); + } + throw abortError; + }); + + await expect(invocation.execute(abortController.signal)).rejects.toBe( + abortError, + ); + + calculateSpy.mockRestore(); + }); + it('should edit an existing file and return diff with fileName', async () => { const initialContent = 'This is some old text.'; const newContent = 'This is some new text.'; // old -> new diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index cf7029355e..e4f6d4a1c3 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -251,6 +251,9 @@ class EditToolInvocation implements ToolInvocation { try { editData = await this.calculateEdit(this.params, abortSignal); } catch (error) { + if (abortSignal.aborted) { + throw error; + } const errorMsg = error instanceof Error ? error.message : String(error); console.log(`Error preparing edit: ${errorMsg}`); return false; @@ -336,6 +339,9 @@ class EditToolInvocation implements ToolInvocation { try { editData = await this.calculateEdit(this.params, signal); } catch (error) { + if (signal.aborted) { + throw error; + } const errorMsg = error instanceof Error ? error.message : String(error); return { llmContent: `Error preparing edit: ${errorMsg}`, diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 7a60614293..cba9c4890f 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -274,6 +274,36 @@ describe('SmartEditTool', () => { filePath = path.join(rootDir, testFile); }); + it('should reject when calculateEdit fails after an abort signal', async () => { + const params: EditToolParams = { + file_path: path.join(rootDir, 'abort-execute.txt'), + instruction: 'Abort during execute', + old_string: 'old', + new_string: 'new', + }; + + const invocation = tool.build(params); + const abortController = new AbortController(); + const abortError = new Error( + 'Abort requested during smart edit execution', + ); + + const calculateSpy = vi + .spyOn(invocation as any, 'calculateEdit') + .mockImplementation(async () => { + if (!abortController.signal.aborted) { + abortController.abort(); + } + throw abortError; + }); + + await expect(invocation.execute(abortController.signal)).rejects.toBe( + abortError, + ); + + calculateSpy.mockRestore(); + }); + it('should edit an existing file and return diff with fileName', async () => { const initialContent = 'This is some old text.'; const newContent = 'This is some new text.'; @@ -511,4 +541,37 @@ describe('SmartEditTool', () => { expect(params.new_string).toBe(modifiedContent); }); }); + + describe('shouldConfirmExecute', () => { + it('should rethrow calculateEdit errors when the abort signal is triggered', async () => { + const filePath = path.join(rootDir, 'abort-confirmation.txt'); + const params: EditToolParams = { + file_path: filePath, + instruction: 'Abort during confirmation', + old_string: 'old', + new_string: 'new', + }; + + const invocation = tool.build(params); + const abortController = new AbortController(); + const abortError = new Error( + 'Abort requested during smart edit confirmation', + ); + + const calculateSpy = vi + .spyOn(invocation as any, 'calculateEdit') + .mockImplementation(async () => { + if (!abortController.signal.aborted) { + abortController.abort(); + } + throw abortError; + }); + + await expect( + invocation.shouldConfirmExecute(abortController.signal), + ).rejects.toBe(abortError); + + calculateSpy.mockRestore(); + }); + }); }); diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index f1537958a2..cfd48e01f0 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -490,6 +490,9 @@ class EditToolInvocation implements ToolInvocation { try { editData = await this.calculateEdit(this.params, abortSignal); } catch (error) { + if (abortSignal.aborted) { + throw error; + } const errorMsg = error instanceof Error ? error.message : String(error); console.log(`Error preparing edit: ${errorMsg}`); return false; @@ -575,6 +578,9 @@ class EditToolInvocation implements ToolInvocation { try { editData = await this.calculateEdit(this.params, signal); } catch (error) { + if (signal.aborted) { + throw error; + } const errorMsg = error instanceof Error ? error.message : String(error); return { llmContent: `Error preparing edit: ${errorMsg}`,