mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat: Add AbortSignal support for retry logic and tool execution (#9196)
Co-authored-by: Sandy Tao <sandytao520@icloud.com>
This commit is contained in:
@@ -1078,17 +1078,25 @@ export class CoreToolScheduler {
|
||||
}
|
||||
})
|
||||
.catch((executionError: Error) => {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'error',
|
||||
createErrorResponse(
|
||||
scheduledCall.request,
|
||||
executionError instanceof Error
|
||||
? executionError
|
||||
: new Error(String(executionError)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
if (signal.aborted) {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'cancelled',
|
||||
'User cancelled tool execution.',
|
||||
);
|
||||
} else {
|
||||
this.setStatusInternal(
|
||||
callId,
|
||||
'error',
|
||||
createErrorResponse(
|
||||
scheduledCall.request,
|
||||
executionError instanceof Error
|
||||
? executionError
|
||||
: new Error(String(executionError)),
|
||||
ToolErrorType.UNHANDLED_EXCEPTION,
|
||||
),
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -572,6 +572,161 @@ describe('DiscoveredMCPTool', () => {
|
||||
'Here is a resource.\n[Link to My Resource: file:///path/to/resource]\nEmbedded text content.\n[Image: image/jpeg]',
|
||||
);
|
||||
});
|
||||
|
||||
describe('AbortSignal support', () => {
|
||||
it('should abort immediately if signal is already aborted', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
|
||||
const invocation = tool.build(params);
|
||||
|
||||
await expect(invocation.execute(controller.signal)).rejects.toThrow(
|
||||
'Tool call aborted',
|
||||
);
|
||||
|
||||
// Tool should not be called if signal is already aborted
|
||||
expect(mockCallTool).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should abort during tool execution', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
|
||||
// Mock a delayed response to simulate long-running tool
|
||||
mockCallTool.mockImplementation(
|
||||
() =>
|
||||
new Promise((resolve) => {
|
||||
setTimeout(() => {
|
||||
resolve([
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: {
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
},
|
||||
},
|
||||
},
|
||||
]);
|
||||
}, 1000);
|
||||
}),
|
||||
);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const promise = invocation.execute(controller.signal);
|
||||
|
||||
// Abort after a short delay to simulate cancellation during execution
|
||||
setTimeout(() => controller.abort(), 50);
|
||||
|
||||
await expect(promise).rejects.toThrow('Tool call aborted');
|
||||
});
|
||||
|
||||
it('should complete successfully if not aborted', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
const successResponse = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: {
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
mockCallTool.mockResolvedValue(successResponse);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(controller.signal);
|
||||
|
||||
expect(result.llmContent).toEqual([{ text: 'Success' }]);
|
||||
expect(result.returnDisplay).toBe('Success');
|
||||
expect(mockCallTool).toHaveBeenCalledWith([
|
||||
{ name: serverToolName, args: params },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should handle tool error even when abort signal is provided', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
const errorResponse = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: { error: { isError: true } },
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
mockCallTool.mockResolvedValue(errorResponse);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(controller.signal);
|
||||
|
||||
expect(result.error?.type).toBe(ToolErrorType.MCP_TOOL_ERROR);
|
||||
expect(result.returnDisplay).toContain(
|
||||
`Error: MCP tool '${serverToolName}' reported an error.`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle callTool rejection with abort signal', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
const expectedError = new Error('Network error');
|
||||
|
||||
mockCallTool.mockRejectedValue(expectedError);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
|
||||
await expect(invocation.execute(controller.signal)).rejects.toThrow(
|
||||
expectedError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should cleanup event listeners properly on successful completion', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
const successResponse = [
|
||||
{
|
||||
functionResponse: {
|
||||
name: serverToolName,
|
||||
response: {
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
mockCallTool.mockResolvedValue(successResponse);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
await invocation.execute(controller.signal);
|
||||
|
||||
controller.abort();
|
||||
expect(controller.signal.aborted).toBe(true);
|
||||
});
|
||||
|
||||
it('should cleanup event listeners properly on error', async () => {
|
||||
const params = { param: 'test' };
|
||||
const controller = new AbortController();
|
||||
const expectedError = new Error('Tool execution failed');
|
||||
|
||||
mockCallTool.mockRejectedValue(expectedError);
|
||||
|
||||
const invocation = tool.build(params);
|
||||
|
||||
try {
|
||||
await invocation.execute(controller.signal);
|
||||
} catch (error) {
|
||||
expect(error).toBe(expectedError);
|
||||
}
|
||||
|
||||
// Verify cleanup by aborting after error
|
||||
controller.abort();
|
||||
expect(controller.signal.aborted).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute', () => {
|
||||
|
||||
@@ -131,7 +131,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
return false;
|
||||
}
|
||||
|
||||
async execute(): Promise<ToolResult> {
|
||||
async execute(signal: AbortSignal): Promise<ToolResult> {
|
||||
const functionCalls: FunctionCall[] = [
|
||||
{
|
||||
name: this.serverToolName,
|
||||
@@ -139,7 +139,36 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
},
|
||||
];
|
||||
|
||||
const rawResponseParts = await this.mcpTool.callTool(functionCalls);
|
||||
// Race MCP tool call with abort signal to respect cancellation
|
||||
const rawResponseParts = await new Promise<Part[]>((resolve, reject) => {
|
||||
if (signal.aborted) {
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
const onAbort = () => {
|
||||
cleanup();
|
||||
const error = new Error('Tool call aborted');
|
||||
error.name = 'AbortError';
|
||||
reject(error);
|
||||
};
|
||||
const cleanup = () => {
|
||||
signal.removeEventListener('abort', onAbort);
|
||||
};
|
||||
signal.addEventListener('abort', onAbort, { once: true });
|
||||
|
||||
this.mcpTool
|
||||
.callTool(functionCalls)
|
||||
.then((res) => {
|
||||
cleanup();
|
||||
resolve(res);
|
||||
})
|
||||
.catch((err) => {
|
||||
cleanup();
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
|
||||
// Ensure the response is not an error
|
||||
if (this.isMCPToolError(rawResponseParts)) {
|
||||
|
||||
Reference in New Issue
Block a user