mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(a2a-server): Resolve race condition in tool completion waiting (#26568)
This commit is contained in:
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { Task } from './task.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
CoreToolCallStatus,
|
||||
type Config,
|
||||
type MessageBus,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||
|
||||
describe('Task Race Condition', () => {
|
||||
let mockConfig: Config;
|
||||
let messageBus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
messageBus = {
|
||||
subscribe: vi.fn(),
|
||||
unsubscribe: vi.fn(),
|
||||
publish: vi.fn(),
|
||||
} as unknown as MessageBus;
|
||||
mockConfig = createMockConfig({
|
||||
messageBus,
|
||||
}) as Config;
|
||||
});
|
||||
|
||||
it('should not hang when multiple tool confirmations are processed while waiting', async () => {
|
||||
// @ts-expect-error - private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig);
|
||||
|
||||
// 1. Register two tools as scheduled
|
||||
task['_registerToolCall']('tool-1', 'scheduled');
|
||||
task['_registerToolCall']('tool-2', 'scheduled');
|
||||
|
||||
// 2. Both transition to awaiting_approval
|
||||
const updateHandler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||
(c: unknown[]) => c[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||
)?.[1];
|
||||
|
||||
updateHandler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
schedulerId: 'task-id',
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: 'tool-1', name: 't1' },
|
||||
status: CoreToolCallStatus.AwaitingApproval,
|
||||
correlationId: 'corr-1',
|
||||
confirmationDetails: { type: 'info' },
|
||||
},
|
||||
{
|
||||
request: { callId: 'tool-2', name: 't2' },
|
||||
status: CoreToolCallStatus.AwaitingApproval,
|
||||
correlationId: 'corr-2',
|
||||
confirmationDetails: { type: 'info' },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// 3. Confirm Tool 1. This makes isAwaitingApprovalOnly() return false.
|
||||
for await (const _ of task.acceptUserMessage(
|
||||
{
|
||||
userMessage: {
|
||||
parts: [
|
||||
{
|
||||
kind: 'data',
|
||||
data: { callId: 'tool-1', outcome: 'proceed_once' },
|
||||
},
|
||||
],
|
||||
},
|
||||
} as unknown as RequestContext,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
// consume generator
|
||||
}
|
||||
|
||||
// 4. Start waiting. This should now block because Tool 1 is confirmed (so we are waiting for its execution).
|
||||
const waitPromise = task.waitForPendingTools();
|
||||
|
||||
// 5. Confirm Tool 2 while waiting.
|
||||
for await (const _ of task.acceptUserMessage(
|
||||
{
|
||||
userMessage: {
|
||||
parts: [
|
||||
{
|
||||
kind: 'data',
|
||||
data: { callId: 'tool-2', outcome: 'proceed_once' },
|
||||
},
|
||||
],
|
||||
},
|
||||
} as unknown as RequestContext,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
// consume generator
|
||||
}
|
||||
|
||||
// 6. Both tools complete successfully
|
||||
updateHandler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
schedulerId: 'task-id',
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: 'tool-1', name: 't1' },
|
||||
status: CoreToolCallStatus.Success,
|
||||
response: { responseParts: [] },
|
||||
},
|
||||
{
|
||||
request: { callId: 'tool-2', name: 't2' },
|
||||
status: CoreToolCallStatus.Success,
|
||||
response: { responseParts: [] },
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
// 7. Verify that the original waitPromise resolves.
|
||||
await expect(waitPromise).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('should reject waitForPendingTools when tools are cancelled', async () => {
|
||||
// @ts-expect-error - private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig);
|
||||
|
||||
// 1. Register a tool
|
||||
task['_registerToolCall']('tool-1', 'scheduled');
|
||||
|
||||
// 2. Start waiting
|
||||
const waitPromise = task.waitForPendingTools();
|
||||
|
||||
// 3. Cancel pending tools
|
||||
task.cancelPendingTools('User requested cancellation');
|
||||
|
||||
// 4. Verify waitPromise rejects with the reason
|
||||
await expect(waitPromise).rejects.toThrow('User requested cancellation');
|
||||
});
|
||||
|
||||
it('should handle concurrent tool scheduling correctly', async () => {
|
||||
// @ts-expect-error - private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig);
|
||||
|
||||
// 1. Register a tool and start waiting
|
||||
task['_registerToolCall']('tool-1', 'scheduled');
|
||||
const waitPromise = task.waitForPendingTools();
|
||||
|
||||
// 2. Schedule another tool concurrently (e.g. from a secondary user message)
|
||||
// This should NOT resolve the current waitPromise until both are done
|
||||
await task.scheduleToolCalls(
|
||||
[{ callId: 'tool-2', name: 't2', args: {} }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(task['pendingToolCalls'].size).toBe(2);
|
||||
|
||||
// 3. Resolve tool 1
|
||||
task['_resolveToolCall']('tool-1');
|
||||
|
||||
// 4. Verify waitPromise is still pending
|
||||
let resolved = false;
|
||||
waitPromise.then(() => (resolved = true));
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
expect(resolved).toBe(false);
|
||||
|
||||
// 5. Resolve tool 2
|
||||
task['_resolveToolCall']('tool-2');
|
||||
|
||||
// 6. Now it should resolve
|
||||
await expect(waitPromise).resolves.toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
ApprovalMode,
|
||||
Scheduler,
|
||||
type MessageBus,
|
||||
type ToolLiveOutput,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
@@ -608,6 +609,74 @@ describe('Task Event-Driven Scheduler', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle multi-turn tool resolution correctly', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig);
|
||||
|
||||
task['_registerToolCall']('1', 'scheduled');
|
||||
task['_registerToolCall']('2', 'scheduled');
|
||||
|
||||
const handler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||
)?.[1];
|
||||
|
||||
// Turn 1: Resolve tool 1
|
||||
handler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: '1', name: 't1' },
|
||||
status: 'success',
|
||||
response: { responseParts: [] },
|
||||
},
|
||||
],
|
||||
schedulerId: 'task-id',
|
||||
});
|
||||
|
||||
expect(task['pendingToolCalls'].size).toBe(1);
|
||||
expect(task['pendingToolCalls'].has('2')).toBe(true);
|
||||
|
||||
// Turn 2: Resolve tool 2
|
||||
handler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
toolCalls: [
|
||||
{
|
||||
request: { callId: '2', name: 't2' },
|
||||
status: 'success',
|
||||
response: { responseParts: [] },
|
||||
},
|
||||
],
|
||||
schedulerId: 'task-id',
|
||||
});
|
||||
|
||||
expect(task['pendingToolCalls'].size).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle subagent progress events from the scheduler', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
// Trigger _schedulerOutputUpdate with subagent progress
|
||||
task['_schedulerOutputUpdate']('tool-1', {
|
||||
isSubagentProgress: true,
|
||||
agentName: 'researcher',
|
||||
recentActivity: [],
|
||||
} as ToolLiveOutput);
|
||||
|
||||
expect(mockEventBus.publish).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
kind: 'artifact-update',
|
||||
artifact: expect.objectContaining({
|
||||
parts: [
|
||||
expect.objectContaining({
|
||||
text: expect.stringContaining('researcher'),
|
||||
}),
|
||||
],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should wait for executing tools before transitioning to input-required state', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
@@ -52,6 +52,7 @@ import type {
|
||||
Artifact,
|
||||
} from '@a2a-js/sdk';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { EventEmitter } from 'node:events';
|
||||
import { logger } from '../utils/logger.js';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
@@ -99,11 +100,8 @@ export class Task {
|
||||
private pendingOutcomes: Map<string, ToolConfirmationOutcome | undefined> =
|
||||
new Map(); // toolCallId --> outcome
|
||||
private toolsAlreadyConfirmed: Set<string> = new Set();
|
||||
private toolCompletionPromise?: Promise<void>;
|
||||
private toolCompletionNotifier?: {
|
||||
resolve: () => void;
|
||||
reject: (reason?: Error) => void;
|
||||
};
|
||||
private toolUpdateEmitter = new EventEmitter();
|
||||
private cancellationError?: Error;
|
||||
|
||||
private constructor(
|
||||
id: string,
|
||||
@@ -124,7 +122,6 @@ export class Task {
|
||||
this.taskState = 'submitted';
|
||||
this.eventBus = eventBus;
|
||||
this.completedToolCalls = [];
|
||||
this._resetToolCompletionPromise();
|
||||
this.autoExecute = autoExecute;
|
||||
this.config.setFallbackModelHandler(
|
||||
// For a2a-server, we want to automatically switch to the fallback model
|
||||
@@ -179,22 +176,9 @@ export class Task {
|
||||
return metadata;
|
||||
}
|
||||
|
||||
private _resetToolCompletionPromise(): void {
|
||||
this.toolCompletionPromise = new Promise((resolve, reject) => {
|
||||
this.toolCompletionNotifier = { resolve, reject };
|
||||
});
|
||||
// If there are no pending calls when reset, resolve immediately.
|
||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
private _registerToolCall(toolCallId: string, status: string): void {
|
||||
const wasEmpty = this.pendingToolCalls.size === 0;
|
||||
this.pendingToolCalls.set(toolCallId, status);
|
||||
if (wasEmpty) {
|
||||
this._resetToolCompletionPromise();
|
||||
}
|
||||
this.toolUpdateEmitter.emit('update');
|
||||
logger.info(
|
||||
`[Task] Registered tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||
);
|
||||
@@ -203,23 +187,47 @@ export class Task {
|
||||
private _resolveToolCall(toolCallId: string): void {
|
||||
if (this.pendingToolCalls.has(toolCallId)) {
|
||||
this.pendingToolCalls.delete(toolCallId);
|
||||
this.toolUpdateEmitter.emit('update');
|
||||
logger.info(
|
||||
`[Task] Resolved tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||
);
|
||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.resolve();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async waitForPendingTools(): Promise<void> {
|
||||
private isAwaitingApprovalOnly(): boolean {
|
||||
if (this.pendingToolCalls.size === 0) {
|
||||
return Promise.resolve();
|
||||
return false;
|
||||
}
|
||||
for (const [callId, status] of this.pendingToolCalls.entries()) {
|
||||
if (
|
||||
status !== CoreToolCallStatus.AwaitingApproval ||
|
||||
this.toolsAlreadyConfirmed.has(callId)
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
async waitForPendingTools(): Promise<void> {
|
||||
while (this.pendingToolCalls.size > 0 && !this.isAwaitingApprovalOnly()) {
|
||||
if (this.cancellationError) {
|
||||
const error = this.cancellationError;
|
||||
this.cancellationError = undefined;
|
||||
throw error;
|
||||
}
|
||||
logger.info(
|
||||
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
|
||||
);
|
||||
await new Promise((resolve) =>
|
||||
this.toolUpdateEmitter.once('update', resolve),
|
||||
);
|
||||
}
|
||||
if (this.cancellationError) {
|
||||
const error = this.cancellationError;
|
||||
this.cancellationError = undefined;
|
||||
throw error;
|
||||
}
|
||||
logger.info(
|
||||
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
|
||||
);
|
||||
await this.toolCompletionPromise;
|
||||
}
|
||||
|
||||
cancelPendingTools(reason: string): void {
|
||||
@@ -228,15 +236,13 @@ export class Task {
|
||||
`[Task] Cancelling all ${this.pendingToolCalls.size} pending tool calls. Reason: ${reason}`,
|
||||
);
|
||||
}
|
||||
if (this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.reject(new Error(reason));
|
||||
}
|
||||
this.cancellationError = new Error(reason);
|
||||
this.pendingToolCalls.clear();
|
||||
this.pendingCorrelationIds.clear();
|
||||
this.toolsAlreadyConfirmed.clear();
|
||||
|
||||
this.scheduler.cancelAll();
|
||||
// Reset the promise for any future operations, ensuring it's in a clean state.
|
||||
this._resetToolCompletionPromise();
|
||||
this.toolUpdateEmitter.emit('update');
|
||||
}
|
||||
|
||||
private _createTextMessage(
|
||||
@@ -552,8 +558,8 @@ export class Task {
|
||||
|
||||
// Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream.
|
||||
// The IDE client will open a new stream with the confirmation reply.
|
||||
if (!wasAlreadyInputRequired && this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.resolve();
|
||||
if (!wasAlreadyInputRequired) {
|
||||
this.toolUpdateEmitter.emit('update');
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -917,6 +923,7 @@ export class Task {
|
||||
const outcomeString = part.data['outcome'];
|
||||
|
||||
this.toolsAlreadyConfirmed.add(callId);
|
||||
this.toolUpdateEmitter.emit('update');
|
||||
|
||||
let confirmationOutcome: ToolConfirmationOutcome | undefined;
|
||||
|
||||
@@ -1130,10 +1137,6 @@ export class Task {
|
||||
if (confirmationHandled) {
|
||||
anyConfirmationHandled = true;
|
||||
// If a confirmation was handled, the scheduler will now run the tool (or cancel it).
|
||||
// We resolve the toolCompletionPromise manually in checkInputRequiredState
|
||||
// to break the original execution loop, so we must reset it here so the
|
||||
// new loop correctly awaits the tool's final execution.
|
||||
this._resetToolCompletionPromise();
|
||||
// We don't send anything to the LLM for this part.
|
||||
// The subsequent tool execution will eventually lead to resolveToolCall.
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user