feat(a2a): switch from callback-based to event-driven tool scheduler

This change transitions packages/a2a-server to use the event-driven
Scheduler by default. It replaces the legacy direct callback mechanism
with a MessageBus listener in the Task class to handle tool status
updates, live output, and confirmations.

- Added experimental.enableEventDrivenScheduler setting (defaults to true).
- Refactored Task.ts to support both legacy and event-driven schedulers.
- Implemented bus-based tool confirmation responses using correlationId.
- Exported Scheduler from packages/core.
- Added unit tests for the event-driven flow in A2A.
This commit is contained in:
Abhi
2026-01-20 15:49:31 -05:00
parent f42b4c80ac
commit 1c4335686f
6 changed files with 372 additions and 22 deletions

View File

@@ -0,0 +1,173 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { Task } from './task.js';
import {
type Config,
MessageBusType,
ToolConfirmationOutcome,
Scheduler,
type MessageBus,
} from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
describe('Task Event-Driven Scheduler', () => {
let mockConfig: Config;
let mockEventBus: ExecutionEventBus;
let messageBus: MessageBus;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = createMockConfig({
isEventDrivenSchedulerEnabled: () => true,
}) as Config;
messageBus = mockConfig.getMessageBus();
mockEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
});
it('should instantiate Scheduler when enabled', () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
expect(task.scheduler).toBeInstanceOf(Scheduler);
});
it('should subscribe to TOOL_CALLS_UPDATE and map status changes', async () => {
// @ts-expect-error - Calling private constructor
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'executing',
};
// Simulate MessageBus event
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
status: expect.objectContaining({
state: 'submitted', // initial task state
}),
metadata: expect.objectContaining({
coderAgent: expect.objectContaining({
kind: 'tool-call-update',
}),
}),
}),
);
});
it('should handle tool confirmations by publishing to MessageBus', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-1',
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
};
// Simulate MessageBus event to stash the correlationId
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
// Simulate A2A client confirmation
const part = {
kind: 'data',
data: {
callId: '1',
outcome: 'proceed_once',
},
};
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart(part);
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-1',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedOnce,
}),
);
});
it('should handle output updates via the message bus', async () => {
// @ts-expect-error - Calling private constructor
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'executing',
liveOutput: 'chunk1',
};
// Simulate MessageBus event
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
// Should publish artifact update for output
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'artifact-update',
artifact: expect.objectContaining({
artifactId: 'tool-1-output',
parts: [{ kind: 'text', text: 'chunk1' }],
}),
}),
);
});
});

View File

