From 02995ba939bcc592ac1ad9486f74a5708219a993 Mon Sep 17 00:00:00 2001 From: Keith Schaab Date: Wed, 6 May 2026 16:20:22 +0000 Subject: [PATCH] fix(a2a-server): Resolve race condition in tool completion waiting (#26568) --- .../src/agent/race-condition.test.ts | 173 ++++++++++++++++++ .../src/agent/task-event-driven.test.ts | 69 +++++++ packages/a2a-server/src/agent/task.ts | 83 +++++---- 3 files changed, 285 insertions(+), 40 deletions(-) create mode 100644 packages/a2a-server/src/agent/race-condition.test.ts diff --git a/packages/a2a-server/src/agent/race-condition.test.ts b/packages/a2a-server/src/agent/race-condition.test.ts new file mode 100644 index 0000000000..3906c43a68 --- /dev/null +++ b/packages/a2a-server/src/agent/race-condition.test.ts @@ -0,0 +1,173 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; +import { Task } from './task.js'; +import { + MessageBusType, + CoreToolCallStatus, + type Config, + type MessageBus, +} from '@google/gemini-cli-core'; +import { createMockConfig } from '../utils/testing_utils.js'; +import type { RequestContext } from '@a2a-js/sdk/server'; + +describe('Task Race Condition', () => { + let mockConfig: Config; + let messageBus: MessageBus; + + beforeEach(() => { + messageBus = { + subscribe: vi.fn(), + unsubscribe: vi.fn(), + publish: vi.fn(), + } as unknown as MessageBus; + mockConfig = createMockConfig({ + messageBus, + }) as Config; + }); + + it('should not hang when multiple tool confirmations are processed while waiting', async () => { + // @ts-expect-error - private constructor + const task = new Task('task-id', 'context-id', mockConfig); + + // 1. Register two tools as scheduled + task['_registerToolCall']('tool-1', 'scheduled'); + task['_registerToolCall']('tool-2', 'scheduled'); + + // 2. Both transition to awaiting_approval + const updateHandler = (messageBus.subscribe as Mock).mock.calls.find( + (c: unknown[]) => c[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + updateHandler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + schedulerId: 'task-id', + toolCalls: [ + { + request: { callId: 'tool-1', name: 't1' }, + status: CoreToolCallStatus.AwaitingApproval, + correlationId: 'corr-1', + confirmationDetails: { type: 'info' }, + }, + { + request: { callId: 'tool-2', name: 't2' }, + status: CoreToolCallStatus.AwaitingApproval, + correlationId: 'corr-2', + confirmationDetails: { type: 'info' }, + }, + ], + }); + + // 3. Confirm Tool 1. This makes isAwaitingApprovalOnly() return false. + for await (const _ of task.acceptUserMessage( + { + userMessage: { + parts: [ + { + kind: 'data', + data: { callId: 'tool-1', outcome: 'proceed_once' }, + }, + ], + }, + } as unknown as RequestContext, + new AbortController().signal, + )) { + // consume generator + } + + // 4. Start waiting. This should now block because Tool 1 is confirmed (so we are waiting for its execution). + const waitPromise = task.waitForPendingTools(); + + // 5. Confirm Tool 2 while waiting. + for await (const _ of task.acceptUserMessage( + { + userMessage: { + parts: [ + { + kind: 'data', + data: { callId: 'tool-2', outcome: 'proceed_once' }, + }, + ], + }, + } as unknown as RequestContext, + new AbortController().signal, + )) { + // consume generator + } + + // 6. Both tools complete successfully + updateHandler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + schedulerId: 'task-id', + toolCalls: [ + { + request: { callId: 'tool-1', name: 't1' }, + status: CoreToolCallStatus.Success, + response: { responseParts: [] }, + }, + { + request: { callId: 'tool-2', name: 't2' }, + status: CoreToolCallStatus.Success, + response: { responseParts: [] }, + }, + ], + }); + + // 7. Verify that the original waitPromise resolves. + await expect(waitPromise).resolves.toBeUndefined(); + }); + + it('should reject waitForPendingTools when tools are cancelled', async () => { + // @ts-expect-error - private constructor + const task = new Task('task-id', 'context-id', mockConfig); + + // 1. Register a tool + task['_registerToolCall']('tool-1', 'scheduled'); + + // 2. Start waiting + const waitPromise = task.waitForPendingTools(); + + // 3. Cancel pending tools + task.cancelPendingTools('User requested cancellation'); + + // 4. Verify waitPromise rejects with the reason + await expect(waitPromise).rejects.toThrow('User requested cancellation'); + }); + + it('should handle concurrent tool scheduling correctly', async () => { + // @ts-expect-error - private constructor + const task = new Task('task-id', 'context-id', mockConfig); + + // 1. Register a tool and start waiting + task['_registerToolCall']('tool-1', 'scheduled'); + const waitPromise = task.waitForPendingTools(); + + // 2. Schedule another tool concurrently (e.g. from a secondary user message) + // This should NOT resolve the current waitPromise until both are done + await task.scheduleToolCalls( + [{ callId: 'tool-2', name: 't2', args: {} }], + new AbortController().signal, + ); + + expect(task['pendingToolCalls'].size).toBe(2); + + // 3. Resolve tool 1 + task['_resolveToolCall']('tool-1'); + + // 4. Verify waitPromise is still pending + let resolved = false; + waitPromise.then(() => (resolved = true)); + await new Promise((resolve) => setTimeout(resolve, 10)); + expect(resolved).toBe(false); + + // 5. Resolve tool 2 + task['_resolveToolCall']('tool-2'); + + // 6. Now it should resolve + await expect(waitPromise).resolves.toBeUndefined(); + }); +}); diff --git a/packages/a2a-server/src/agent/task-event-driven.test.ts b/packages/a2a-server/src/agent/task-event-driven.test.ts index 5fc548a8f4..a67a2bee13 100644 --- a/packages/a2a-server/src/agent/task-event-driven.test.ts +++ b/packages/a2a-server/src/agent/task-event-driven.test.ts @@ -12,6 +12,7 @@ import { ApprovalMode, Scheduler, type MessageBus, + type ToolLiveOutput, } from '@google/gemini-cli-core'; import { createMockConfig } from '../utils/testing_utils.js'; import type { ExecutionEventBus } from '@a2a-js/sdk/server'; @@ -608,6 +609,74 @@ describe('Task Event-Driven Scheduler', () => { ); }); + it('should handle multi-turn tool resolution correctly', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig); + + task['_registerToolCall']('1', 'scheduled'); + task['_registerToolCall']('2', 'scheduled'); + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + // Turn 1: Resolve tool 1 + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [ + { + request: { callId: '1', name: 't1' }, + status: 'success', + response: { responseParts: [] }, + }, + ], + schedulerId: 'task-id', + }); + + expect(task['pendingToolCalls'].size).toBe(1); + expect(task['pendingToolCalls'].has('2')).toBe(true); + + // Turn 2: Resolve tool 2 + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [ + { + request: { callId: '2', name: 't2' }, + status: 'success', + response: { responseParts: [] }, + }, + ], + schedulerId: 'task-id', + }); + + expect(task['pendingToolCalls'].size).toBe(0); + }); + + it('should handle subagent progress events from the scheduler', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + // Trigger _schedulerOutputUpdate with subagent progress + task['_schedulerOutputUpdate']('tool-1', { + isSubagentProgress: true, + agentName: 'researcher', + recentActivity: [], + } as ToolLiveOutput); + + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + kind: 'artifact-update', + artifact: expect.objectContaining({ + parts: [ + expect.objectContaining({ + text: expect.stringContaining('researcher'), + }), + ], + }), + }), + ); + }); + it('should wait for executing tools before transitioning to input-required state', async () => { // @ts-expect-error - Calling private constructor const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 3fcb5c3ef5..6ecea06c60 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -52,6 +52,7 @@ import type { Artifact, } from '@a2a-js/sdk'; import { v4 as uuidv4 } from 'uuid'; +import { EventEmitter } from 'node:events'; import { logger } from '../utils/logger.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; @@ -99,11 +100,8 @@ export class Task { private pendingOutcomes: Map = new Map(); // toolCallId --> outcome private toolsAlreadyConfirmed: Set = new Set(); - private toolCompletionPromise?: Promise; - private toolCompletionNotifier?: { - resolve: () => void; - reject: (reason?: Error) => void; - }; + private toolUpdateEmitter = new EventEmitter(); + private cancellationError?: Error; private constructor( id: string, @@ -124,7 +122,6 @@ export class Task { this.taskState = 'submitted'; this.eventBus = eventBus; this.completedToolCalls = []; - this._resetToolCompletionPromise(); this.autoExecute = autoExecute; this.config.setFallbackModelHandler( // For a2a-server, we want to automatically switch to the fallback model @@ -179,22 +176,9 @@ export class Task { return metadata; } - private _resetToolCompletionPromise(): void { - this.toolCompletionPromise = new Promise((resolve, reject) => { - this.toolCompletionNotifier = { resolve, reject }; - }); - // If there are no pending calls when reset, resolve immediately. - if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) { - this.toolCompletionNotifier.resolve(); - } - } - private _registerToolCall(toolCallId: string, status: string): void { - const wasEmpty = this.pendingToolCalls.size === 0; this.pendingToolCalls.set(toolCallId, status); - if (wasEmpty) { - this._resetToolCompletionPromise(); - } + this.toolUpdateEmitter.emit('update'); logger.info( `[Task] Registered tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`, ); @@ -203,23 +187,47 @@ export class Task { private _resolveToolCall(toolCallId: string): void { if (this.pendingToolCalls.has(toolCallId)) { this.pendingToolCalls.delete(toolCallId); + this.toolUpdateEmitter.emit('update'); logger.info( `[Task] Resolved tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`, ); - if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) { - this.toolCompletionNotifier.resolve(); - } } } - async waitForPendingTools(): Promise { + private isAwaitingApprovalOnly(): boolean { if (this.pendingToolCalls.size === 0) { - return Promise.resolve(); + return false; + } + for (const [callId, status] of this.pendingToolCalls.entries()) { + if ( + status !== CoreToolCallStatus.AwaitingApproval || + this.toolsAlreadyConfirmed.has(callId) + ) { + return false; + } + } + return true; + } + + async waitForPendingTools(): Promise { + while (this.pendingToolCalls.size > 0 && !this.isAwaitingApprovalOnly()) { + if (this.cancellationError) { + const error = this.cancellationError; + this.cancellationError = undefined; + throw error; + } + logger.info( + `[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`, + ); + await new Promise((resolve) => + this.toolUpdateEmitter.once('update', resolve), + ); + } + if (this.cancellationError) { + const error = this.cancellationError; + this.cancellationError = undefined; + throw error; } - logger.info( - `[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`, - ); - await this.toolCompletionPromise; } cancelPendingTools(reason: string): void { @@ -228,15 +236,13 @@ export class Task { `[Task] Cancelling all ${this.pendingToolCalls.size} pending tool calls. Reason: ${reason}`, ); } - if (this.toolCompletionNotifier) { - this.toolCompletionNotifier.reject(new Error(reason)); - } + this.cancellationError = new Error(reason); this.pendingToolCalls.clear(); this.pendingCorrelationIds.clear(); + this.toolsAlreadyConfirmed.clear(); this.scheduler.cancelAll(); - // Reset the promise for any future operations, ensuring it's in a clean state. - this._resetToolCompletionPromise(); + this.toolUpdateEmitter.emit('update'); } private _createTextMessage( @@ -552,8 +558,8 @@ export class Task { // Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream. // The IDE client will open a new stream with the confirmation reply. - if (!wasAlreadyInputRequired && this.toolCompletionNotifier) { - this.toolCompletionNotifier.resolve(); + if (!wasAlreadyInputRequired) { + this.toolUpdateEmitter.emit('update'); } } } @@ -917,6 +923,7 @@ export class Task { const outcomeString = part.data['outcome']; this.toolsAlreadyConfirmed.add(callId); + this.toolUpdateEmitter.emit('update'); let confirmationOutcome: ToolConfirmationOutcome | undefined; @@ -1130,10 +1137,6 @@ export class Task { if (confirmationHandled) { anyConfirmationHandled = true; // If a confirmation was handled, the scheduler will now run the tool (or cancel it). - // We resolve the toolCompletionPromise manually in checkInputRequiredState - // to break the original execution loop, so we must reset it here so the - // new loop correctly awaits the tool's final execution. - this._resetToolCompletionPromise(); // We don't send anything to the LLM for this part. // The subsequent tool execution will eventually lead to resolveToolCall. continue;