mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-12 23:21:27 -07:00
fix(core): Fix unable to cancel edit tool (#9299)
This commit is contained in:
@@ -109,6 +109,65 @@ class TestApprovalInvocation extends BaseToolInvocation<
|
||||
}
|
||||
}
|
||||
|
||||
class AbortDuringConfirmationInvocation extends BaseToolInvocation<
|
||||
Record<string, unknown>,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly abortController: AbortController,
|
||||
private readonly abortError: Error,
|
||||
params: Record<string, unknown>,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
_signal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
this.abortController.abort();
|
||||
throw this.abortError;
|
||||
}
|
||||
|
||||
async execute(_abortSignal: AbortSignal): Promise<ToolResult> {
|
||||
throw new Error('execute should not be called when confirmation fails');
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return 'Abort during confirmation invocation';
|
||||
}
|
||||
}
|
||||
|
||||
class AbortDuringConfirmationTool extends BaseDeclarativeTool<
|
||||
Record<string, unknown>,
|
||||
ToolResult
|
||||
> {
|
||||
constructor(
|
||||
private readonly abortController: AbortController,
|
||||
private readonly abortError: Error,
|
||||
) {
|
||||
super(
|
||||
'abortDuringConfirmationTool',
|
||||
'Abort During Confirmation Tool',
|
||||
'A tool that aborts while confirming execution.',
|
||||
Kind.Other,
|
||||
{
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(
|
||||
params: Record<string, unknown>,
|
||||
): ToolInvocation<Record<string, unknown>, ToolResult> {
|
||||
return new AbortDuringConfirmationInvocation(
|
||||
this.abortController,
|
||||
this.abortError,
|
||||
params,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForStatus(
|
||||
onToolCallsUpdate: Mock,
|
||||
status: 'awaiting_approval' | 'executing' | 'success' | 'error' | 'cancelled',
|
||||
@@ -218,6 +277,85 @@ describe('CoreToolScheduler', () => {
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
});
|
||||
|
||||
it('should mark tool call as cancelled when abort happens during confirmation error', async () => {
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error('Abort requested during confirmation');
|
||||
const declarativeTool = new AbortDuringConfirmationTool(
|
||||
abortController,
|
||||
abortError,
|
||||
);
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: () => declarativeTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: () => declarativeTool,
|
||||
getToolByDisplayName: () => declarativeTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onAllToolCallsComplete = vi.fn();
|
||||
const onToolCallsUpdate = vi.fn();
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getApprovalMode: () => ApprovalMode.DEFAULT,
|
||||
getAllowedTools: () => [],
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
getShellExecutionConfig: () => ({
|
||||
terminalWidth: 90,
|
||||
terminalHeight: 30,
|
||||
}),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
},
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
getTruncateToolOutputLines: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getUseSmartEdit: () => false,
|
||||
getUseModelRouter: () => false,
|
||||
getGeminiClient: () => null,
|
||||
} as unknown as Config;
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
onAllToolCallsComplete,
|
||||
onToolCallsUpdate,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
onEditorClose: vi.fn(),
|
||||
});
|
||||
|
||||
const request = {
|
||||
callId: 'abort-1',
|
||||
name: 'abortDuringConfirmationTool',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-abort',
|
||||
};
|
||||
|
||||
await scheduler.schedule([request], abortController.signal);
|
||||
|
||||
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||
const completedCalls = onAllToolCallsComplete.mock
|
||||
.calls[0][0] as ToolCall[];
|
||||
expect(completedCalls[0].status).toBe('cancelled');
|
||||
const statuses = onToolCallsUpdate.mock.calls.flatMap((call) =>
|
||||
(call[0] as ToolCall[]).map((toolCall) => toolCall.status),
|
||||
);
|
||||
expect(statuses).not.toContain('error');
|
||||
});
|
||||
|
||||
describe('getToolSuggestion', () => {
|
||||
it('should suggest the top N closest tool names for a typo', () => {
|
||||
// Create mocked tool registry
|
||||
|
||||
@@ -799,6 +799,15 @@ export class CoreToolScheduler {
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'cancelled',
|
||||
'Tool call cancelled by user.',
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
this.setStatusInternal(
|
||||
reqInfo.callId,
|
||||
'error',
|
||||
|
||||
@@ -471,6 +471,34 @@ describe('EditTool', () => {
|
||||
);
|
||||
expect(patchedContent).toBe(expectedFinalContent);
|
||||
});
|
||||
|
||||
it('should rethrow calculateEdit errors when the abort signal is triggered', async () => {
|
||||
const filePath = path.join(rootDir, 'abort-confirmation.txt');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error('Abort requested');
|
||||
|
||||
const calculateSpy = vi
|
||||
.spyOn(invocation as any, 'calculateEdit')
|
||||
.mockImplementation(async () => {
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
throw abortError;
|
||||
});
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(abortController.signal),
|
||||
).rejects.toBe(abortError);
|
||||
|
||||
calculateSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
@@ -515,6 +543,33 @@ describe('EditTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should reject when calculateEdit fails after an abort signal', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, 'abort-execute.txt'),
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error('Abort requested during execute');
|
||||
|
||||
const calculateSpy = vi
|
||||
.spyOn(invocation as any, 'calculateEdit')
|
||||
.mockImplementation(async () => {
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
throw abortError;
|
||||
});
|
||||
|
||||
await expect(invocation.execute(abortController.signal)).rejects.toBe(
|
||||
abortError,
|
||||
);
|
||||
|
||||
calculateSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should edit an existing file and return diff with fileName', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
const newContent = 'This is some new text.'; // old -> new
|
||||
|
||||
@@ -251,6 +251,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, abortSignal);
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
console.log(`Error preparing edit: ${errorMsg}`);
|
||||
return false;
|
||||
@@ -336,6 +339,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, signal);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
llmContent: `Error preparing edit: ${errorMsg}`,
|
||||
|
||||
@@ -274,6 +274,36 @@ describe('SmartEditTool', () => {
|
||||
filePath = path.join(rootDir, testFile);
|
||||
});
|
||||
|
||||
it('should reject when calculateEdit fails after an abort signal', async () => {
|
||||
const params: EditToolParams = {
|
||||
file_path: path.join(rootDir, 'abort-execute.txt'),
|
||||
instruction: 'Abort during execute',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error(
|
||||
'Abort requested during smart edit execution',
|
||||
);
|
||||
|
||||
const calculateSpy = vi
|
||||
.spyOn(invocation as any, 'calculateEdit')
|
||||
.mockImplementation(async () => {
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
throw abortError;
|
||||
});
|
||||
|
||||
await expect(invocation.execute(abortController.signal)).rejects.toBe(
|
||||
abortError,
|
||||
);
|
||||
|
||||
calculateSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should edit an existing file and return diff with fileName', async () => {
|
||||
const initialContent = 'This is some old text.';
|
||||
const newContent = 'This is some new text.';
|
||||
@@ -511,4 +541,37 @@ describe('SmartEditTool', () => {
|
||||
expect(params.new_string).toBe(modifiedContent);
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
it('should rethrow calculateEdit errors when the abort signal is triggered', async () => {
|
||||
const filePath = path.join(rootDir, 'abort-confirmation.txt');
|
||||
const params: EditToolParams = {
|
||||
file_path: filePath,
|
||||
instruction: 'Abort during confirmation',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
};
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const abortController = new AbortController();
|
||||
const abortError = new Error(
|
||||
'Abort requested during smart edit confirmation',
|
||||
);
|
||||
|
||||
const calculateSpy = vi
|
||||
.spyOn(invocation as any, 'calculateEdit')
|
||||
.mockImplementation(async () => {
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
throw abortError;
|
||||
});
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(abortController.signal),
|
||||
).rejects.toBe(abortError);
|
||||
|
||||
calculateSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -490,6 +490,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, abortSignal);
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
console.log(`Error preparing edit: ${errorMsg}`);
|
||||
return false;
|
||||
@@ -575,6 +578,9 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, signal);
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
const errorMsg = error instanceof Error ? error.message : String(error);
|
||||
return {
|
||||
llmContent: `Error preparing edit: ${errorMsg}`,
|
||||
|
||||
Reference in New Issue
Block a user