@@ -5,6 +5,7 @@
*/
import {
Scheduler,
CoreToolScheduler,
type GeminiClient,
GeminiEventType,
@@ -29,6 +30,9 @@ import {
type AnsiOutput,
EDIT_TOOL_NAMES,
processRestorableToolCalls,
MessageBusType,
type ToolCallsUpdateMessage,
type SerializableConfirmationDetails,
} from '@google/gemini-cli-core';
import type { RequestContext } from '@a2a-js/sdk/server';
import { type ExecutionEventBus } from '@a2a-js/sdk/server';
@@ -62,10 +66,11 @@ type UnionKeys<T> = T extends T ? keyof T : never;
export class Task {
id: string;
contextId: string;
scheduler: CoreToolScheduler;
scheduler: Scheduler | CoreToolScheduler;
config: Config;
geminiClient: GeminiClient;
pendingToolConfirmationDetails: Map<string, ToolCallConfirmationDetails>;
pendingCorrelationIds: Map<string, string> = new Map();
taskState: TaskState;
eventBus?: ExecutionEventBus;
completedToolCalls: CompletedToolCall[];
@@ -93,7 +98,13 @@ export class Task {
this.id = id;
this.contextId = contextId;
this.config = config;
this.scheduler = this.createScheduler();
if (this.config.isEventDrivenSchedulerEnabled()) {
this.scheduler = this._setupEventDrivenScheduler();
} else {
this.scheduler = this.createLegacyScheduler();
}
this.geminiClient = this.config.getGeminiClient();
this.pendingToolConfirmationDetails = new Map();
this.taskState = 'submitted';
@@ -206,6 +217,13 @@ export class Task {
this.toolCompletionNotifier.reject(new Error(reason));
}
this.pendingToolCalls.clear();
this.pendingCorrelationIds.clear();
if (this.scheduler instanceof Scheduler) {
this.scheduler.cancelAll();
} else {
this.scheduler.cancelAll(new AbortController().signal);
}
// Reset the promise for any future operations, ensuring it's in a clean state.
this._resetToolCompletionPromise();
}
@@ -450,7 +468,7 @@ export class Task {
}
}
private createScheduler(): CoreToolScheduler {
private createLegacyScheduler(): CoreToolScheduler {
const scheduler = new CoreToolScheduler({
outputUpdateHandler: this._schedulerOutputUpdate.bind(this),
onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this),
@@ -461,6 +479,134 @@ export class Task {
return scheduler;
}
private _setupEventDrivenScheduler(): Scheduler {
const messageBus = this.config.getMessageBus();
const scheduler = new Scheduler({
config: this.config,
messageBus,
getPreferredEditor: () => DEFAULT_GUI_EDITOR,
});
messageBus.subscribe(
MessageBusType.TOOL_CALLS_UPDATE,
(message: unknown) => {
const event = message as ToolCallsUpdateMessage;
if (event.type !== MessageBusType.TOOL_CALLS_UPDATE) {
return;
}
const toolCalls = event.toolCalls;
toolCalls.forEach((tc) => {
const callId = tc.request.callId;
const previousStatus = this.pendingToolCalls.get(callId);
const hasChanged = previousStatus !== tc.status;
// 1. Handle Output
if (tc.status === 'executing' && tc.liveOutput) {
this._schedulerOutputUpdate(callId, tc.liveOutput);
}
// 2. Handle terminal states
if (['success', 'error', 'cancelled'].includes(tc.status)) {
if (hasChanged) {
const completedCall = tc as CompletedToolCall;
logger.info(
`[Task] Tool call ${callId} completed with status: ${tc.status}`,
);
this.completedToolCalls.push(completedCall);
this._resolveToolCall(callId);
}
} else {
// Keep track of pending tools
this._registerToolCall(callId, tc.status);
}
// 3. Handle Confirmation Stash
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
// Bridge the new serializable details back to the legacy shape for A2A UI
const details =
tc.confirmationDetails as SerializableConfirmationDetails;
if (tc.correlationId) {
this.pendingCorrelationIds.set(callId, tc.correlationId);
}
// In A2A, we just need to store the details so the client can fetch them.
// The actual confirmation will be handled by _handleToolConfirmationPart
// publishing back to the bus using the correlationId.
this.pendingToolConfirmationDetails.set(callId, {
...details,
// Inject a dummy onConfirm for legacy UI compatibility if needed,
// though A2A should use the correlationId-based path now.
onConfirm: async () => {},
} as ToolCallConfirmationDetails);
}
// 4. Publish Status Updates to A2A event bus
if (hasChanged) {
const coderAgentMessage: CoderAgentMessage =
tc.status === 'awaiting_approval'
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
const message = this.toolStatusMessage(tc, this.id, this.contextId);
const statusUpdate = this._createStatusUpdateEvent(
this.taskState,
coderAgentMessage,
message,
false,
);
this.eventBus?.publish(statusUpdate);
}
// 5. Handle Auto-Execution (YOLO)
if (
tc.status === 'awaiting_approval' &&
tc.correlationId &&
(this.autoExecute ||
this.config.getApprovalMode() === ApprovalMode.YOLO)
) {
logger.info(`[Task] Auto-approving tool call ${callId}`);
void messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: tc.correlationId,
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedOnce,
});
this.pendingToolConfirmationDetails.delete(callId);
}
});
// 6. Handle Input Required State
const allPendingStatuses = Array.from(this.pendingToolCalls.values());
const isAwaitingApproval = allPendingStatuses.some(
(status) => status === 'awaiting_approval',
);
const isExecuting = allPendingStatuses.some(
(status) => status === 'executing',
);
if (
isAwaitingApproval &&
!isExecuting &&
!this.skipFinalTrueAfterInlineEdit
) {
this.skipFinalTrueAfterInlineEdit = false;
this.setTaskStateAndPublishUpdate(
'input-required',
{ kind: CoderAgentEvent.StateChangeEvent },
undefined,
undefined,
/*final*/ true,
);
}
},
);
return scheduler;
}
private _pickFields<
T extends ToolCall | AnyDeclarativeTool,
K extends UnionKeys<T>,
@@ -640,7 +786,11 @@ export class Task {
};
this.setTaskStateAndPublishUpdate('working', stateChange);
await this.scheduler.schedule(updatedRequests, abortSignal);
if (this.scheduler instanceof Scheduler) {
await this.scheduler.schedule(updatedRequests, abortSignal);
} else {
await this.scheduler.schedule(updatedRequests, abortSignal);
}
}
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
@@ -780,8 +930,9 @@ export class Task {
}
const confirmationDetails = this.pendingToolConfirmationDetails.get(callId);
const correlationId = this.pendingCorrelationIds.get(callId);
if (!confirmationDetails) {
if (!confirmationDetails && !correlationId) {
logger.warn(
`[Task] Received tool confirmation for unknown or already processed callId: ${callId}`,
);
@@ -803,24 +954,42 @@ export class Task {
// This will trigger the scheduler to continue or cancel the specific tool.
// The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled).
// If `edit` tool call, pass updated payload if presesent
if (confirmationDetails.type === 'edit') {
const payload = part.data['newContent']
? ({
newContent: part.data['newContent'] as string,
} as ToolConfirmationPayload)
: undefined;
this.skipFinalTrueAfterInlineEdit = !!payload;
try {
await confirmationDetails.onConfirm(confirmationOutcome, payload);
} finally {
// Once confirmationDetails.onConfirm finishes (or fails) with a payload,
// reset skipFinalTrueAfterInlineEdit so that external callers receive
// their call has been completed.
this.skipFinalTrueAfterInlineEdit = false;
if (correlationId) {
const payload =
confirmationDetails?.type === 'edit' && part.data['newContent']
? ({
newContent: part.data['newContent'] as string,
} as ToolConfirmationPayload)
: undefined;
await this.config.getMessageBus().publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId,
confirmed: confirmationOutcome !== ToolConfirmationOutcome.Cancel,
outcome: confirmationOutcome,
payload,
});
} else if (confirmationDetails) {
// Legacy path
// If `edit` tool call, pass updated payload if presesent
if (confirmationDetails.type === 'edit') {
const payload = part.data['newContent']
? ({
newContent: part.data['newContent'] as string,
} as ToolConfirmationPayload)
: undefined;
this.skipFinalTrueAfterInlineEdit = !!payload;
try {
await confirmationDetails.onConfirm(confirmationOutcome, payload);
} finally {
// Once confirmationDetails.onConfirm finishes (or fails) with a payload,
// reset skipFinalTrueAfterInlineEdit so that external callers receive
// their call has been completed.
this.skipFinalTrueAfterInlineEdit = false;
}
} else {
await confirmationDetails.onConfirm(confirmationOutcome);
}
} else {
await confirmationDetails.onConfirm(confirmationOutcome);
}
} finally {
if (gcpProject) {
@@ -836,6 +1005,7 @@ export class Task {
// Note !== ToolConfirmationOutcome.ModifyWithEditor does not work!
if (confirmationOutcome !== 'modify_with_editor') {
this.pendingToolConfirmationDetails.delete(callId);
this.pendingCorrelationIds.delete(callId);
}
// If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool.

View File

@@ -95,6 +95,8 @@ export async function loadConfig(
extensionLoader,
checkpointing,
previewFeatures: settings.general?.previewFeatures,
enableEventDrivenScheduler:
settings.experimental?.enableEventDrivenScheduler ?? true,
interactive: true,
enableInteractiveShell: true,
};

View File

@@ -34,6 +34,9 @@ export interface Settings {
general?: {
previewFeatures?: boolean;
};
experimental?: {
enableEventDrivenScheduler?: boolean;
};
// Git-aware file filtering settings
fileFiltering?: {

View File

@@ -60,6 +60,7 @@ export function createMockConfig(
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getUserTier: vi.fn(),
isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn(),
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),

View File

@@ -36,6 +36,7 @@ export * from './core/tokenLimits.js';
export * from './core/turn.js';
export * from './core/geminiRequest.js';
export * from './core/coreToolScheduler.js';
export * from './scheduler/scheduler.js';
export * from './scheduler/types.js';
export * from './scheduler/tool-executor.js';
export * from './core/nonInteractiveToolExecutor.js';