fix(core): Resolve race condition in tool response reporting (#16557)

This commit is contained in:
Abhi
2026-01-13 20:37:10 -05:00
committed by GitHub
parent 8dbaa2bcea
commit eda47f587c
2 changed files with 193 additions and 7 deletions

View File

@@ -1904,4 +1904,175 @@ describe('CoreToolScheduler Sequential Execution', () => {
serverName,
);
});
it('should not double-report completed tools when concurrent completions occur', async () => {
// Arrange
const executeFn = vi.fn().mockResolvedValue({ llmContent: 'success' });
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
const declarativeTool = mockTool;
const mockToolRegistry = {
getTool: () => declarativeTool,
getToolByName: () => declarativeTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByDisplayName: () => declarativeTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
} as unknown as ToolRegistry;
let completionCallCount = 0;
const onAllToolCallsComplete = vi.fn().mockImplementation(async () => {
completionCallCount++;
// Simulate slow reporting (e.g. Gemini API call)
await new Promise((resolve) => setTimeout(resolve, 50));
});
const mockConfig = createMockConfig({
getToolRegistry: () => mockToolRegistry,
getApprovalMode: () => ApprovalMode.YOLO,
isInteractive: () => false,
});
const mockMessageBus = createMockMessageBus();
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
mockConfig.getEnableHooks = vi.fn().mockReturnValue(false);
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
getPreferredEditor: () => 'vscode',
});
const abortController = new AbortController();
const request = {
callId: '1',
name: 'mockTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
};
// Act
// 1. Start execution
const schedulePromise = scheduler.schedule(
[request],
abortController.signal,
);
// 2. Wait just enough for it to finish and enter checkAndNotifyCompletion
// (awaiting our slow mock)
await vi.waitFor(() => {
expect(completionCallCount).toBe(1);
});
// 3. Trigger a concurrent completion event (e.g. via cancelAll)
scheduler.cancelAll(abortController.signal);
await schedulePromise;
// Assert
// Even though cancelAll was called while the first completion was in progress,
// it should not have triggered a SECOND completion call because the first one
// was still 'finalizing' and will drain any new tools.
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(1);
});
it('should complete reporting all tools even mid-callback during abort', async () => {
// Arrange
const onAllToolCallsComplete = vi.fn().mockImplementation(async () => {
// Simulate slow reporting
await new Promise((resolve) => setTimeout(resolve, 50));
});
const mockTool = new MockTool({ name: 'mockTool' });
const mockToolRegistry = {
getTool: () => mockTool,
getToolByName: () => mockTool,
getFunctionDeclarations: () => [],
tools: new Map(),
discovery: {},
registerTool: () => {},
getToolByDisplayName: () => mockTool,
getTools: () => [],
discoverTools: async () => {},
getAllTools: () => [],
getToolsByServer: () => [],
} as unknown as ToolRegistry;
const mockConfig = createMockConfig({
getToolRegistry: () => mockToolRegistry,
getApprovalMode: () => ApprovalMode.YOLO,
isInteractive: () => false,
});
const scheduler = new CoreToolScheduler({
config: mockConfig,
onAllToolCallsComplete,
getPreferredEditor: () => 'vscode',
});
const abortController = new AbortController();
const signal = abortController.signal;
// Act
// 1. Start execution of two tools
const schedulePromise = scheduler.schedule(
[
{
callId: '1',
name: 'mockTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
{
callId: '2',
name: 'mockTool',
args: {},
isClientInitiated: false,
prompt_id: 'prompt-1',
},
],
signal,
);
// 2. Wait for reporting to start
await vi.waitFor(() => {
expect(onAllToolCallsComplete).toHaveBeenCalled();
});
// 3. Abort the signal while reporting is in progress
abortController.abort();
await schedulePromise;
// Assert
// Verify that onAllToolCallsComplete was called and processed the tools,
// and that the scheduler didn't just drop them because of the abort.
expect(onAllToolCallsComplete).toHaveBeenCalled();
const reportedTools = onAllToolCallsComplete.mock.calls.flatMap((call) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
call[0].map((t: any) => t.request.callId),
);
// Both tools should have been reported exactly once with success status
expect(reportedTools).toContain('1');
expect(reportedTools).toContain('2');
const allStatuses = onAllToolCallsComplete.mock.calls.flatMap((call) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
call[0].map((t: any) => t.status),
);
expect(allStatuses).toEqual(['success', 'success']);
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(1);
});
});

View File

@@ -909,21 +909,36 @@ export class CoreToolScheduler {
this._cancelAllQueuedCalls();
}
// If we are already finalizing, another concurrent call to
// checkAndNotifyCompletion will just return. The ongoing finalized loop
// will pick up any new tools added to completedToolCallsForBatch.
if (this.isFinalizingToolCalls) {
return;
}
// If there's nothing to report and we weren't cancelled, we can stop.
// But if we were cancelled, we must proceed to potentially start the next queued request.
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
return;
}
if (this.onAllToolCallsComplete) {
this.isFinalizingToolCalls = true;
// Use the batch array, not the (now empty) active array.
await this.onAllToolCallsComplete(this.completedToolCallsForBatch);
this.completedToolCallsForBatch = []; // Clear after reporting.
this.isFinalizingToolCalls = true;
try {
// We use a while loop here to ensure that if new tools are added to the
// batch (e.g., via cancellation) while we are awaiting
// onAllToolCallsComplete, they are also reported before we finish.
while (this.completedToolCallsForBatch.length > 0) {
const batchToReport = [...this.completedToolCallsForBatch];
this.completedToolCallsForBatch = [];
if (this.onAllToolCallsComplete) {
await this.onAllToolCallsComplete(batchToReport);
}
}
} finally {
this.isFinalizingToolCalls = false;
this.isCancelling = false;
this.notifyToolCallsUpdate();
}
this.isCancelling = false;
this.notifyToolCallsUpdate();
// After completion of the entire batch, process the next item in the main request queue.
if (this.requestQueue.length > 0) {