From 1c4335686f3ac2d73fb1b283e7f02f57352f2939 Mon Sep 17 00:00:00 2001 From: Abhi Date: Tue, 20 Jan 2026 15:49:31 -0500 Subject: [PATCH] feat(a2a): switch from callback-based to event-driven tool scheduler This change transitions packages/a2a-server to use the event-driven Scheduler by default. It replaces the legacy direct callback mechanism with a MessageBus listener in the Task class to handle tool status updates, live output, and confirmations. - Added experimental.enableEventDrivenScheduler setting (defaults to true). - Refactored Task.ts to support both legacy and event-driven schedulers. - Implemented bus-based tool confirmation responses using correlationId. - Exported Scheduler from packages/core. - Added unit tests for the event-driven flow in A2A. --- .../src/agent/task-event-driven.test.ts | 173 ++++++++++++++ packages/a2a-server/src/agent/task.ts | 214 ++++++++++++++++-- packages/a2a-server/src/config/config.ts | 2 + packages/a2a-server/src/config/settings.ts | 3 + .../a2a-server/src/utils/testing_utils.ts | 1 + packages/core/src/index.ts | 1 + 6 files changed, 372 insertions(+), 22 deletions(-) create mode 100644 packages/a2a-server/src/agent/task-event-driven.test.ts diff --git a/packages/a2a-server/src/agent/task-event-driven.test.ts b/packages/a2a-server/src/agent/task-event-driven.test.ts new file mode 100644 index 0000000000..31e5e258f1 --- /dev/null +++ b/packages/a2a-server/src/agent/task-event-driven.test.ts @@ -0,0 +1,173 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; +import { Task } from './task.js'; +import { + type Config, + MessageBusType, + ToolConfirmationOutcome, + Scheduler, + type MessageBus, +} from '@google/gemini-cli-core'; +import { createMockConfig } from '../utils/testing_utils.js'; +import type { ExecutionEventBus } from '@a2a-js/sdk/server'; + +describe('Task Event-Driven Scheduler', () => { + let mockConfig: Config; + let mockEventBus: ExecutionEventBus; + let messageBus: MessageBus; + + beforeEach(() => { + vi.clearAllMocks(); + mockConfig = createMockConfig({ + isEventDrivenSchedulerEnabled: () => true, + }) as Config; + messageBus = mockConfig.getMessageBus(); + mockEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + }); + + it('should instantiate Scheduler when enabled', () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + expect(task.scheduler).toBeInstanceOf(Scheduler); + }); + + it('should subscribe to TOOL_CALLS_UPDATE and map status changes', async () => { + // @ts-expect-error - Calling private constructor + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'executing', + }; + + // Simulate MessageBus event + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + status: expect.objectContaining({ + state: 'submitted', // initial task state + }), + metadata: expect.objectContaining({ + coderAgent: expect.objectContaining({ + kind: 'tool-call-update', + }), + }), + }), + ); + }); + + it('should handle tool confirmations by publishing to MessageBus', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-1', + confirmationDetails: { type: 'info', title: 'test', prompt: 'test' }, + }; + + // Simulate MessageBus event to stash the correlationId + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + // Simulate A2A client confirmation + const part = { + kind: 'data', + data: { + callId: '1', + outcome: 'proceed_once', + }, + }; + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart(part); + expect(handled).toBe(true); + + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-1', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, + }), + ); + }); + + it('should handle output updates via the message bus', async () => { + // @ts-expect-error - Calling private constructor + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'executing', + liveOutput: 'chunk1', + }; + + // Simulate MessageBus event + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + // Should publish artifact update for output + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + kind: 'artifact-update', + artifact: expect.objectContaining({ + artifactId: 'tool-1-output', + parts: [{ kind: 'text', text: 'chunk1' }], + }), + }), + ); + }); +}); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 6fefd84919..c625eebd74 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -5,6 +5,7 @@ */ import { + Scheduler, CoreToolScheduler, type GeminiClient, GeminiEventType, @@ -29,6 +30,9 @@ import { type AnsiOutput, EDIT_TOOL_NAMES, processRestorableToolCalls, + MessageBusType, + type ToolCallsUpdateMessage, + type SerializableConfirmationDetails, } from '@google/gemini-cli-core'; import type { RequestContext } from '@a2a-js/sdk/server'; import { type ExecutionEventBus } from '@a2a-js/sdk/server'; @@ -62,10 +66,11 @@ type UnionKeys = T extends T ? keyof T : never; export class Task { id: string; contextId: string; - scheduler: CoreToolScheduler; + scheduler: Scheduler | CoreToolScheduler; config: Config; geminiClient: GeminiClient; pendingToolConfirmationDetails: Map; + pendingCorrelationIds: Map = new Map(); taskState: TaskState; eventBus?: ExecutionEventBus; completedToolCalls: CompletedToolCall[]; @@ -93,7 +98,13 @@ export class Task { this.id = id; this.contextId = contextId; this.config = config; - this.scheduler = this.createScheduler(); + + if (this.config.isEventDrivenSchedulerEnabled()) { + this.scheduler = this._setupEventDrivenScheduler(); + } else { + this.scheduler = this.createLegacyScheduler(); + } + this.geminiClient = this.config.getGeminiClient(); this.pendingToolConfirmationDetails = new Map(); this.taskState = 'submitted'; @@ -206,6 +217,13 @@ export class Task { this.toolCompletionNotifier.reject(new Error(reason)); } this.pendingToolCalls.clear(); + this.pendingCorrelationIds.clear(); + + if (this.scheduler instanceof Scheduler) { + this.scheduler.cancelAll(); + } else { + this.scheduler.cancelAll(new AbortController().signal); + } // Reset the promise for any future operations, ensuring it's in a clean state. this._resetToolCompletionPromise(); } @@ -450,7 +468,7 @@ export class Task { } } - private createScheduler(): CoreToolScheduler { + private createLegacyScheduler(): CoreToolScheduler { const scheduler = new CoreToolScheduler({ outputUpdateHandler: this._schedulerOutputUpdate.bind(this), onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this), @@ -461,6 +479,134 @@ export class Task { return scheduler; } + private _setupEventDrivenScheduler(): Scheduler { + const messageBus = this.config.getMessageBus(); + const scheduler = new Scheduler({ + config: this.config, + messageBus, + getPreferredEditor: () => DEFAULT_GUI_EDITOR, + }); + + messageBus.subscribe( + MessageBusType.TOOL_CALLS_UPDATE, + (message: unknown) => { + const event = message as ToolCallsUpdateMessage; + if (event.type !== MessageBusType.TOOL_CALLS_UPDATE) { + return; + } + + const toolCalls = event.toolCalls; + + toolCalls.forEach((tc) => { + const callId = tc.request.callId; + const previousStatus = this.pendingToolCalls.get(callId); + const hasChanged = previousStatus !== tc.status; + + // 1. Handle Output + if (tc.status === 'executing' && tc.liveOutput) { + this._schedulerOutputUpdate(callId, tc.liveOutput); + } + + // 2. Handle terminal states + if (['success', 'error', 'cancelled'].includes(tc.status)) { + if (hasChanged) { + const completedCall = tc as CompletedToolCall; + logger.info( + `[Task] Tool call ${callId} completed with status: ${tc.status}`, + ); + this.completedToolCalls.push(completedCall); + this._resolveToolCall(callId); + } + } else { + // Keep track of pending tools + this._registerToolCall(callId, tc.status); + } + + // 3. Handle Confirmation Stash + if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { + // Bridge the new serializable details back to the legacy shape for A2A UI + const details = + tc.confirmationDetails as SerializableConfirmationDetails; + + if (tc.correlationId) { + this.pendingCorrelationIds.set(callId, tc.correlationId); + } + + // In A2A, we just need to store the details so the client can fetch them. + // The actual confirmation will be handled by _handleToolConfirmationPart + // publishing back to the bus using the correlationId. + this.pendingToolConfirmationDetails.set(callId, { + ...details, + // Inject a dummy onConfirm for legacy UI compatibility if needed, + // though A2A should use the correlationId-based path now. + onConfirm: async () => {}, + } as ToolCallConfirmationDetails); + } + + // 4. Publish Status Updates to A2A event bus + if (hasChanged) { + const coderAgentMessage: CoderAgentMessage = + tc.status === 'awaiting_approval' + ? { kind: CoderAgentEvent.ToolCallConfirmationEvent } + : { kind: CoderAgentEvent.ToolCallUpdateEvent }; + + const message = this.toolStatusMessage(tc, this.id, this.contextId); + const statusUpdate = this._createStatusUpdateEvent( + this.taskState, + coderAgentMessage, + message, + false, + ); + this.eventBus?.publish(statusUpdate); + } + + // 5. Handle Auto-Execution (YOLO) + if ( + tc.status === 'awaiting_approval' && + tc.correlationId && + (this.autoExecute || + this.config.getApprovalMode() === ApprovalMode.YOLO) + ) { + logger.info(`[Task] Auto-approving tool call ${callId}`); + void messageBus.publish({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: tc.correlationId, + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, + }); + this.pendingToolConfirmationDetails.delete(callId); + } + }); + + // 6. Handle Input Required State + const allPendingStatuses = Array.from(this.pendingToolCalls.values()); + const isAwaitingApproval = allPendingStatuses.some( + (status) => status === 'awaiting_approval', + ); + const isExecuting = allPendingStatuses.some( + (status) => status === 'executing', + ); + + if ( + isAwaitingApproval && + !isExecuting && + !this.skipFinalTrueAfterInlineEdit + ) { + this.skipFinalTrueAfterInlineEdit = false; + this.setTaskStateAndPublishUpdate( + 'input-required', + { kind: CoderAgentEvent.StateChangeEvent }, + undefined, + undefined, + /*final*/ true, + ); + } + }, + ); + + return scheduler; + } + private _pickFields< T extends ToolCall | AnyDeclarativeTool, K extends UnionKeys, @@ -640,7 +786,11 @@ export class Task { }; this.setTaskStateAndPublishUpdate('working', stateChange); - await this.scheduler.schedule(updatedRequests, abortSignal); + if (this.scheduler instanceof Scheduler) { + await this.scheduler.schedule(updatedRequests, abortSignal); + } else { + await this.scheduler.schedule(updatedRequests, abortSignal); + } } async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise { @@ -780,8 +930,9 @@ export class Task { } const confirmationDetails = this.pendingToolConfirmationDetails.get(callId); + const correlationId = this.pendingCorrelationIds.get(callId); - if (!confirmationDetails) { + if (!confirmationDetails && !correlationId) { logger.warn( `[Task] Received tool confirmation for unknown or already processed callId: ${callId}`, ); @@ -803,24 +954,42 @@ export class Task { // This will trigger the scheduler to continue or cancel the specific tool. // The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled). - // If `edit` tool call, pass updated payload if presesent - if (confirmationDetails.type === 'edit') { - const payload = part.data['newContent'] - ? ({ - newContent: part.data['newContent'] as string, - } as ToolConfirmationPayload) - : undefined; - this.skipFinalTrueAfterInlineEdit = !!payload; - try { - await confirmationDetails.onConfirm(confirmationOutcome, payload); - } finally { - // Once confirmationDetails.onConfirm finishes (or fails) with a payload, - // reset skipFinalTrueAfterInlineEdit so that external callers receive - // their call has been completed. - this.skipFinalTrueAfterInlineEdit = false; + if (correlationId) { + const payload = + confirmationDetails?.type === 'edit' && part.data['newContent'] + ? ({ + newContent: part.data['newContent'] as string, + } as ToolConfirmationPayload) + : undefined; + + await this.config.getMessageBus().publish({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: confirmationOutcome !== ToolConfirmationOutcome.Cancel, + outcome: confirmationOutcome, + payload, + }); + } else if (confirmationDetails) { + // Legacy path + // If `edit` tool call, pass updated payload if presesent + if (confirmationDetails.type === 'edit') { + const payload = part.data['newContent'] + ? ({ + newContent: part.data['newContent'] as string, + } as ToolConfirmationPayload) + : undefined; + this.skipFinalTrueAfterInlineEdit = !!payload; + try { + await confirmationDetails.onConfirm(confirmationOutcome, payload); + } finally { + // Once confirmationDetails.onConfirm finishes (or fails) with a payload, + // reset skipFinalTrueAfterInlineEdit so that external callers receive + // their call has been completed. + this.skipFinalTrueAfterInlineEdit = false; + } + } else { + await confirmationDetails.onConfirm(confirmationOutcome); } - } else { - await confirmationDetails.onConfirm(confirmationOutcome); } } finally { if (gcpProject) { @@ -836,6 +1005,7 @@ export class Task { // Note !== ToolConfirmationOutcome.ModifyWithEditor does not work! if (confirmationOutcome !== 'modify_with_editor') { this.pendingToolConfirmationDetails.delete(callId); + this.pendingCorrelationIds.delete(callId); } // If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool. diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index b9e895dde0..094055ca4f 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -95,6 +95,8 @@ export async function loadConfig( extensionLoader, checkpointing, previewFeatures: settings.general?.previewFeatures, + enableEventDrivenScheduler: + settings.experimental?.enableEventDrivenScheduler ?? true, interactive: true, enableInteractiveShell: true, }; diff --git a/packages/a2a-server/src/config/settings.ts b/packages/a2a-server/src/config/settings.ts index 7040a80d4e..73886d0eb5 100644 --- a/packages/a2a-server/src/config/settings.ts +++ b/packages/a2a-server/src/config/settings.ts @@ -34,6 +34,9 @@ export interface Settings { general?: { previewFeatures?: boolean; }; + experimental?: { + enableEventDrivenScheduler?: boolean; + }; // Git-aware file filtering settings fileFiltering?: { diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 87c7315f82..4b08552b0b 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -60,6 +60,7 @@ export function createMockConfig( getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), getSessionId: vi.fn().mockReturnValue('test-session-id'), getUserTier: vi.fn(), + isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false), getMessageBus: vi.fn(), getPolicyEngine: vi.fn(), getEnableExtensionReloading: vi.fn().mockReturnValue(false), diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 506e602ebf..cf6cc9ea16 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -36,6 +36,7 @@ export * from './core/tokenLimits.js'; export * from './core/turn.js'; export * from './core/geminiRequest.js'; export * from './core/coreToolScheduler.js'; +export * from './scheduler/scheduler.js'; export * from './scheduler/types.js'; export * from './scheduler/tool-executor.js'; export * from './core/nonInteractiveToolExecutor.js';