diff --git a/packages/a2a-server/src/agent/executor.test.ts b/packages/a2a-server/src/agent/executor.test.ts new file mode 100644 index 0000000000..2b77f3006c --- /dev/null +++ b/packages/a2a-server/src/agent/executor.test.ts @@ -0,0 +1,248 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; +import { CoderAgentExecutor } from './executor.js'; +import type { + ExecutionEventBus, + RequestContext, + TaskStore, +} from '@a2a-js/sdk/server'; +import { EventEmitter } from 'node:events'; +import { requestStorage } from '../http/requestStorage.js'; + +// Mocks for constructor dependencies +vi.mock('../config/config.js', () => ({ + loadConfig: vi.fn().mockReturnValue({ + getSessionId: () => 'test-session', + getTargetDir: () => '/tmp', + getCheckpointingEnabled: () => false, + }), + loadEnvironment: vi.fn(), + setTargetDir: vi.fn().mockReturnValue('/tmp'), +})); + +vi.mock('../config/settings.js', () => ({ + loadSettings: vi.fn().mockReturnValue({}), +})); + +vi.mock('../config/extension.js', () => ({ + loadExtensions: vi.fn().mockReturnValue([]), +})); + +vi.mock('../http/requestStorage.js', () => ({ + requestStorage: { + getStore: vi.fn(), + }, +})); + +vi.mock('./task.js', () => { + const mockTaskInstance = (taskId: string, contextId: string) => ({ + id: taskId, + contextId, + taskState: 'working', + acceptUserMessage: vi + .fn() + .mockImplementation(async function* (context, aborted) { + const isConfirmation = ( + context.userMessage.parts as Array<{ kind: string }> + ).some((p) => p.kind === 'confirmation'); + // Hang only for main user messages (text), allow confirmations to finish quickly + if (!isConfirmation && aborted) { + await new Promise((resolve) => { + aborted.addEventListener('abort', resolve, { once: true }); + }); + } + yield { type: 'content', value: 'hello' }; + }), + acceptAgentMessage: vi.fn().mockResolvedValue(undefined), + scheduleToolCalls: vi.fn().mockResolvedValue(undefined), + waitForPendingTools: vi.fn().mockResolvedValue(undefined), + getAndClearCompletedTools: vi.fn().mockReturnValue([]), + addToolResponsesToHistory: vi.fn(), + sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}), + cancelPendingTools: vi.fn(), + setTaskStateAndPublishUpdate: vi.fn(), + dispose: vi.fn(), + getMetadata: vi.fn().mockResolvedValue({}), + geminiClient: { + initialize: vi.fn().mockResolvedValue(undefined), + }, + toSDKTask: () => ({ + id: taskId, + contextId, + kind: 'task', + status: { state: 'working', timestamp: new Date().toISOString() }, + metadata: {}, + history: [], + artifacts: [], + }), + }); + + const MockTask = vi.fn().mockImplementation(mockTaskInstance); + (MockTask as unknown as { create: Mock }).create = vi + .fn() + .mockImplementation(async (taskId: string, contextId: string) => + mockTaskInstance(taskId, contextId), + ); + + return { Task: MockTask }; +}); + +describe('CoderAgentExecutor', () => { + let executor: CoderAgentExecutor; + let mockTaskStore: TaskStore; + let mockEventBus: ExecutionEventBus; + + beforeEach(() => { + vi.clearAllMocks(); + mockTaskStore = { + save: vi.fn().mockResolvedValue(undefined), + load: vi.fn().mockResolvedValue(undefined), + delete: vi.fn().mockResolvedValue(undefined), + list: vi.fn().mockResolvedValue([]), + } as unknown as TaskStore; + + mockEventBus = new EventEmitter() as unknown as ExecutionEventBus; + mockEventBus.publish = vi.fn(); + mockEventBus.finished = vi.fn(); + + executor = new CoderAgentExecutor(mockTaskStore); + }); + + it('should distinguish between primary and secondary execution', async () => { + const taskId = 'test-task'; + const contextId = 'test-context'; + + const mockSocket = new EventEmitter(); + const requestContext = { + userMessage: { + messageId: 'msg-1', + taskId, + contextId, + parts: [{ kind: 'text', text: 'hi' }], + metadata: { + coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' }, + }, + }, + } as unknown as RequestContext; + + // Mock requestStorage for primary + (requestStorage.getStore as Mock).mockReturnValue({ + req: { socket: mockSocket }, + }); + + // First execution (Primary) + const primaryPromise = executor.execute(requestContext, mockEventBus); + + // Give it enough time to reach line 490 in executor.ts + await new Promise((resolve) => setTimeout(resolve, 50)); + + expect( + ( + executor as unknown as { executingTasks: Set } + ).executingTasks.has(taskId), + ).toBe(true); + const wrapper = executor.getTask(taskId); + expect(wrapper).toBeDefined(); + + // Mock requestStorage for secondary + const secondarySocket = new EventEmitter(); + (requestStorage.getStore as Mock).mockReturnValue({ + req: { socket: secondarySocket }, + }); + + const secondaryRequestContext = { + userMessage: { + messageId: 'msg-2', + taskId, + contextId, + parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }], + metadata: { + coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' }, + }, + }, + } as unknown as RequestContext; + + const secondaryPromise = executor.execute( + secondaryRequestContext, + mockEventBus, + ); + + // Secondary execution should NOT add to executingTasks (already there) + // and should return early after its loop + await secondaryPromise; + + // Task should still be in executingTasks and NOT disposed + expect( + ( + executor as unknown as { executingTasks: Set } + ).executingTasks.has(taskId), + ).toBe(true); + expect(wrapper?.task.dispose).not.toHaveBeenCalled(); + + // Now simulate secondary socket closure - it should NOT affect primary + secondarySocket.emit('end'); + expect( + ( + executor as unknown as { executingTasks: Set } + ).executingTasks.has(taskId), + ).toBe(true); + expect(wrapper?.task.dispose).not.toHaveBeenCalled(); + + // Set to terminal state to verify disposal on finish + wrapper!.task.taskState = 'completed'; + + // Now close primary socket + mockSocket.emit('end'); + + await primaryPromise; + + expect( + ( + executor as unknown as { executingTasks: Set } + ).executingTasks.has(taskId), + ).toBe(false); + expect(wrapper?.task.dispose).toHaveBeenCalled(); + }); + + it('should evict task from cache when it reaches terminal state', async () => { + const taskId = 'test-task-terminal'; + const contextId = 'test-context'; + + const mockSocket = new EventEmitter(); + (requestStorage.getStore as Mock).mockReturnValue({ + req: { socket: mockSocket }, + }); + + const requestContext = { + userMessage: { + messageId: 'msg-1', + taskId, + contextId, + parts: [{ kind: 'text', text: 'hi' }], + metadata: { + coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' }, + }, + }, + } as unknown as RequestContext; + + const primaryPromise = executor.execute(requestContext, mockEventBus); + await new Promise((resolve) => setTimeout(resolve, 50)); + + const wrapper = executor.getTask(taskId)!; + expect(wrapper).toBeDefined(); + // Simulate terminal state + wrapper.task.taskState = 'completed'; + + // Finish primary execution + mockSocket.emit('end'); + await primaryPromise; + + expect(executor.getTask(taskId)).toBeUndefined(); + expect(wrapper.task.dispose).toHaveBeenCalled(); + }); +}); diff --git a/packages/a2a-server/src/agent/executor.ts b/packages/a2a-server/src/agent/executor.ts index 7fc35657fb..dbb8269376 100644 --- a/packages/a2a-server/src/agent/executor.ts +++ b/packages/a2a-server/src/agent/executor.ts @@ -252,6 +252,10 @@ export class CoderAgentExecutor implements AgentExecutor { ); await this.taskStore?.save(wrapper.toSDKTask()); logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`); + + // Cleanup listener subscriptions to avoid memory leaks. + wrapper.task.dispose(); + this.tasks.delete(taskId); } catch (error) { const errorMessage = error instanceof Error ? error.message : 'Unknown error'; @@ -320,23 +324,26 @@ export class CoderAgentExecutor implements AgentExecutor { if (store) { // Grab the raw socket from the request object const socket = store.req.socket; - const onClientEnd = () => { + const onSocketEnd = () => { logger.info( - `[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`, + `[CoderAgentExecutor] Socket ended for message ${userMessage.messageId} (task ${taskId}). Aborting execution loop.`, ); if (!abortController.signal.aborted) { abortController.abort(); } // Clean up the listener to prevent memory leaks - socket.removeListener('close', onClientEnd); + socket.removeListener('end', onSocketEnd); }; // Listen on the socket's 'end' event (remote closed the connection) - socket.on('end', onClientEnd); + socket.on('end', onSocketEnd); + socket.once('close', () => { + socket.removeListener('end', onSocketEnd); + }); // It's also good practice to remove the listener if the task completes successfully abortSignal.addEventListener('abort', () => { - socket.removeListener('end', onClientEnd); + socket.removeListener('end', onSocketEnd); }); logger.info( `[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`, @@ -457,6 +464,26 @@ export class CoderAgentExecutor implements AgentExecutor { return; } + // Check if this is the primary/initial execution for this task + const isPrimaryExecution = !this.executingTasks.has(taskId); + + if (!isPrimaryExecution) { + logger.info( + `[CoderAgentExecutor] Primary execution already active for task ${taskId}. Starting secondary loop for message ${userMessage.messageId}.`, + ); + currentTask.eventBus = eventBus; + for await (const _ of currentTask.acceptUserMessage( + requestContext, + abortController.signal, + )) { + logger.info( + `[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`, + ); + } + // End this execution-- the original/source will be resumed. + return; + } + logger.info( `[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`, ); @@ -598,18 +625,30 @@ export class CoderAgentExecutor implements AgentExecutor { } } } finally { - this.executingTasks.delete(taskId); - logger.info( - `[CoderAgentExecutor] Saving final state for task ${taskId}.`, - ); - try { - await this.taskStore?.save(wrapper.toSDKTask()); - logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`); - } catch (saveError) { - logger.error( - `[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`, - saveError, + if (isPrimaryExecution) { + this.executingTasks.delete(taskId); + logger.info( + `[CoderAgentExecutor] Saving final state for task ${taskId}.`, ); + try { + await this.taskStore?.save(wrapper.toSDKTask()); + logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`); + } catch (saveError) { + logger.error( + `[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`, + saveError, + ); + } + + if ( + ['canceled', 'failed', 'completed'].includes(currentTask.taskState) + ) { + logger.info( + `[CoderAgentExecutor] Task ${taskId} reached terminal state ${currentTask.taskState}. Evicting and disposing.`, + ); + wrapper.task.dispose(); + this.tasks.delete(taskId); + } } } } diff --git a/packages/a2a-server/src/agent/task-event-driven.test.ts b/packages/a2a-server/src/agent/task-event-driven.test.ts new file mode 100644 index 0000000000..f9dda8a752 --- /dev/null +++ b/packages/a2a-server/src/agent/task-event-driven.test.ts @@ -0,0 +1,655 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; +import { Task } from './task.js'; +import { + type Config, + MessageBusType, + ToolConfirmationOutcome, + ApprovalMode, + Scheduler, + type MessageBus, +} from '@google/gemini-cli-core'; +import { createMockConfig } from '../utils/testing_utils.js'; +import type { ExecutionEventBus } from '@a2a-js/sdk/server'; + +describe('Task Event-Driven Scheduler', () => { + let mockConfig: Config; + let mockEventBus: ExecutionEventBus; + let messageBus: MessageBus; + + beforeEach(() => { + vi.clearAllMocks(); + mockConfig = createMockConfig({ + isEventDrivenSchedulerEnabled: () => true, + }) as Config; + messageBus = mockConfig.getMessageBus(); + mockEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + }); + + it('should instantiate Scheduler when enabled', () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + expect(task.scheduler).toBeInstanceOf(Scheduler); + }); + + it('should subscribe to TOOL_CALLS_UPDATE and map status changes', async () => { + // @ts-expect-error - Calling private constructor + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'executing', + }; + + // Simulate MessageBus event + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + status: expect.objectContaining({ + state: 'submitted', // initial task state + }), + metadata: expect.objectContaining({ + coderAgent: expect.objectContaining({ + kind: 'tool-call-update', + }), + }), + }), + ); + }); + + it('should handle tool confirmations by publishing to MessageBus', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-1', + confirmationDetails: { type: 'info', title: 'test', prompt: 'test' }, + }; + + // Simulate MessageBus event to stash the correlationId + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + // Simulate A2A client confirmation + const part = { + kind: 'data', + data: { + callId: '1', + outcome: 'proceed_once', + }, + }; + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart(part); + expect(handled).toBe(true); + + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-1', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, + }), + ); + }); + + it('should handle Rejection (Cancel) and Modification (ModifyWithEditor)', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-1', + confirmationDetails: { type: 'info', title: 'test', prompt: 'test' }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + // Simulate Rejection (Cancel) + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'cancel' }, + }); + expect(handled).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-1', + confirmed: false, + }), + ); + + const toolCall2 = { + request: { callId: '2', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-2', + confirmationDetails: { type: 'info', title: 'test', prompt: 'test' }, + }; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall2] }); + + // Simulate ModifyWithEditor + const handled2 = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '2', outcome: 'modify_with_editor' }, + }); + expect(handled2).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-2', + confirmed: false, + outcome: ToolConfirmationOutcome.ModifyWithEditor, + payload: undefined, + }), + ); + }); + + it('should handle MCP Server tool operations correctly', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'call_mcp_tool', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-mcp-1', + confirmationDetails: { + type: 'mcp', + title: 'MCP Server Operation', + prompt: 'test_mcp', + }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + // Simulate ProceedOnce for MCP + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'proceed_once' }, + }); + expect(handled).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-mcp-1', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, + }), + ); + }); + + it('should handle MCP Server tool ProceedAlwaysServer outcome', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'call_mcp_tool', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-mcp-2', + confirmationDetails: { + type: 'mcp', + title: 'MCP Server Operation', + prompt: 'test_mcp', + }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'proceed_always_server' }, + }); + expect(handled).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-mcp-2', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedAlwaysServer, + }), + ); + }); + + it('should handle MCP Server tool ProceedAlwaysTool outcome', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'call_mcp_tool', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-mcp-3', + confirmationDetails: { + type: 'mcp', + title: 'MCP Server Operation', + prompt: 'test_mcp', + }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'proceed_always_tool' }, + }); + expect(handled).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-mcp-3', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedAlwaysTool, + }), + ); + }); + + it('should handle MCP Server tool ProceedAlwaysAndSave outcome', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'call_mcp_tool', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-mcp-4', + confirmationDetails: { + type: 'mcp', + title: 'MCP Server Operation', + prompt: 'test_mcp', + }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'proceed_always_and_save' }, + }); + expect(handled).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-mcp-4', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedAlwaysAndSave, + }), + ); + }); + + it('should execute without confirmation in YOLO mode and not transition to input-required', async () => { + // Enable YOLO mode + const yoloConfig = createMockConfig({ + isEventDrivenSchedulerEnabled: () => true, + getApprovalMode: () => ApprovalMode.YOLO, + }) as Config; + const yoloMessageBus = yoloConfig.getMessageBus(); + + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', yoloConfig, mockEventBus); + task.setTaskStateAndPublishUpdate = vi.fn(); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-1', + confirmationDetails: { type: 'info', title: 'test', prompt: 'test' }, + }; + + const handler = (yoloMessageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + // Should NOT auto-publish ProceedOnce anymore, because PolicyEngine handles it directly + expect(yoloMessageBus.publish).not.toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + }), + ); + + // Should NOT transition to input-required since it was auto-approved + expect(task.setTaskStateAndPublishUpdate).not.toHaveBeenCalledWith( + 'input-required', + expect.anything(), + undefined, + undefined, + true, + ); + }); + + it('should handle output updates via the message bus', async () => { + // @ts-expect-error - Calling private constructor + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'executing', + liveOutput: 'chunk1', + }; + + // Simulate MessageBus event + // Simulate MessageBus event + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + if (!handler) { + throw new Error('TOOL_CALLS_UPDATE handler not found'); + } + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall], + }); + + // Should publish artifact update for output + expect(mockEventBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + kind: 'artifact-update', + artifact: expect.objectContaining({ + artifactId: 'tool-1-output', + parts: [{ kind: 'text', text: 'chunk1' }], + }), + }), + ); + }); + + it('should complete artifact creation without hanging', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCallId = 'create-file-123'; + task['_registerToolCall'](toolCallId, 'executing'); + + const toolCall = { + request: { + callId: toolCallId, + name: 'writeFile', + args: { path: 'test.sh' }, + }, + status: 'success', + result: { ok: true }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] }); + + // The tool should be complete and registered appropriately, eventually + // triggering the toolCompletionPromise resolution when all clear. + const internalTask = task as unknown as { + completedToolCalls: unknown[]; + pendingToolCalls: Map; + }; + expect(internalTask.completedToolCalls.length).toBe(1); + expect(internalTask.pendingToolCalls.size).toBe(0); + }); + + it('should preserve messageId across multiple text chunks to prevent UI duplication', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + // Initialize the ID for the first turn (happens internally upon LLM stream) + task.currentAgentMessageId = 'test-id-123'; + + // Simulate sending multiple text chunks + task._sendTextContent('chunk 1'); + task._sendTextContent('chunk 2'); + + // Both text contents should have been published with the same messageId + const textCalls = (mockEventBus.publish as Mock).mock.calls.filter( + (call) => call[0].status?.message?.kind === 'message', + ); + expect(textCalls.length).toBe(2); + expect(textCalls[0][0].status.message.messageId).toBe('test-id-123'); + expect(textCalls[1][0].status.message.messageId).toBe('test-id-123'); + + // Simulate starting a new turn by calling getAndClearCompletedTools + // (which precedes sendCompletedToolsToLlm where a new ID is minted) + task.getAndClearCompletedTools(); + + // sendCompletedToolsToLlm internally rolls the ID forward. + // Simulate what sendCompletedToolsToLlm does: + const internalTask = task as unknown as { + setTaskStateAndPublishUpdate: (state: string, change: unknown) => void; + }; + internalTask.setTaskStateAndPublishUpdate('working', {}); + + // Simulate what sendCompletedToolsToLlm does: generate a new UUID for the next turn + task.currentAgentMessageId = 'test-id-456'; + + task._sendTextContent('chunk 3'); + + const secondTurnCalls = (mockEventBus.publish as Mock).mock.calls.filter( + (call) => call[0].status?.message?.messageId === 'test-id-456', + ); + expect(secondTurnCalls.length).toBe(1); + expect(secondTurnCalls[0][0].status.message.parts[0].text).toBe('chunk 3'); + }); + + it('should handle parallel tool calls correctly', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const toolCall1 = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-1', + confirmationDetails: { type: 'info', title: 'test 1', prompt: 'test 1' }, + }; + + const toolCall2 = { + request: { callId: '2', name: 'pwd', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-2', + confirmationDetails: { type: 'info', title: 'test 2', prompt: 'test 2' }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + // Publish update for both tool calls simultaneously + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall1, toolCall2], + }); + + // Confirm first tool call + const handled1 = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '1', outcome: 'proceed_once' }, + }); + expect(handled1).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-1', + confirmed: true, + }), + ); + + // Confirm second tool call + const handled2 = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: '2', outcome: 'cancel' }, + }); + expect(handled2).toBe(true); + expect(messageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'corr-2', + confirmed: false, + }), + ); + }); + + it('should wait for executing tools before transitioning to input-required state', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + task.setTaskStateAndPublishUpdate = vi.fn(); + + // Register tool 1 as executing + task['_registerToolCall']('1', 'executing'); + + const toolCall1 = { + request: { callId: '1', name: 'ls', args: {} }, + status: 'executing', + }; + + const toolCall2 = { + request: { callId: '2', name: 'pwd', args: {} }, + status: 'awaiting_approval', + correlationId: 'corr-2', + confirmationDetails: { type: 'info', title: 'test 2', prompt: 'test 2' }, + }; + + const handler = (messageBus.subscribe as Mock).mock.calls.find( + (call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE, + )?.[1]; + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall1, toolCall2], + }); + + // Should NOT transition to input-required yet + expect(task.setTaskStateAndPublishUpdate).not.toHaveBeenCalledWith( + 'input-required', + expect.anything(), + undefined, + undefined, + true, + ); + + // Complete tool 1 + const toolCall1Complete = { + ...toolCall1, + status: 'success', + result: { ok: true }, + }; + + handler({ + type: MessageBusType.TOOL_CALLS_UPDATE, + toolCalls: [toolCall1Complete, toolCall2], + }); + + // Now it should transition + expect(task.setTaskStateAndPublishUpdate).toHaveBeenCalledWith( + 'input-required', + expect.anything(), + undefined, + undefined, + true, + ); + }); + + it('should ignore confirmations for unknown tool calls', async () => { + // @ts-expect-error - Calling private constructor + const task = new Task('task-id', 'context-id', mockConfig, mockEventBus); + + const handled = await ( + task as unknown as { + _handleToolConfirmationPart: (part: unknown) => Promise; + } + )._handleToolConfirmationPart({ + kind: 'data', + data: { callId: 'unknown-id', outcome: 'proceed_once' }, + }); + + // Should return false for unhandled tool call + expect(handled).toBe(false); + + // Should not publish anything to the message bus + expect(messageBus.publish).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index e29f669333..bf15d7fc49 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -504,13 +504,14 @@ describe('Task', () => { }); describe('auto-approval', () => { - it('should auto-approve tool calls when autoExecute is true', () => { + it('should NOT publish ToolCallConfirmationEvent when autoExecute is true', () => { task.autoExecute = true; const onConfirmSpy = vi.fn(); const toolCalls = [ { request: { callId: '1' }, status: 'awaiting_approval', + correlationId: 'test-corr-id', confirmationDetails: { type: 'edit', onConfirm: onConfirmSpy, @@ -524,9 +525,17 @@ describe('Task', () => { expect(onConfirmSpy).toHaveBeenCalledWith( ToolConfirmationOutcome.ProceedOnce, ); + const calls = (mockEventBus.publish as Mock).mock.calls; + // Search if ToolCallConfirmationEvent was published + const confEvent = calls.find( + (call) => + call[0].metadata?.coderAgent?.kind === + CoderAgentEvent.ToolCallConfirmationEvent, + ); + expect(confEvent).toBeUndefined(); }); - it('should auto-approve tool calls when approval mode is YOLO', () => { + it('should NOT publish ToolCallConfirmationEvent when approval mode is YOLO', () => { (mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.YOLO); task.autoExecute = false; const onConfirmSpy = vi.fn(); @@ -534,6 +543,7 @@ describe('Task', () => { { request: { callId: '1' }, status: 'awaiting_approval', + correlationId: 'test-corr-id', confirmationDetails: { type: 'edit', onConfirm: onConfirmSpy, @@ -547,6 +557,14 @@ describe('Task', () => { expect(onConfirmSpy).toHaveBeenCalledWith( ToolConfirmationOutcome.ProceedOnce, ); + const calls = (mockEventBus.publish as Mock).mock.calls; + // Search if ToolCallConfirmationEvent was published + const confEvent = calls.find( + (call) => + call[0].metadata?.coderAgent?.kind === + CoderAgentEvent.ToolCallConfirmationEvent, + ); + expect(confEvent).toBeUndefined(); }); it('should NOT auto-approve when autoExecute is false and mode is not YOLO', () => { @@ -567,6 +585,14 @@ describe('Task', () => { task._schedulerToolCallsUpdate(toolCalls); expect(onConfirmSpy).not.toHaveBeenCalled(); + const calls = (mockEventBus.publish as Mock).mock.calls; + // Search if ToolCallConfirmationEvent was published + const confEvent = calls.find( + (call) => + call[0].metadata?.coderAgent?.kind === + CoderAgentEvent.ToolCallConfirmationEvent, + ); + expect(confEvent).toBeDefined(); }); }); }); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index ef15a907e6..652635779b 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -5,6 +5,7 @@ */ import { + Scheduler, CoreToolScheduler, type GeminiClient, GeminiEventType, @@ -34,6 +35,8 @@ import { isSubagentProgress, EDIT_TOOL_NAMES, processRestorableToolCalls, + MessageBusType, + type ToolCallsUpdateMessage, } from '@google/gemini-cli-core'; import { type ExecutionEventBus, @@ -96,21 +99,30 @@ function isToolCallConfirmationDetails( export class Task { id: string; contextId: string; - scheduler: CoreToolScheduler; + scheduler: Scheduler | CoreToolScheduler; config: Config; geminiClient: GeminiClient; pendingToolConfirmationDetails: Map; + pendingCorrelationIds: Map = new Map(); taskState: TaskState; eventBus?: ExecutionEventBus; completedToolCalls: CompletedToolCall[]; + processedToolCallIds: Set = new Set(); skipFinalTrueAfterInlineEdit = false; modelInfo?: string; currentPromptId: string | undefined; + currentAgentMessageId = uuidv4(); promptCount = 0; autoExecute: boolean; + private get isYoloMatch(): boolean { + return ( + this.autoExecute || this.config.getApprovalMode() === ApprovalMode.YOLO + ); + } // For tool waiting logic private pendingToolCalls: Map = new Map(); //toolCallId --> status + private toolsAlreadyConfirmed: Set = new Set(); private toolCompletionPromise?: Promise; private toolCompletionNotifier?: { resolve: () => void; @@ -127,7 +139,13 @@ export class Task { this.id = id; this.contextId = contextId; this.config = config; - this.scheduler = this.createScheduler(); + + if (this.config.isEventDrivenSchedulerEnabled()) { + this.scheduler = this.setupEventDrivenScheduler(); + } else { + this.scheduler = this.createLegacyScheduler(); + } + this.geminiClient = this.config.getGeminiClient(); this.pendingToolConfirmationDetails = new Map(); this.taskState = 'submitted'; @@ -227,7 +245,7 @@ export class Task { logger.info( `[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`, ); - return this.toolCompletionPromise; + await this.toolCompletionPromise; } cancelPendingTools(reason: string): void { @@ -240,6 +258,13 @@ export class Task { this.toolCompletionNotifier.reject(new Error(reason)); } this.pendingToolCalls.clear(); + this.pendingCorrelationIds.clear(); + + if (this.scheduler instanceof Scheduler) { + this.scheduler.cancelAll(); + } else { + this.scheduler.cancelAll(new AbortController().signal); + } // Reset the promise for any future operations, ensuring it's in a clean state. this._resetToolCompletionPromise(); } @@ -252,7 +277,7 @@ export class Task { kind: 'message', role, parts: [{ kind: 'text', text }], - messageId: uuidv4(), + messageId: role === 'agent' ? this.currentAgentMessageId : uuidv4(), taskId: this.id, contextId: this.contextId, }; @@ -425,26 +450,34 @@ export class Task { // Only send an update if the status has actually changed. if (hasChanged) { - const coderAgentMessage: CoderAgentMessage = - tc.status === 'awaiting_approval' - ? { kind: CoderAgentEvent.ToolCallConfirmationEvent } - : { kind: CoderAgentEvent.ToolCallUpdateEvent }; - const message = this.toolStatusMessage(tc, this.id, this.contextId); + // Skip sending confirmation event if we are going to auto-approve it anyway + if ( + tc.status === 'awaiting_approval' && + tc.confirmationDetails && + this.isYoloMatch + ) { + logger.info( + `[Task] Skipping ToolCallConfirmationEvent for ${tc.request.callId} due to YOLO mode.`, + ); + } else { + const coderAgentMessage: CoderAgentMessage = + tc.status === 'awaiting_approval' + ? { kind: CoderAgentEvent.ToolCallConfirmationEvent } + : { kind: CoderAgentEvent.ToolCallUpdateEvent }; + const message = this.toolStatusMessage(tc, this.id, this.contextId); - const event = this._createStatusUpdateEvent( - this.taskState, - coderAgentMessage, - message, - false, // Always false for these continuous updates - ); - this.eventBus?.publish(event); + const event = this._createStatusUpdateEvent( + this.taskState, + coderAgentMessage, + message, + false, // Always false for these continuous updates + ); + this.eventBus?.publish(event); + } } }); - if ( - this.autoExecute || - this.config.getApprovalMode() === ApprovalMode.YOLO - ) { + if (this.isYoloMatch) { logger.info( '[Task] ' + (this.autoExecute ? '' : 'YOLO mode enabled. ') + @@ -492,7 +525,7 @@ export class Task { } } - private createScheduler(): CoreToolScheduler { + private createLegacyScheduler(): CoreToolScheduler { const scheduler = new CoreToolScheduler({ outputUpdateHandler: this._schedulerOutputUpdate.bind(this), onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this), @@ -503,6 +536,171 @@ export class Task { return scheduler; } + private messageBusListener?: (message: ToolCallsUpdateMessage) => void; + + private setupEventDrivenScheduler(): Scheduler { + const messageBus = this.config.getMessageBus(); + const scheduler = new Scheduler({ + schedulerId: this.id, + config: this.config, + messageBus, + getPreferredEditor: () => DEFAULT_GUI_EDITOR, + }); + + this.messageBusListener = this.handleEventDrivenToolCallsUpdate.bind(this); + messageBus.subscribe( + MessageBusType.TOOL_CALLS_UPDATE, + this.messageBusListener, + ); + + return scheduler; + } + + dispose(): void { + if (this.messageBusListener) { + this.config + .getMessageBus() + .unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, this.messageBusListener); + this.messageBusListener = undefined; + } + + if (this.scheduler instanceof Scheduler) { + this.scheduler.dispose(); + } + } + + private handleEventDrivenToolCallsUpdate( + event: ToolCallsUpdateMessage, + ): void { + if (event.type !== MessageBusType.TOOL_CALLS_UPDATE) { + return; + } + + const toolCalls = event.toolCalls; + + toolCalls.forEach((tc) => { + this.handleEventDrivenToolCall(tc); + }); + + this.checkInputRequiredState(); + } + + private handleEventDrivenToolCall(tc: ToolCall): void { + const callId = tc.request.callId; + + // Do not process events for tools that have already been finalized. + // This prevents duplicate completions if the state manager emits a snapshot containing + // already resolved tools whose IDs were removed from pendingToolCalls. + if ( + this.processedToolCallIds.has(callId) || + this.completedToolCalls.some((c) => c.request.callId === callId) + ) { + return; + } + + const previousStatus = this.pendingToolCalls.get(callId); + const hasChanged = previousStatus !== tc.status; + + // 1. Handle Output + if (tc.status === 'executing' && tc.liveOutput) { + this._schedulerOutputUpdate(callId, tc.liveOutput); + } + + // 2. Handle terminal states + if ( + tc.status === 'success' || + tc.status === 'error' || + tc.status === 'cancelled' + ) { + this.toolsAlreadyConfirmed.delete(callId); + if (hasChanged) { + logger.info( + `[Task] Tool call ${callId} completed with status: ${tc.status}`, + ); + this.completedToolCalls.push(tc); + this._resolveToolCall(callId); + } + } else { + // Keep track of pending tools + this._registerToolCall(callId, tc.status); + } + + // 3. Handle Confirmation Stash + if (tc.status === 'awaiting_approval' && tc.confirmationDetails) { + const details = tc.confirmationDetails; + + if (tc.correlationId) { + this.pendingCorrelationIds.set(callId, tc.correlationId); + } + + this.pendingToolConfirmationDetails.set(callId, { + ...details, + onConfirm: async () => {}, + } as ToolCallConfirmationDetails); + } + + // 4. Publish Status Updates to A2A event bus + if (hasChanged) { + const coderAgentMessage: CoderAgentMessage = + tc.status === 'awaiting_approval' + ? { kind: CoderAgentEvent.ToolCallConfirmationEvent } + : { kind: CoderAgentEvent.ToolCallUpdateEvent }; + + const message = this.toolStatusMessage(tc, this.id, this.contextId); + const statusUpdate = this._createStatusUpdateEvent( + this.taskState, + coderAgentMessage, + message, + false, + ); + this.eventBus?.publish(statusUpdate); + } + } + + private checkInputRequiredState(): void { + if (this.isYoloMatch) { + return; + } + + // 6. Handle Input Required State + let isAwaitingApproval = false; + let isExecuting = false; + + for (const [callId, status] of this.pendingToolCalls.entries()) { + if (status === 'executing' || status === 'scheduled') { + isExecuting = true; + } else if ( + status === 'awaiting_approval' && + !this.toolsAlreadyConfirmed.has(callId) + ) { + isAwaitingApproval = true; + } + } + + if ( + isAwaitingApproval && + !isExecuting && + !this.skipFinalTrueAfterInlineEdit + ) { + this.skipFinalTrueAfterInlineEdit = false; + const wasAlreadyInputRequired = this.taskState === 'input-required'; + + this.setTaskStateAndPublishUpdate( + 'input-required', + { kind: CoderAgentEvent.StateChangeEvent }, + undefined, + undefined, + /*final*/ true, + ); + + // Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream. + // The IDE client will open a new stream with the confirmation reply. + if (!wasAlreadyInputRequired && this.toolCompletionNotifier) { + this.toolCompletionNotifier.resolve(); + } + } + } + private _pickFields< T extends ToolCall | AnyDeclarativeTool, K extends UnionKeys, @@ -713,7 +911,16 @@ export class Task { }; this.setTaskStateAndPublishUpdate('working', stateChange); - await this.scheduler.schedule(updatedRequests, abortSignal); + // Pre-register tools to ensure waitForPendingTools sees them as pending + // before the async scheduler enqueues them and fires the event bus update. + for (const req of updatedRequests) { + if (!this.pendingToolCalls.has(req.callId)) { + this._registerToolCall(req.callId, 'scheduled'); + } + } + + // Fire and forget so we don't block the executor loop before waitForPendingTools can be called + void this.scheduler.schedule(updatedRequests, abortSignal); } async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise { @@ -839,9 +1046,15 @@ export class Task { ) { return false; } + if (!part.data['outcome']) { + return false; + } const callId = part.data['callId']; const outcomeString = part.data['outcome']; + + this.toolsAlreadyConfirmed.add(callId); + let confirmationOutcome: ToolConfirmationOutcome | undefined; if (outcomeString === 'proceed_once') { @@ -854,6 +1067,8 @@ export class Task { confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysServer; } else if (outcomeString === 'proceed_always_tool') { confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysTool; + } else if (outcomeString === 'proceed_always_and_save') { + confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysAndSave; } else if (outcomeString === 'modify_with_editor') { confirmationOutcome = ToolConfirmationOutcome.ModifyWithEditor; } else { @@ -864,8 +1079,9 @@ export class Task { } const confirmationDetails = this.pendingToolConfirmationDetails.get(callId); + const correlationId = this.pendingCorrelationIds.get(callId); - if (!confirmationDetails) { + if (!confirmationDetails && !correlationId) { logger.warn( `[Task] Received tool confirmation for unknown or already processed callId: ${callId}`, ); @@ -887,24 +1103,35 @@ export class Task { // This will trigger the scheduler to continue or cancel the specific tool. // The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled). - // If `edit` tool call, pass updated payload if presesent - if (confirmationDetails.type === 'edit') { - const newContent = part.data['newContent']; - const payload = - typeof newContent === 'string' - ? ({ newContent } as ToolConfirmationPayload) - : undefined; - this.skipFinalTrueAfterInlineEdit = !!payload; - try { + // If `edit` tool call, pass updated payload if present + const newContent = part.data['newContent']; + const payload = + confirmationDetails?.type === 'edit' && typeof newContent === 'string' + ? ({ newContent } as ToolConfirmationPayload) + : undefined; + this.skipFinalTrueAfterInlineEdit = !!payload; + + try { + if (correlationId) { + await this.config.getMessageBus().publish({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: + confirmationOutcome !== ToolConfirmationOutcome.Cancel && + confirmationOutcome !== + ToolConfirmationOutcome.ModifyWithEditor, + outcome: confirmationOutcome, + payload, + }); + } else if (confirmationDetails?.onConfirm) { + // Fallback for legacy callback-based confirmation await confirmationDetails.onConfirm(confirmationOutcome, payload); - } finally { - // Once confirmationDetails.onConfirm finishes (or fails) with a payload, - // reset skipFinalTrueAfterInlineEdit so that external callers receive - // their call has been completed. - this.skipFinalTrueAfterInlineEdit = false; } - } else { - await confirmationDetails.onConfirm(confirmationOutcome); + } finally { + // Once confirmation payload is sent or callback finishes, + // reset skipFinalTrueAfterInlineEdit so that external callers receive + // their call has been completed. + this.skipFinalTrueAfterInlineEdit = false; } } finally { if (gcpProject) { @@ -920,6 +1147,7 @@ export class Task { // Note !== ToolConfirmationOutcome.ModifyWithEditor does not work! if (confirmationOutcome !== 'modify_with_editor') { this.pendingToolConfirmationDetails.delete(callId); + this.pendingCorrelationIds.delete(callId); } // If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool. @@ -953,6 +1181,9 @@ export class Task { getAndClearCompletedTools(): CompletedToolCall[] { const tools = [...this.completedToolCalls]; + for (const tool of tools) { + this.processedToolCallIds.add(tool.request.callId); + } this.completedToolCalls = []; return tools; } @@ -1013,6 +1244,7 @@ export class Task { }; // Set task state to working as we are about to call LLM this.setTaskStateAndPublishUpdate('working', stateChange); + this.currentAgentMessageId = uuidv4(); yield* this.geminiClient.sendMessageStream( llmParts, aborted, @@ -1034,6 +1266,10 @@ export class Task { if (confirmationHandled) { anyConfirmationHandled = true; // If a confirmation was handled, the scheduler will now run the tool (or cancel it). + // We resolve the toolCompletionPromise manually in checkInputRequiredState + // to break the original execution loop, so we must reset it here so the + // new loop correctly awaits the tool's final execution. + this._resetToolCompletionPromise(); // We don't send anything to the LLM for this part. // The subsequent tool execution will eventually lead to resolveToolCall. continue; @@ -1048,6 +1284,7 @@ export class Task { if (hasContentForLlm) { this.currentPromptId = this.config.getSessionId() + '########' + this.promptCount++; + this.currentAgentMessageId = uuidv4(); logger.info('[Task] Sending new parts to LLM.'); const stateChange: StateChange = { kind: CoderAgentEvent.StateChangeEvent, @@ -1093,7 +1330,6 @@ export class Task { if (content === '') { return; } - logger.info('[Task] Sending text content to event bus.'); const message = this._createTextMessage(content); const textContent: TextContent = { kind: CoderAgentEvent.TextContentEvent, @@ -1125,7 +1361,7 @@ export class Task { data: content, } as Part, ], - messageId: uuidv4(), + messageId: this.currentAgentMessageId, taskId: this.id, contextId: this.contextId, }; diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 5b6757701d..229abc65c9 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -106,6 +106,8 @@ export async function loadConfig( trustedFolder: true, extensionLoader, checkpointing, + enableEventDrivenScheduler: + settings.experimental?.enableEventDrivenScheduler ?? true, interactive: !isHeadlessMode(), enableInteractiveShell: !isHeadlessMode(), ptyInfo: 'auto', diff --git a/packages/a2a-server/src/config/settings.ts b/packages/a2a-server/src/config/settings.ts index b3c44cc177..0c353b46aa 100644 --- a/packages/a2a-server/src/config/settings.ts +++ b/packages/a2a-server/src/config/settings.ts @@ -37,6 +37,12 @@ export interface Settings { showMemoryUsage?: boolean; checkpointing?: CheckpointingSettings; folderTrust?: boolean; + general?: { + previewFeatures?: boolean; + }; + experimental?: { + enableEventDrivenScheduler?: boolean; + }; // Git-aware file filtering settings fileFiltering?: { diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 7d77d8dc9a..4981dbbd67 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -64,6 +64,7 @@ export function createMockConfig( getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), getSessionId: vi.fn().mockReturnValue('test-session-id'), getUserTier: vi.fn(), + isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false), getMessageBus: vi.fn(), getPolicyEngine: vi.fn(), getEnableExtensionReloading: vi.fn().mockReturnValue(false), diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index baf475701c..a54da32376 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -333,6 +333,48 @@ describe('PolicyEngine', () => { PolicyDecision.ASK_USER, ); }); + + it('should return ALLOW by default in YOLO mode when no rules match', async () => { + engine = new PolicyEngine({ approvalMode: ApprovalMode.YOLO }); + + // No rules defined, should return ALLOW in YOLO mode + const { decision } = await engine.check({ name: 'any-tool' }, undefined); + expect(decision).toBe(PolicyDecision.ALLOW); + }); + + it('should NOT override explicit DENY rules in YOLO mode', async () => { + const rules: PolicyRule[] = [ + { toolName: 'dangerous-tool', decision: PolicyDecision.DENY }, + ]; + engine = new PolicyEngine({ rules, approvalMode: ApprovalMode.YOLO }); + + const { decision } = await engine.check( + { name: 'dangerous-tool' }, + undefined, + ); + expect(decision).toBe(PolicyDecision.DENY); + + // But other tools still allowed + expect( + (await engine.check({ name: 'safe-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + }); + + it('should respect rule priority in YOLO mode when a match exists', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test-tool', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { toolName: 'test-tool', decision: PolicyDecision.DENY, priority: 20 }, + ]; + engine = new PolicyEngine({ rules, approvalMode: ApprovalMode.YOLO }); + + // Priority 20 (DENY) should win over priority 10 (ASK_USER) + const { decision } = await engine.check({ name: 'test-tool' }, undefined); + expect(decision).toBe(PolicyDecision.DENY); + }); }); describe('addRule', () => { diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index a2f64bf356..b626666370 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -466,6 +466,15 @@ export class PolicyEngine { // Default if no rule matched if (decision === undefined) { + if (this.approvalMode === ApprovalMode.YOLO) { + debugLogger.debug( + `[PolicyEngine.check] NO MATCH in YOLO mode - using ALLOW`, + ); + return { + decision: PolicyDecision.ALLOW, + }; + } + debugLogger.debug( `[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`, );