mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 23:51:16 -07:00
feat(a2a): switch from callback-based to event-driven tool scheduler
This change transitions packages/a2a-server to use the event-driven Scheduler by default. It replaces the legacy direct callback mechanism with a MessageBus listener in the Task class to handle tool status updates, live output, and confirmations. - Added experimental.enableEventDrivenScheduler setting (defaults to true). - Refactored Task.ts to support both legacy and event-driven schedulers. - Implemented bus-based tool confirmation responses using correlationId. - Exported Scheduler from packages/core. - Added unit tests for the event-driven flow in A2A.
This commit is contained in:
173
packages/a2a-server/src/agent/task-event-driven.test.ts
Normal file
173
packages/a2a-server/src/agent/task-event-driven.test.ts
Normal file
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { Task } from './task.js';
|
||||
import {
|
||||
type Config,
|
||||
MessageBusType,
|
||||
ToolConfirmationOutcome,
|
||||
Scheduler,
|
||||
type MessageBus,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
|
||||
describe('Task Event-Driven Scheduler', () => {
|
||||
let mockConfig: Config;
|
||||
let mockEventBus: ExecutionEventBus;
|
||||
let messageBus: MessageBus;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockConfig = createMockConfig({
|
||||
isEventDrivenSchedulerEnabled: () => true,
|
||||
}) as Config;
|
||||
messageBus = mockConfig.getMessageBus();
|
||||
mockEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
it('should instantiate Scheduler when enabled', () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
expect(task.scheduler).toBeInstanceOf(Scheduler);
|
||||
});
|
||||
|
||||
it('should subscribe to TOOL_CALLS_UPDATE and map status changes', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
const toolCall = {
|
||||
request: { callId: '1', name: 'ls', args: {} },
|
||||
status: 'executing',
|
||||
};
|
||||
|
||||
// Simulate MessageBus event
|
||||
// Simulate MessageBus event
|
||||
const handler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||
)?.[1];
|
||||
|
||||
if (!handler) {
|
||||
throw new Error('TOOL_CALLS_UPDATE handler not found');
|
||||
}
|
||||
|
||||
handler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
toolCalls: [toolCall],
|
||||
});
|
||||
|
||||
expect(mockEventBus.publish).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
status: expect.objectContaining({
|
||||
state: 'submitted', // initial task state
|
||||
}),
|
||||
metadata: expect.objectContaining({
|
||||
coderAgent: expect.objectContaining({
|
||||
kind: 'tool-call-update',
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle tool confirmations by publishing to MessageBus', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
const toolCall = {
|
||||
request: { callId: '1', name: 'ls', args: {} },
|
||||
status: 'awaiting_approval',
|
||||
correlationId: 'corr-1',
|
||||
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
|
||||
};
|
||||
|
||||
// Simulate MessageBus event to stash the correlationId
|
||||
// Simulate MessageBus event
|
||||
const handler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||
)?.[1];
|
||||
|
||||
if (!handler) {
|
||||
throw new Error('TOOL_CALLS_UPDATE handler not found');
|
||||
}
|
||||
|
||||
handler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
toolCalls: [toolCall],
|
||||
});
|
||||
|
||||
// Simulate A2A client confirmation
|
||||
const part = {
|
||||
kind: 'data',
|
||||
data: {
|
||||
callId: '1',
|
||||
outcome: 'proceed_once',
|
||||
},
|
||||
};
|
||||
|
||||
const handled = await (
|
||||
task as unknown as {
|
||||
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
|
||||
}
|
||||
)._handleToolConfirmationPart(part);
|
||||
expect(handled).toBe(true);
|
||||
|
||||
expect(messageBus.publish).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'corr-1',
|
||||
confirmed: true,
|
||||
outcome: ToolConfirmationOutcome.ProceedOnce,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle output updates via the message bus', async () => {
|
||||
// @ts-expect-error - Calling private constructor
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
|
||||
|
||||
const toolCall = {
|
||||
request: { callId: '1', name: 'ls', args: {} },
|
||||
status: 'executing',
|
||||
liveOutput: 'chunk1',
|
||||
};
|
||||
|
||||
// Simulate MessageBus event
|
||||
// Simulate MessageBus event
|
||||
const handler = (messageBus.subscribe as Mock).mock.calls.find(
|
||||
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
|
||||
)?.[1];
|
||||
|
||||
if (!handler) {
|
||||
throw new Error('TOOL_CALLS_UPDATE handler not found');
|
||||
}
|
||||
|
||||
handler({
|
||||
type: MessageBusType.TOOL_CALLS_UPDATE,
|
||||
toolCalls: [toolCall],
|
||||
});
|
||||
|
||||
// Should publish artifact update for output
|
||||
expect(mockEventBus.publish).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
kind: 'artifact-update',
|
||||
artifact: expect.objectContaining({
|
||||
artifactId: 'tool-1-output',
|
||||
parts: [{ kind: 'text', text: 'chunk1' }],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import {
|
||||
Scheduler,
|
||||
CoreToolScheduler,
|
||||
type GeminiClient,
|
||||
GeminiEventType,
|
||||
@@ -29,6 +30,9 @@ import {
|
||||
type AnsiOutput,
|
||||
EDIT_TOOL_NAMES,
|
||||
processRestorableToolCalls,
|
||||
MessageBusType,
|
||||
type ToolCallsUpdateMessage,
|
||||
type SerializableConfirmationDetails,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||
import { type ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
@@ -62,10 +66,11 @@ type UnionKeys<T> = T extends T ? keyof T : never;
|
||||
export class Task {
|
||||
id: string;
|
||||
contextId: string;
|
||||
scheduler: CoreToolScheduler;
|
||||
scheduler: Scheduler | CoreToolScheduler;
|
||||
config: Config;
|
||||
geminiClient: GeminiClient;
|
||||
pendingToolConfirmationDetails: Map<string, ToolCallConfirmationDetails>;
|
||||
pendingCorrelationIds: Map<string, string> = new Map();
|
||||
taskState: TaskState;
|
||||
eventBus?: ExecutionEventBus;
|
||||
completedToolCalls: CompletedToolCall[];
|
||||
@@ -93,7 +98,13 @@ export class Task {
|
||||
this.id = id;
|
||||
this.contextId = contextId;
|
||||
this.config = config;
|
||||
this.scheduler = this.createScheduler();
|
||||
|
||||
if (this.config.isEventDrivenSchedulerEnabled()) {
|
||||
this.scheduler = this._setupEventDrivenScheduler();
|
||||
} else {
|
||||
this.scheduler = this.createLegacyScheduler();
|
||||
}
|
||||
|
||||
this.geminiClient = this.config.getGeminiClient();
|
||||
this.pendingToolConfirmationDetails = new Map();
|
||||
this.taskState = 'submitted';
|
||||
@@ -206,6 +217,13 @@ export class Task {
|
||||
this.toolCompletionNotifier.reject(new Error(reason));
|
||||
}
|
||||
this.pendingToolCalls.clear();
|
||||
this.pendingCorrelationIds.clear();
|
||||
|
||||
if (this.scheduler instanceof Scheduler) {
|
||||
this.scheduler.cancelAll();
|
||||
} else {
|
||||
this.scheduler.cancelAll(new AbortController().signal);
|
||||
}
|
||||
// Reset the promise for any future operations, ensuring it's in a clean state.
|
||||
this._resetToolCompletionPromise();
|
||||
}
|
||||
@@ -450,7 +468,7 @@ export class Task {
|
||||
}
|
||||
}
|
||||
|
||||
private createScheduler(): CoreToolScheduler {
|
||||
private createLegacyScheduler(): CoreToolScheduler {
|
||||
const scheduler = new CoreToolScheduler({
|
||||
outputUpdateHandler: this._schedulerOutputUpdate.bind(this),
|
||||
onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this),
|
||||
@@ -461,6 +479,134 @@ export class Task {
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
private _setupEventDrivenScheduler(): Scheduler {
|
||||
const messageBus = this.config.getMessageBus();
|
||||
const scheduler = new Scheduler({
|
||||
config: this.config,
|
||||
messageBus,
|
||||
getPreferredEditor: () => DEFAULT_GUI_EDITOR,
|
||||
});
|
||||
|
||||
messageBus.subscribe(
|
||||
MessageBusType.TOOL_CALLS_UPDATE,
|
||||
(message: unknown) => {
|
||||
const event = message as ToolCallsUpdateMessage;
|
||||
if (event.type !== MessageBusType.TOOL_CALLS_UPDATE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const toolCalls = event.toolCalls;
|
||||
|
||||
toolCalls.forEach((tc) => {
|
||||
const callId = tc.request.callId;
|
||||
const previousStatus = this.pendingToolCalls.get(callId);
|
||||
const hasChanged = previousStatus !== tc.status;
|
||||
|
||||
// 1. Handle Output
|
||||
if (tc.status === 'executing' && tc.liveOutput) {
|
||||
this._schedulerOutputUpdate(callId, tc.liveOutput);
|
||||
}
|
||||
|
||||
// 2. Handle terminal states
|
||||
if (['success', 'error', 'cancelled'].includes(tc.status)) {
|
||||
if (hasChanged) {
|
||||
const completedCall = tc as CompletedToolCall;
|
||||
logger.info(
|
||||
`[Task] Tool call ${callId} completed with status: ${tc.status}`,
|
||||
);
|
||||
this.completedToolCalls.push(completedCall);
|
||||
this._resolveToolCall(callId);
|
||||
}
|
||||
} else {
|
||||
// Keep track of pending tools
|
||||
this._registerToolCall(callId, tc.status);
|
||||
}
|
||||
|
||||
// 3. Handle Confirmation Stash
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
// Bridge the new serializable details back to the legacy shape for A2A UI
|
||||
const details =
|
||||
tc.confirmationDetails as SerializableConfirmationDetails;
|
||||
|
||||
if (tc.correlationId) {
|
||||
this.pendingCorrelationIds.set(callId, tc.correlationId);
|
||||
}
|
||||
|
||||
// In A2A, we just need to store the details so the client can fetch them.
|
||||
// The actual confirmation will be handled by _handleToolConfirmationPart
|
||||
// publishing back to the bus using the correlationId.
|
||||
this.pendingToolConfirmationDetails.set(callId, {
|
||||
...details,
|
||||
// Inject a dummy onConfirm for legacy UI compatibility if needed,
|
||||
// though A2A should use the correlationId-based path now.
|
||||
onConfirm: async () => {},
|
||||
} as ToolCallConfirmationDetails);
|
||||
}
|
||||
|
||||
// 4. Publish Status Updates to A2A event bus
|
||||
if (hasChanged) {
|
||||
const coderAgentMessage: CoderAgentMessage =
|
||||
tc.status === 'awaiting_approval'
|
||||
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
|
||||
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
|
||||
|
||||
const message = this.toolStatusMessage(tc, this.id, this.contextId);
|
||||
const statusUpdate = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
coderAgentMessage,
|
||||
message,
|
||||
false,
|
||||
);
|
||||
this.eventBus?.publish(statusUpdate);
|
||||
}
|
||||
|
||||
// 5. Handle Auto-Execution (YOLO)
|
||||
if (
|
||||
tc.status === 'awaiting_approval' &&
|
||||
tc.correlationId &&
|
||||
(this.autoExecute ||
|
||||
this.config.getApprovalMode() === ApprovalMode.YOLO)
|
||||
) {
|
||||
logger.info(`[Task] Auto-approving tool call ${callId}`);
|
||||
void messageBus.publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: tc.correlationId,
|
||||
confirmed: true,
|
||||
outcome: ToolConfirmationOutcome.ProceedOnce,
|
||||
});
|
||||
this.pendingToolConfirmationDetails.delete(callId);
|
||||
}
|
||||
});
|
||||
|
||||
// 6. Handle Input Required State
|
||||
const allPendingStatuses = Array.from(this.pendingToolCalls.values());
|
||||
const isAwaitingApproval = allPendingStatuses.some(
|
||||
(status) => status === 'awaiting_approval',
|
||||
);
|
||||
const isExecuting = allPendingStatuses.some(
|
||||
(status) => status === 'executing',
|
||||
);
|
||||
|
||||
if (
|
||||
isAwaitingApproval &&
|
||||
!isExecuting &&
|
||||
!this.skipFinalTrueAfterInlineEdit
|
||||
) {
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
this.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
{ kind: CoderAgentEvent.StateChangeEvent },
|
||||
undefined,
|
||||
undefined,
|
||||
/*final*/ true,
|
||||
);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
private _pickFields<
|
||||
T extends ToolCall | AnyDeclarativeTool,
|
||||
K extends UnionKeys<T>,
|
||||
@@ -640,7 +786,11 @@ export class Task {
|
||||
};
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
|
||||
await this.scheduler.schedule(updatedRequests, abortSignal);
|
||||
if (this.scheduler instanceof Scheduler) {
|
||||
await this.scheduler.schedule(updatedRequests, abortSignal);
|
||||
} else {
|
||||
await this.scheduler.schedule(updatedRequests, abortSignal);
|
||||
}
|
||||
}
|
||||
|
||||
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
|
||||
@@ -780,8 +930,9 @@ export class Task {
|
||||
}
|
||||
|
||||
const confirmationDetails = this.pendingToolConfirmationDetails.get(callId);
|
||||
const correlationId = this.pendingCorrelationIds.get(callId);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
if (!confirmationDetails && !correlationId) {
|
||||
logger.warn(
|
||||
`[Task] Received tool confirmation for unknown or already processed callId: ${callId}`,
|
||||
);
|
||||
@@ -803,24 +954,42 @@ export class Task {
|
||||
// This will trigger the scheduler to continue or cancel the specific tool.
|
||||
// The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled).
|
||||
|
||||
// If `edit` tool call, pass updated payload if presesent
|
||||
if (confirmationDetails.type === 'edit') {
|
||||
const payload = part.data['newContent']
|
||||
? ({
|
||||
newContent: part.data['newContent'] as string,
|
||||
} as ToolConfirmationPayload)
|
||||
: undefined;
|
||||
this.skipFinalTrueAfterInlineEdit = !!payload;
|
||||
try {
|
||||
await confirmationDetails.onConfirm(confirmationOutcome, payload);
|
||||
} finally {
|
||||
// Once confirmationDetails.onConfirm finishes (or fails) with a payload,
|
||||
// reset skipFinalTrueAfterInlineEdit so that external callers receive
|
||||
// their call has been completed.
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
if (correlationId) {
|
||||
const payload =
|
||||
confirmationDetails?.type === 'edit' && part.data['newContent']
|
||||
? ({
|
||||
newContent: part.data['newContent'] as string,
|
||||
} as ToolConfirmationPayload)
|
||||
: undefined;
|
||||
|
||||
await this.config.getMessageBus().publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId,
|
||||
confirmed: confirmationOutcome !== ToolConfirmationOutcome.Cancel,
|
||||
outcome: confirmationOutcome,
|
||||
payload,
|
||||
});
|
||||
} else if (confirmationDetails) {
|
||||
// Legacy path
|
||||
// If `edit` tool call, pass updated payload if presesent
|
||||
if (confirmationDetails.type === 'edit') {
|
||||
const payload = part.data['newContent']
|
||||
? ({
|
||||
newContent: part.data['newContent'] as string,
|
||||
} as ToolConfirmationPayload)
|
||||
: undefined;
|
||||
this.skipFinalTrueAfterInlineEdit = !!payload;
|
||||
try {
|
||||
await confirmationDetails.onConfirm(confirmationOutcome, payload);
|
||||
} finally {
|
||||
// Once confirmationDetails.onConfirm finishes (or fails) with a payload,
|
||||
// reset skipFinalTrueAfterInlineEdit so that external callers receive
|
||||
// their call has been completed.
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
}
|
||||
} else {
|
||||
await confirmationDetails.onConfirm(confirmationOutcome);
|
||||
}
|
||||
} else {
|
||||
await confirmationDetails.onConfirm(confirmationOutcome);
|
||||
}
|
||||
} finally {
|
||||
if (gcpProject) {
|
||||
@@ -836,6 +1005,7 @@ export class Task {
|
||||
// Note !== ToolConfirmationOutcome.ModifyWithEditor does not work!
|
||||
if (confirmationOutcome !== 'modify_with_editor') {
|
||||
this.pendingToolConfirmationDetails.delete(callId);
|
||||
this.pendingCorrelationIds.delete(callId);
|
||||
}
|
||||
|
||||
// If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool.
|
||||
|
||||
@@ -95,6 +95,8 @@ export async function loadConfig(
|
||||
extensionLoader,
|
||||
checkpointing,
|
||||
previewFeatures: settings.general?.previewFeatures,
|
||||
enableEventDrivenScheduler:
|
||||
settings.experimental?.enableEventDrivenScheduler ?? true,
|
||||
interactive: true,
|
||||
enableInteractiveShell: true,
|
||||
};
|
||||
|
||||
@@ -34,6 +34,9 @@ export interface Settings {
|
||||
general?: {
|
||||
previewFeatures?: boolean;
|
||||
};
|
||||
experimental?: {
|
||||
enableEventDrivenScheduler?: boolean;
|
||||
};
|
||||
|
||||
// Git-aware file filtering settings
|
||||
fileFiltering?: {
|
||||
|
||||
@@ -60,6 +60,7 @@ export function createMockConfig(
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getUserTier: vi.fn(),
|
||||
isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false),
|
||||
getMessageBus: vi.fn(),
|
||||
getPolicyEngine: vi.fn(),
|
||||
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
|
||||
|
||||
@@ -36,6 +36,7 @@ export * from './core/tokenLimits.js';
|
||||
export * from './core/turn.js';
|
||||
export * from './core/geminiRequest.js';
|
||||
export * from './core/coreToolScheduler.js';
|
||||
export * from './scheduler/scheduler.js';
|
||||
export * from './scheduler/types.js';
|
||||
export * from './scheduler/tool-executor.js';
|
||||
export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
Reference in New Issue
Block a user