From 37ca643a6440a5db00cd7ae08725483fc75e1b26 Mon Sep 17 00:00:00 2001 From: cornmander Date: Mon, 10 Nov 2025 16:19:32 -0500 Subject: [PATCH] Fix external editor diff drift (#12846) --- packages/cli/src/gemini.tsx | 1 - .../core/src/core/coreToolScheduler.test.ts | 536 ++++++------------ packages/core/src/core/coreToolScheduler.ts | 10 + .../core/src/tools/modifiable-tool.test.ts | 58 ++ packages/core/src/tools/modifiable-tool.ts | 25 +- 5 files changed, 252 insertions(+), 378 deletions(-) diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 9ef0287be7..aa9b7e3588 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -234,7 +234,6 @@ export async function startInteractiveUI( } }, alternateBuffer: settings.merged.ui?.useAlternateBuffer, - alternateBufferAlreadyActive: settings.merged.ui?.useAlternateBuffer, }, ); diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 9b7aefa8bd..073ff712d5 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -36,6 +36,7 @@ import { MockTool, MOCK_TOOL_SHOULD_CONFIRM_EXECUTE, } from '../test-utils/mock-tool.js'; +import * as modifiableToolModule from '../tools/modifiable-tool.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; import { isShellInvocationAllowlisted } from '../utils/shell-utils.js'; @@ -209,6 +210,54 @@ async function waitForStatus( }); } +function createMockConfig(overrides: Partial = {}): Config { + const defaultToolRegistry = { + getTool: () => undefined, + getToolByName: () => undefined, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => undefined, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const baseConfig = { + getSessionId: () => 'test-session-id', + getUsageStatisticsEnabled: () => true, + getDebugMode: () => false, + getApprovalMode: () => ApprovalMode.DEFAULT, + setApprovalMode: () => {}, + getAllowedTools: () => [], + getContentGeneratorConfig: () => ({ + model: 'test-model', + authType: 'oauth-personal', + }), + getShellExecutionConfig: () => ({ + terminalWidth: 90, + terminalHeight: 30, + }), + storage: { + getProjectTempDir: () => '/tmp', + }, + getTruncateToolOutputThreshold: () => + DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, + getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + getToolRegistry: () => defaultToolRegistry, + getUseSmartEdit: () => false, + getUseModelRouter: () => false, + getGeminiClient: () => null, + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, + } as unknown as Config; + + return { ...baseConfig, ...overrides } as Config; +} + describe('CoreToolScheduler', () => { it('should cancel a tool call if the signal is aborted before confirmation', async () => { const mockTool = new MockTool({ @@ -233,34 +282,9 @@ describe('CoreToolScheduler', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -323,34 +347,9 @@ describe('CoreToolScheduler', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -449,34 +448,9 @@ describe('CoreToolScheduler', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -570,34 +544,9 @@ describe('CoreToolScheduler', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -633,15 +582,9 @@ describe('CoreToolScheduler', () => { const mockToolRegistry = { getAllToolNames: () => ['list_files', 'read_file', 'write_file'], } as unknown as ToolRegistry; - const mockConfig = { + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); // Create scheduler const scheduler = new CoreToolScheduler({ @@ -692,34 +635,9 @@ describe('CoreToolScheduler with payload', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1018,31 +936,9 @@ describe('CoreToolScheduler edit cancellation', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1124,34 +1020,10 @@ describe('CoreToolScheduler YOLO mode', () => { const onToolCallsUpdate = vi.fn(); // Configure the scheduler for YOLO mode. - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.YOLO, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + getApprovalMode: () => ApprovalMode.YOLO, + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1234,34 +1106,10 @@ describe('CoreToolScheduler request queueing', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1361,42 +1209,20 @@ describe('CoreToolScheduler request queueing', () => { discoverTools: async () => {}, getAllTools: () => [], getToolsByServer: () => [], - }; + } as unknown as ToolRegistry; const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); // Configure the scheduler to auto-approve the specific tool call. - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, // Not YOLO mode + const mockConfig = createMockConfig({ getAllowedTools: () => ['mockTool'], // Auto-approve this tool getToolRegistry: () => toolRegistry, - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), getShellExecutionConfig: () => ({ terminalWidth: 80, terminalHeight: 24, }), - getTerminalWidth: vi.fn(() => 80), - getTerminalHeight: vi.fn(() => 24), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1491,41 +1317,19 @@ describe('CoreToolScheduler request queueing', () => { discoverTools: async () => {}, getAllTools: () => [], getToolsByServer: () => [], - }; + } as unknown as ToolRegistry; const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.DEFAULT, + const mockConfig = createMockConfig({ getAllowedTools: () => ['run_shell_command(git)'], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), getShellExecutionConfig: () => ({ terminalWidth: 80, terminalHeight: 24, }), - getTerminalWidth: vi.fn(() => 80), - getTerminalHeight: vi.fn(() => 24), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, getToolRegistry: () => toolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1578,34 +1382,10 @@ describe('CoreToolScheduler request queueing', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.YOLO, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + getApprovalMode: () => ApprovalMode.YOLO, + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1655,32 +1435,12 @@ describe('CoreToolScheduler request queueing', () => { it('should auto-approve remaining tool calls when first tool call is approved with ProceedAlways', async () => { let approvalMode = ApprovalMode.DEFAULT; - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, + const mockConfig = createMockConfig({ getApprovalMode: () => approvalMode, - getAllowedTools: () => [], setApprovalMode: (mode: ApprovalMode) => { approvalMode = mode; }, - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, // No client needed for these tests - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - getPolicyEngine: () => null, - } as unknown as Config; + }); const testTool = new TestApprovalTool(mockConfig); const toolRegistry = { @@ -1848,33 +1608,10 @@ describe('CoreToolScheduler Sequential Execution', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - } as unknown as Config; + getApprovalMode: () => ApprovalMode.YOLO, // Use YOLO to avoid confirmation prompts + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -1970,33 +1707,10 @@ describe('CoreToolScheduler Sequential Execution', () => { const onAllToolCallsComplete = vi.fn(); const onToolCallsUpdate = vi.fn(); - const mockConfig = { - getSessionId: () => 'test-session-id', - getUsageStatisticsEnabled: () => true, - getDebugMode: () => false, - getApprovalMode: () => ApprovalMode.YOLO, - getAllowedTools: () => [], - getContentGeneratorConfig: () => ({ - model: 'test-model', - authType: 'oauth-personal', - }), - getShellExecutionConfig: () => ({ - terminalWidth: 90, - terminalHeight: 30, - }), - storage: { - getProjectTempDir: () => '/tmp', - }, + const mockConfig = createMockConfig({ getToolRegistry: () => mockToolRegistry, - getTruncateToolOutputThreshold: () => - DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, - getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES, - getUseSmartEdit: () => false, - getUseModelRouter: () => false, - getGeminiClient: () => null, - getEnableMessageBusIntegration: () => false, - getMessageBus: () => null, - } as unknown as Config; + getApprovalMode: () => ApprovalMode.YOLO, + }); const scheduler = new CoreToolScheduler({ config: mockConfig, @@ -2066,6 +1780,84 @@ describe('CoreToolScheduler Sequential Execution', () => { expect(call2?.status).toBe('cancelled'); expect(call3?.status).toBe('cancelled'); }); + + it('should pass confirmation diff data into modifyWithEditor overrides', async () => { + const modifyWithEditorSpy = vi + .spyOn(modifiableToolModule, 'modifyWithEditor') + .mockResolvedValue({ + updatedParams: { param: 'updated' }, + updatedDiff: 'updated diff', + }); + + const mockModifiableTool = new MockModifiableTool('mockModifiableTool'); + const mockToolRegistry = { + getTool: () => mockModifiableTool, + getToolByName: () => mockModifiableTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByDisplayName: () => mockModifiableTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const onAllToolCallsComplete = vi.fn(); + const onToolCallsUpdate = vi.fn(); + + const mockConfig = createMockConfig({ + getToolRegistry: () => mockToolRegistry, + }); + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + onAllToolCallsComplete, + onToolCallsUpdate, + getPreferredEditor: () => 'vscode', + onEditorClose: vi.fn(), + }); + + const abortController = new AbortController(); + + await scheduler.schedule( + [ + { + callId: '1', + name: 'mockModifiableTool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + ], + abortController.signal, + ); + + const toolCall = (scheduler as unknown as { toolCalls: ToolCall[] }) + .toolCalls[0] as WaitingToolCall; + expect(toolCall.status).toBe('awaiting_approval'); + + const confirmationSignal = new AbortController().signal; + await scheduler.handleConfirmationResponse( + toolCall.request.callId, + async () => {}, + ToolConfirmationOutcome.ModifyWithEditor, + confirmationSignal, + ); + + expect(modifyWithEditorSpy).toHaveBeenCalled(); + const overrides = + modifyWithEditorSpy.mock.calls[ + modifyWithEditorSpy.mock.calls.length - 1 + ][5]; + expect(overrides).toEqual({ + currentContent: 'originalContent', + proposedContent: 'newContent', + }); + + modifyWithEditorSpy.mockRestore(); + }); }); describe('truncateAndSaveToFile', () => { diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 0cc1adf7a1..6e3b799277 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -979,6 +979,15 @@ export class CoreToolScheduler { isModifying: true, } as ToolCallConfirmationDetails); + const contentOverrides = + waitingToolCall.confirmationDetails.type === 'edit' + ? { + currentContent: + waitingToolCall.confirmationDetails.originalContent, + proposedContent: waitingToolCall.confirmationDetails.newContent, + } + : undefined; + const { updatedParams, updatedDiff } = await modifyWithEditor< typeof waitingToolCall.request.args >( @@ -987,6 +996,7 @@ export class CoreToolScheduler { editorType, signal, this.onEditorClose, + contentOverrides, ); this.setArgsInternal(callId, updatedParams); this.setStatusInternal(callId, 'awaiting_approval', signal, { diff --git a/packages/core/src/tools/modifiable-tool.test.ts b/packages/core/src/tools/modifiable-tool.test.ts index 0ed23d610c..ec6caf9290 100644 --- a/packages/core/src/tools/modifiable-tool.test.ts +++ b/packages/core/src/tools/modifiable-tool.test.ts @@ -235,6 +235,64 @@ describe('modifyWithEditor', () => { expect(result.updatedDiff).toBe('mock diff content'); }); + it('should honor override content values when provided', async () => { + const overrideCurrent = 'override current content'; + const overrideProposed = 'override proposed content'; + mockModifyContext.getCurrentContent = vi.fn(); + mockModifyContext.getProposedContent = vi.fn(); + + await modifyWithEditor( + mockParams, + mockModifyContext, + 'vscode' as EditorType, + abortSignal, + vi.fn(), + { + currentContent: overrideCurrent, + proposedContent: overrideProposed, + }, + ); + + expect(mockModifyContext.getCurrentContent).not.toHaveBeenCalled(); + expect(mockModifyContext.getProposedContent).not.toHaveBeenCalled(); + expect(mockCreatePatch).toHaveBeenCalledWith( + path.basename(mockParams.filePath), + overrideCurrent, + modifiedContent, + 'Current', + 'Proposed', + expect.any(Object), + ); + }); + + it('should treat null override as explicit empty content', async () => { + mockModifyContext.getCurrentContent = vi.fn(); + mockModifyContext.getProposedContent = vi.fn(); + + await modifyWithEditor( + mockParams, + mockModifyContext, + 'vscode' as EditorType, + abortSignal, + vi.fn(), + { + currentContent: null, + proposedContent: 'override proposed content', + }, + ); + + expect(mockModifyContext.getCurrentContent).not.toHaveBeenCalled(); + expect(mockModifyContext.getProposedContent).not.toHaveBeenCalled(); + expect(mockCreatePatch).toHaveBeenCalledWith( + path.basename(mockParams.filePath), + '', + modifiedContent, + 'Current', + 'Proposed', + expect.any(Object), + ); + }); + it('should clean up temp files even if editor fails', async () => { const editorError = new Error('Editor failed to open'); mockOpenDiff.mockRejectedValue(editorError); diff --git a/packages/core/src/tools/modifiable-tool.ts b/packages/core/src/tools/modifiable-tool.ts index 0857d86884..b96ad04d44 100644 --- a/packages/core/src/tools/modifiable-tool.ts +++ b/packages/core/src/tools/modifiable-tool.ts @@ -46,6 +46,11 @@ export interface ModifyResult { updatedDiff: string; } +export interface ModifyContentOverrides { + currentContent?: string | null; + proposedContent?: string; +} + /** * Type guard to check if a declarative tool is modifiable. */ @@ -172,14 +177,24 @@ export async function modifyWithEditor( editorType: EditorType, _abortSignal: AbortSignal, onEditorClose: () => void, + overrides?: ModifyContentOverrides, ): Promise> { - const currentContent = await modifyContext.getCurrentContent(originalParams); - const proposedContent = - await modifyContext.getProposedContent(originalParams); + const hasCurrentOverride = + overrides !== undefined && 'currentContent' in overrides; + const hasProposedOverride = + overrides !== undefined && 'proposedContent' in overrides; + + const currentContent = hasCurrentOverride + ? (overrides!.currentContent ?? '') + : await modifyContext.getCurrentContent(originalParams); + + const proposedContent = hasProposedOverride + ? (overrides!.proposedContent ?? '') + : await modifyContext.getProposedContent(originalParams); const { oldPath, newPath, dirPath } = createTempFilesForModify( - currentContent, - proposedContent, + currentContent ?? '', + proposedContent ?? '', modifyContext.getFilePath(originalParams), );