mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-13 12:57:12 -07:00
fix(a2a-server): Resolve race condition in tool completion waiting (#26568)
This commit is contained in:
@@ -0,0 +1,173 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2026 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||||
|
import { Task } from './task.js';
|
||||||
|
import {
|
||||||
|
MessageBusType,
|
||||||
|
CoreToolCallStatus,
|
||||||
|
type Config,
|
||||||
|
type MessageBus,
|
||||||
|
} from '@google/gemini-cli-core';
|
||||||
|
import { createMockConfig } from '../utils/testing_utils.js';
|
||||||
|
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||||
|
|
||||||
|
describe('Task Race Condition', () => {
|
||||||
|
let mockConfig: Config;
|
||||||
|
let messageBus: MessageBus;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
messageBus = {
|
||||||
|
subscribe: vi.fn(),
|
||||||
|
unsubscribe: vi.fn(),
|
||||||
|
publish: vi.fn(),
|
||||||
|
} as unknown as MessageBus;
|
||||||
|
mockConfig = createMockConfig({
|
||||||
|
messageBus,
|
||||||
|
}) as Config;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not hang when multiple tool confirmations are processed while waiting', async () => {
|
||||||
|
// @ts-expect-error - private constructor
|
||||||
|
const task = new Task('task-id', 'context-id', mockConfig);
|
||||||
|
|
||||||
|
// 1. Register two tools as scheduled
|
||||||
|
task['_registerToolCall']('tool-1', 'scheduled');
|
||||||
|
task['_registerToolCall']('tool-2', 'scheduled');
|
||||||
|
|
||||||
|
// 2. Both transition to awaiting_approval
|
||||||
|
const updateHandler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||||
|
(c: unknown[]) => c[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
)?.[1];
|
||||||
|
|
||||||
|
updateHandler({
|
||||||
|
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
schedulerId: 'task-id',
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
request: { callId: 'tool-1', name: 't1' },
|
||||||
|
status: CoreToolCallStatus.AwaitingApproval,
|
||||||
|
correlationId: 'corr-1',
|
||||||
|
confirmationDetails: { type: 'info' },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
request: { callId: 'tool-2', name: 't2' },
|
||||||
|
status: CoreToolCallStatus.AwaitingApproval,
|
||||||
|
correlationId: 'corr-2',
|
||||||
|
confirmationDetails: { type: 'info' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
// 3. Confirm Tool 1. This makes isAwaitingApprovalOnly() return false.
|
||||||
|
for await (const _ of task.acceptUserMessage(
|
||||||
|
{
|
||||||
|
userMessage: {
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
kind: 'data',
|
||||||
|
data: { callId: 'tool-1', outcome: 'proceed_once' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
} as unknown as RequestContext,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
|
// consume generator
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Start waiting. This should now block because Tool 1 is confirmed (so we are waiting for its execution).
|
||||||
|
const waitPromise = task.waitForPendingTools();
|
||||||
|
|
||||||
|
// 5. Confirm Tool 2 while waiting.
|
||||||
|
for await (const _ of task.acceptUserMessage(
|
||||||
|
{
|
||||||
|
userMessage: {
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
kind: 'data',
|
||||||
|
data: { callId: 'tool-2', outcome: 'proceed_once' },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
} as unknown as RequestContext,
|
||||||
|
new AbortController().signal,
|
||||||
|
)) {
|
||||||
|
// consume generator
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Both tools complete successfully
|
||||||
|
updateHandler({
|
||||||
|
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
schedulerId: 'task-id',
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
request: { callId: 'tool-1', name: 't1' },
|
||||||
|
status: CoreToolCallStatus.Success,
|
||||||
|
response: { responseParts: [] },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
request: { callId: 'tool-2', name: 't2' },
|
||||||
|
status: CoreToolCallStatus.Success,
|
||||||
|
response: { responseParts: [] },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
// 7. Verify that the original waitPromise resolves.
|
||||||
|
await expect(waitPromise).resolves.toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should reject waitForPendingTools when tools are cancelled', async () => {
|
||||||
|
// @ts-expect-error - private constructor
|
||||||
|
const task = new Task('task-id', 'context-id', mockConfig);
|
||||||
|
|
||||||
|
// 1. Register a tool
|
||||||
|
task['_registerToolCall']('tool-1', 'scheduled');
|
||||||
|
|
||||||
|
// 2. Start waiting
|
||||||
|
const waitPromise = task.waitForPendingTools();
|
||||||
|
|
||||||
|
// 3. Cancel pending tools
|
||||||
|
task.cancelPendingTools('User requested cancellation');
|
||||||
|
|
||||||
|
// 4. Verify waitPromise rejects with the reason
|
||||||
|
await expect(waitPromise).rejects.toThrow('User requested cancellation');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle concurrent tool scheduling correctly', async () => {
|
||||||
|
// @ts-expect-error - private constructor
|
||||||
|
const task = new Task('task-id', 'context-id', mockConfig);
|
||||||
|
|
||||||
|
// 1. Register a tool and start waiting
|
||||||
|
task['_registerToolCall']('tool-1', 'scheduled');
|
||||||
|
const waitPromise = task.waitForPendingTools();
|
||||||
|
|
||||||
|
// 2. Schedule another tool concurrently (e.g. from a secondary user message)
|
||||||
|
// This should NOT resolve the current waitPromise until both are done
|
||||||
|
await task.scheduleToolCalls(
|
||||||
|
[{ callId: 'tool-2', name: 't2', args: {} }],
|
||||||
|
new AbortController().signal,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(task['pendingToolCalls'].size).toBe(2);
|
||||||
|
|
||||||
|
// 3. Resolve tool 1
|
||||||
|
task['_resolveToolCall']('tool-1');
|
||||||
|
|
||||||
|
// 4. Verify waitPromise is still pending
|
||||||
|
let resolved = false;
|
||||||
|
waitPromise.then(() => (resolved = true));
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||||
|
expect(resolved).toBe(false);
|
||||||
|
|
||||||
|
// 5. Resolve tool 2
|
||||||
|
task['_resolveToolCall']('tool-2');
|
||||||
|
|
||||||
|
// 6. Now it should resolve
|
||||||
|
await expect(waitPromise).resolves.toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -12,6 +12,7 @@ import {
|
|||||||
ApprovalMode,
|
ApprovalMode,
|
||||||
Scheduler,
|
Scheduler,
|
||||||
type MessageBus,
|
type MessageBus,
|
||||||
|
type ToolLiveOutput,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
import { createMockConfig } from '../utils/testing_utils.js';
|
import { createMockConfig } from '../utils/testing_utils.js';
|
||||||
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||||
@@ -608,6 +609,74 @@ describe('Task Event-Driven Scheduler', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should handle multi-turn tool resolution correctly', async () => {
|
||||||
|
// @ts-expect-error - Calling private constructor
|
||||||
|
const task = new Task('task-id', 'context-id', mockConfig);
|
||||||
|
|
||||||
|
task['_registerToolCall']('1', 'scheduled');
|
||||||
|
task['_registerToolCall']('2', 'scheduled');
|
||||||
|
|
||||||
|
const handler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||||
|
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
)?.[1];
|
||||||
|
|
||||||
|
// Turn 1: Resolve tool 1
|
||||||
|
handler({
|
||||||
|
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
request: { callId: '1', name: 't1' },
|
||||||
|
status: 'success',
|
||||||
|
response: { responseParts: [] },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
schedulerId: 'task-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(task['pendingToolCalls'].size).toBe(1);
|
||||||
|
expect(task['pendingToolCalls'].has('2')).toBe(true);
|
||||||
|
|
||||||
|
// Turn 2: Resolve tool 2
|
||||||
|
handler({
|
||||||
|
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
request: { callId: '2', name: 't2' },
|
||||||
|
status: 'success',
|
||||||
|
response: { responseParts: [] },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
schedulerId: 'task-id',
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(task['pendingToolCalls'].size).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle subagent progress events from the scheduler', async () => {
|
||||||
|
// @ts-expect-error - Calling private constructor
|
||||||
|
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||||
|
|
||||||
|
// Trigger _schedulerOutputUpdate with subagent progress
|
||||||
|
task['_schedulerOutputUpdate']('tool-1', {
|
||||||
|
isSubagentProgress: true,
|
||||||
|
agentName: 'researcher',
|
||||||
|
recentActivity: [],
|
||||||
|
} as ToolLiveOutput);
|
||||||
|
|
||||||
|
expect(mockEventBus.publish).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
kind: 'artifact-update',
|
||||||
|
artifact: expect.objectContaining({
|
||||||
|
parts: [
|
||||||
|
expect.objectContaining({
|
||||||
|
text: expect.stringContaining('researcher'),
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should wait for executing tools before transitioning to input-required state', async () => {
|
it('should wait for executing tools before transitioning to input-required state', async () => {
|
||||||
// @ts-expect-error - Calling private constructor
|
// @ts-expect-error - Calling private constructor
|
||||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ import type {
|
|||||||
Artifact,
|
Artifact,
|
||||||
} from '@a2a-js/sdk';
|
} from '@a2a-js/sdk';
|
||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
|
import { EventEmitter } from 'node:events';
|
||||||
import { logger } from '../utils/logger.js';
|
import { logger } from '../utils/logger.js';
|
||||||
import * as fs from 'node:fs/promises';
|
import * as fs from 'node:fs/promises';
|
||||||
import * as path from 'node:path';
|
import * as path from 'node:path';
|
||||||
@@ -99,11 +100,8 @@ export class Task {
|
|||||||
private pendingOutcomes: Map<string, ToolConfirmationOutcome | undefined> =
|
private pendingOutcomes: Map<string, ToolConfirmationOutcome | undefined> =
|
||||||
new Map(); // toolCallId --> outcome
|
new Map(); // toolCallId --> outcome
|
||||||
private toolsAlreadyConfirmed: Set<string> = new Set();
|
private toolsAlreadyConfirmed: Set<string> = new Set();
|
||||||
private toolCompletionPromise?: Promise<void>;
|
private toolUpdateEmitter = new EventEmitter();
|
||||||
private toolCompletionNotifier?: {
|
private cancellationError?: Error;
|
||||||
resolve: () => void;
|
|
||||||
reject: (reason?: Error) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
private constructor(
|
private constructor(
|
||||||
id: string,
|
id: string,
|
||||||
@@ -124,7 +122,6 @@ export class Task {
|
|||||||
this.taskState = 'submitted';
|
this.taskState = 'submitted';
|
||||||
this.eventBus = eventBus;
|
this.eventBus = eventBus;
|
||||||
this.completedToolCalls = [];
|
this.completedToolCalls = [];
|
||||||
this._resetToolCompletionPromise();
|
|
||||||
this.autoExecute = autoExecute;
|
this.autoExecute = autoExecute;
|
||||||
this.config.setFallbackModelHandler(
|
this.config.setFallbackModelHandler(
|
||||||
// For a2a-server, we want to automatically switch to the fallback model
|
// For a2a-server, we want to automatically switch to the fallback model
|
||||||
@@ -179,22 +176,9 @@ export class Task {
|
|||||||
return metadata;
|
return metadata;
|
||||||
}
|
}
|
||||||
|
|
||||||
private _resetToolCompletionPromise(): void {
|
|
||||||
this.toolCompletionPromise = new Promise((resolve, reject) => {
|
|
||||||
this.toolCompletionNotifier = { resolve, reject };
|
|
||||||
});
|
|
||||||
// If there are no pending calls when reset, resolve immediately.
|
|
||||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
|
||||||
this.toolCompletionNotifier.resolve();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private _registerToolCall(toolCallId: string, status: string): void {
|
private _registerToolCall(toolCallId: string, status: string): void {
|
||||||
const wasEmpty = this.pendingToolCalls.size === 0;
|
|
||||||
this.pendingToolCalls.set(toolCallId, status);
|
this.pendingToolCalls.set(toolCallId, status);
|
||||||
if (wasEmpty) {
|
this.toolUpdateEmitter.emit('update');
|
||||||
this._resetToolCompletionPromise();
|
|
||||||
}
|
|
||||||
logger.info(
|
logger.info(
|
||||||
`[Task] Registered tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
`[Task] Registered tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||||
);
|
);
|
||||||
@@ -203,23 +187,47 @@ export class Task {
|
|||||||
private _resolveToolCall(toolCallId: string): void {
|
private _resolveToolCall(toolCallId: string): void {
|
||||||
if (this.pendingToolCalls.has(toolCallId)) {
|
if (this.pendingToolCalls.has(toolCallId)) {
|
||||||
this.pendingToolCalls.delete(toolCallId);
|
this.pendingToolCalls.delete(toolCallId);
|
||||||
|
this.toolUpdateEmitter.emit('update');
|
||||||
logger.info(
|
logger.info(
|
||||||
`[Task] Resolved tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
`[Task] Resolved tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||||
);
|
);
|
||||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
|
||||||
this.toolCompletionNotifier.resolve();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async waitForPendingTools(): Promise<void> {
|
private isAwaitingApprovalOnly(): boolean {
|
||||||
if (this.pendingToolCalls.size === 0) {
|
if (this.pendingToolCalls.size === 0) {
|
||||||
return Promise.resolve();
|
return false;
|
||||||
|
}
|
||||||
|
for (const [callId, status] of this.pendingToolCalls.entries()) {
|
||||||
|
if (
|
||||||
|
status !== CoreToolCallStatus.AwaitingApproval ||
|
||||||
|
this.toolsAlreadyConfirmed.has(callId)
|
||||||
|
) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
async waitForPendingTools(): Promise<void> {
|
||||||
|
while (this.pendingToolCalls.size > 0 && !this.isAwaitingApprovalOnly()) {
|
||||||
|
if (this.cancellationError) {
|
||||||
|
const error = this.cancellationError;
|
||||||
|
this.cancellationError = undefined;
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
logger.info(
|
||||||
|
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
|
||||||
|
);
|
||||||
|
await new Promise((resolve) =>
|
||||||
|
this.toolUpdateEmitter.once('update', resolve),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (this.cancellationError) {
|
||||||
|
const error = this.cancellationError;
|
||||||
|
this.cancellationError = undefined;
|
||||||
|
throw error;
|
||||||
}
|
}
|
||||||
logger.info(
|
|
||||||
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
|
|
||||||
);
|
|
||||||
await this.toolCompletionPromise;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cancelPendingTools(reason: string): void {
|
cancelPendingTools(reason: string): void {
|
||||||
@@ -228,15 +236,13 @@ export class Task {
|
|||||||
`[Task] Cancelling all ${this.pendingToolCalls.size} pending tool calls. Reason: ${reason}`,
|
`[Task] Cancelling all ${this.pendingToolCalls.size} pending tool calls. Reason: ${reason}`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
if (this.toolCompletionNotifier) {
|
this.cancellationError = new Error(reason);
|
||||||
this.toolCompletionNotifier.reject(new Error(reason));
|
|
||||||
}
|
|
||||||
this.pendingToolCalls.clear();
|
this.pendingToolCalls.clear();
|
||||||
this.pendingCorrelationIds.clear();
|
this.pendingCorrelationIds.clear();
|
||||||
|
this.toolsAlreadyConfirmed.clear();
|
||||||
|
|
||||||
this.scheduler.cancelAll();
|
this.scheduler.cancelAll();
|
||||||
// Reset the promise for any future operations, ensuring it's in a clean state.
|
this.toolUpdateEmitter.emit('update');
|
||||||
this._resetToolCompletionPromise();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private _createTextMessage(
|
private _createTextMessage(
|
||||||
@@ -552,8 +558,8 @@ export class Task {
|
|||||||
|
|
||||||
// Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream.
|
// Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream.
|
||||||
// The IDE client will open a new stream with the confirmation reply.
|
// The IDE client will open a new stream with the confirmation reply.
|
||||||
if (!wasAlreadyInputRequired && this.toolCompletionNotifier) {
|
if (!wasAlreadyInputRequired) {
|
||||||
this.toolCompletionNotifier.resolve();
|
this.toolUpdateEmitter.emit('update');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -917,6 +923,7 @@ export class Task {
|
|||||||
const outcomeString = part.data['outcome'];
|
const outcomeString = part.data['outcome'];
|
||||||
|
|
||||||
this.toolsAlreadyConfirmed.add(callId);
|
this.toolsAlreadyConfirmed.add(callId);
|
||||||
|
this.toolUpdateEmitter.emit('update');
|
||||||
|
|
||||||
let confirmationOutcome: ToolConfirmationOutcome | undefined;
|
let confirmationOutcome: ToolConfirmationOutcome | undefined;
|
||||||
|
|
||||||
@@ -1130,10 +1137,6 @@ export class Task {
|
|||||||
if (confirmationHandled) {
|
if (confirmationHandled) {
|
||||||
anyConfirmationHandled = true;
|
anyConfirmationHandled = true;
|
||||||
// If a confirmation was handled, the scheduler will now run the tool (or cancel it).
|
// If a confirmation was handled, the scheduler will now run the tool (or cancel it).
|
||||||
// We resolve the toolCompletionPromise manually in checkInputRequiredState
|
|
||||||
// to break the original execution loop, so we must reset it here so the
|
|
||||||
// new loop correctly awaits the tool's final execution.
|
|
||||||
this._resetToolCompletionPromise();
|
|
||||||
// We don't send anything to the LLM for this part.
|
// We don't send anything to the LLM for this part.
|
||||||
// The subsequent tool execution will eventually lead to resolveToolCall.
|
// The subsequent tool execution will eventually lead to resolveToolCall.
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
Reference in New Issue
Block a user