mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-12 15:10:59 -07:00
fix(core): Resolve race condition in tool response reporting (#16557)
This commit is contained in:
@@ -1904,4 +1904,175 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
|||||||
serverName,
|
serverName,
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should not double-report completed tools when concurrent completions occur', async () => {
|
||||||
|
// Arrange
|
||||||
|
const executeFn = vi.fn().mockResolvedValue({ llmContent: 'success' });
|
||||||
|
const mockTool = new MockTool({ name: 'mockTool', execute: executeFn });
|
||||||
|
const declarativeTool = mockTool;
|
||||||
|
|
||||||
|
const mockToolRegistry = {
|
||||||
|
getTool: () => declarativeTool,
|
||||||
|
getToolByName: () => declarativeTool,
|
||||||
|
getFunctionDeclarations: () => [],
|
||||||
|
tools: new Map(),
|
||||||
|
discovery: {},
|
||||||
|
registerTool: () => {},
|
||||||
|
getToolByDisplayName: () => declarativeTool,
|
||||||
|
getTools: () => [],
|
||||||
|
discoverTools: async () => {},
|
||||||
|
getAllTools: () => [],
|
||||||
|
getToolsByServer: () => [],
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
let completionCallCount = 0;
|
||||||
|
const onAllToolCallsComplete = vi.fn().mockImplementation(async () => {
|
||||||
|
completionCallCount++;
|
||||||
|
// Simulate slow reporting (e.g. Gemini API call)
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockConfig = createMockConfig({
|
||||||
|
getToolRegistry: () => mockToolRegistry,
|
||||||
|
getApprovalMode: () => ApprovalMode.YOLO,
|
||||||
|
isInteractive: () => false,
|
||||||
|
});
|
||||||
|
const mockMessageBus = createMockMessageBus();
|
||||||
|
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
|
||||||
|
mockConfig.getEnableHooks = vi.fn().mockReturnValue(false);
|
||||||
|
mockConfig.getHookSystem = vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue(new HookSystem(mockConfig));
|
||||||
|
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
onAllToolCallsComplete,
|
||||||
|
getPreferredEditor: () => 'vscode',
|
||||||
|
});
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const request = {
|
||||||
|
callId: '1',
|
||||||
|
name: 'mockTool',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
// Act
|
||||||
|
// 1. Start execution
|
||||||
|
const schedulePromise = scheduler.schedule(
|
||||||
|
[request],
|
||||||
|
abortController.signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Wait just enough for it to finish and enter checkAndNotifyCompletion
|
||||||
|
// (awaiting our slow mock)
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(completionCallCount).toBe(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
// 3. Trigger a concurrent completion event (e.g. via cancelAll)
|
||||||
|
scheduler.cancelAll(abortController.signal);
|
||||||
|
|
||||||
|
await schedulePromise;
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// Even though cancelAll was called while the first completion was in progress,
|
||||||
|
// it should not have triggered a SECOND completion call because the first one
|
||||||
|
// was still 'finalizing' and will drain any new tools.
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should complete reporting all tools even mid-callback during abort', async () => {
|
||||||
|
// Arrange
|
||||||
|
const onAllToolCallsComplete = vi.fn().mockImplementation(async () => {
|
||||||
|
// Simulate slow reporting
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 50));
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockTool = new MockTool({ name: 'mockTool' });
|
||||||
|
const mockToolRegistry = {
|
||||||
|
getTool: () => mockTool,
|
||||||
|
getToolByName: () => mockTool,
|
||||||
|
getFunctionDeclarations: () => [],
|
||||||
|
tools: new Map(),
|
||||||
|
discovery: {},
|
||||||
|
registerTool: () => {},
|
||||||
|
getToolByDisplayName: () => mockTool,
|
||||||
|
getTools: () => [],
|
||||||
|
discoverTools: async () => {},
|
||||||
|
getAllTools: () => [],
|
||||||
|
getToolsByServer: () => [],
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const mockConfig = createMockConfig({
|
||||||
|
getToolRegistry: () => mockToolRegistry,
|
||||||
|
getApprovalMode: () => ApprovalMode.YOLO,
|
||||||
|
isInteractive: () => false,
|
||||||
|
});
|
||||||
|
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
onAllToolCallsComplete,
|
||||||
|
getPreferredEditor: () => 'vscode',
|
||||||
|
});
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const signal = abortController.signal;
|
||||||
|
|
||||||
|
// Act
|
||||||
|
// 1. Start execution of two tools
|
||||||
|
const schedulePromise = scheduler.schedule(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
callId: '1',
|
||||||
|
name: 'mockTool',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-1',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
callId: '2',
|
||||||
|
name: 'mockTool',
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-1',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 2. Wait for reporting to start
|
||||||
|
await vi.waitFor(() => {
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
// 3. Abort the signal while reporting is in progress
|
||||||
|
abortController.abort();
|
||||||
|
|
||||||
|
await schedulePromise;
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
// Verify that onAllToolCallsComplete was called and processed the tools,
|
||||||
|
// and that the scheduler didn't just drop them because of the abort.
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalled();
|
||||||
|
|
||||||
|
const reportedTools = onAllToolCallsComplete.mock.calls.flatMap((call) =>
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
call[0].map((t: any) => t.request.callId),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Both tools should have been reported exactly once with success status
|
||||||
|
expect(reportedTools).toContain('1');
|
||||||
|
expect(reportedTools).toContain('2');
|
||||||
|
|
||||||
|
const allStatuses = onAllToolCallsComplete.mock.calls.flatMap((call) =>
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
call[0].map((t: any) => t.status),
|
||||||
|
);
|
||||||
|
expect(allStatuses).toEqual(['success', 'success']);
|
||||||
|
|
||||||
|
expect(onAllToolCallsComplete).toHaveBeenCalledTimes(1);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -909,21 +909,36 @@ export class CoreToolScheduler {
|
|||||||
this._cancelAllQueuedCalls();
|
this._cancelAllQueuedCalls();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we are already finalizing, another concurrent call to
|
||||||
|
// checkAndNotifyCompletion will just return. The ongoing finalized loop
|
||||||
|
// will pick up any new tools added to completedToolCallsForBatch.
|
||||||
|
if (this.isFinalizingToolCalls) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// If there's nothing to report and we weren't cancelled, we can stop.
|
// If there's nothing to report and we weren't cancelled, we can stop.
|
||||||
// But if we were cancelled, we must proceed to potentially start the next queued request.
|
// But if we were cancelled, we must proceed to potentially start the next queued request.
|
||||||
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
|
if (this.completedToolCallsForBatch.length === 0 && !signal.aborted) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.onAllToolCallsComplete) {
|
this.isFinalizingToolCalls = true;
|
||||||
this.isFinalizingToolCalls = true;
|
try {
|
||||||
// Use the batch array, not the (now empty) active array.
|
// We use a while loop here to ensure that if new tools are added to the
|
||||||
await this.onAllToolCallsComplete(this.completedToolCallsForBatch);
|
// batch (e.g., via cancellation) while we are awaiting
|
||||||
this.completedToolCallsForBatch = []; // Clear after reporting.
|
// onAllToolCallsComplete, they are also reported before we finish.
|
||||||
|
while (this.completedToolCallsForBatch.length > 0) {
|
||||||
|
const batchToReport = [...this.completedToolCallsForBatch];
|
||||||
|
this.completedToolCallsForBatch = [];
|
||||||
|
if (this.onAllToolCallsComplete) {
|
||||||
|
await this.onAllToolCallsComplete(batchToReport);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
this.isFinalizingToolCalls = false;
|
this.isFinalizingToolCalls = false;
|
||||||
|
this.isCancelling = false;
|
||||||
|
this.notifyToolCallsUpdate();
|
||||||
}
|
}
|
||||||
this.isCancelling = false;
|
|
||||||
this.notifyToolCallsUpdate();
|
|
||||||
|
|
||||||
// After completion of the entire batch, process the next item in the main request queue.
|
// After completion of the entire batch, process the next item in the main request queue.
|
||||||
if (this.requestQueue.length > 0) {
|
if (this.requestQueue.length > 0) {
|
||||||
|
|||||||
Reference in New Issue
Block a user