feat(a2a): switch from callback-based to event-driven tool scheduler (#21467)

Co-authored-by: Abhi <abhipatel@google.com>
Co-authored-by: Adam Weidman <adamfweidman@google.com>
This commit is contained in:
Coco Sheng
2026-03-10 15:36:17 -04:00
committed by GitHub
parent e5615f47c4
commit 1b69637032
10 changed files with 1323 additions and 59 deletions

View File

@@ -0,0 +1,248 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { CoderAgentExecutor } from './executor.js';
import type {
ExecutionEventBus,
RequestContext,
TaskStore,
} from '@a2a-js/sdk/server';
import { EventEmitter } from 'node:events';
import { requestStorage } from '../http/requestStorage.js';
// Mocks for constructor dependencies
vi.mock('../config/config.js', () => ({
loadConfig: vi.fn().mockReturnValue({
getSessionId: () => 'test-session',
getTargetDir: () => '/tmp',
getCheckpointingEnabled: () => false,
}),
loadEnvironment: vi.fn(),
setTargetDir: vi.fn().mockReturnValue('/tmp'),
}));
vi.mock('../config/settings.js', () => ({
loadSettings: vi.fn().mockReturnValue({}),
}));
vi.mock('../config/extension.js', () => ({
loadExtensions: vi.fn().mockReturnValue([]),
}));
vi.mock('../http/requestStorage.js', () => ({
requestStorage: {
getStore: vi.fn(),
},
}));
vi.mock('./task.js', () => {
const mockTaskInstance = (taskId: string, contextId: string) => ({
id: taskId,
contextId,
taskState: 'working',
acceptUserMessage: vi
.fn()
.mockImplementation(async function* (context, aborted) {
const isConfirmation = (
context.userMessage.parts as Array<{ kind: string }>
).some((p) => p.kind === 'confirmation');
// Hang only for main user messages (text), allow confirmations to finish quickly
if (!isConfirmation && aborted) {
await new Promise((resolve) => {
aborted.addEventListener('abort', resolve, { once: true });
});
}
yield { type: 'content', value: 'hello' };
}),
acceptAgentMessage: vi.fn().mockResolvedValue(undefined),
scheduleToolCalls: vi.fn().mockResolvedValue(undefined),
waitForPendingTools: vi.fn().mockResolvedValue(undefined),
getAndClearCompletedTools: vi.fn().mockReturnValue([]),
addToolResponsesToHistory: vi.fn(),
sendCompletedToolsToLlm: vi.fn().mockImplementation(async function* () {}),
cancelPendingTools: vi.fn(),
setTaskStateAndPublishUpdate: vi.fn(),
dispose: vi.fn(),
getMetadata: vi.fn().mockResolvedValue({}),
geminiClient: {
initialize: vi.fn().mockResolvedValue(undefined),
},
toSDKTask: () => ({
id: taskId,
contextId,
kind: 'task',
status: { state: 'working', timestamp: new Date().toISOString() },
metadata: {},
history: [],
artifacts: [],
}),
});
const MockTask = vi.fn().mockImplementation(mockTaskInstance);
(MockTask as unknown as { create: Mock }).create = vi
.fn()
.mockImplementation(async (taskId: string, contextId: string) =>
mockTaskInstance(taskId, contextId),
);
return { Task: MockTask };
});
describe('CoderAgentExecutor', () => {
let executor: CoderAgentExecutor;
let mockTaskStore: TaskStore;
let mockEventBus: ExecutionEventBus;
beforeEach(() => {
vi.clearAllMocks();
mockTaskStore = {
save: vi.fn().mockResolvedValue(undefined),
load: vi.fn().mockResolvedValue(undefined),
delete: vi.fn().mockResolvedValue(undefined),
list: vi.fn().mockResolvedValue([]),
} as unknown as TaskStore;
mockEventBus = new EventEmitter() as unknown as ExecutionEventBus;
mockEventBus.publish = vi.fn();
mockEventBus.finished = vi.fn();
executor = new CoderAgentExecutor(mockTaskStore);
});
it('should distinguish between primary and secondary execution', async () => {
const taskId = 'test-task';
const contextId = 'test-context';
const mockSocket = new EventEmitter();
const requestContext = {
userMessage: {
messageId: 'msg-1',
taskId,
contextId,
parts: [{ kind: 'text', text: 'hi' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;
// Mock requestStorage for primary
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: mockSocket },
});
// First execution (Primary)
const primaryPromise = executor.execute(requestContext, mockEventBus);
// Give it enough time to reach line 490 in executor.ts
await new Promise((resolve) => setTimeout(resolve, 50));
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
const wrapper = executor.getTask(taskId);
expect(wrapper).toBeDefined();
// Mock requestStorage for secondary
const secondarySocket = new EventEmitter();
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: secondarySocket },
});
const secondaryRequestContext = {
userMessage: {
messageId: 'msg-2',
taskId,
contextId,
parts: [{ kind: 'confirmation', callId: '1', outcome: 'proceed' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;
const secondaryPromise = executor.execute(
secondaryRequestContext,
mockEventBus,
);
// Secondary execution should NOT add to executingTasks (already there)
// and should return early after its loop
await secondaryPromise;
// Task should still be in executingTasks and NOT disposed
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
expect(wrapper?.task.dispose).not.toHaveBeenCalled();
// Now simulate secondary socket closure - it should NOT affect primary
secondarySocket.emit('end');
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(true);
expect(wrapper?.task.dispose).not.toHaveBeenCalled();
// Set to terminal state to verify disposal on finish
wrapper!.task.taskState = 'completed';
// Now close primary socket
mockSocket.emit('end');
await primaryPromise;
expect(
(
executor as unknown as { executingTasks: Set<string> }
).executingTasks.has(taskId),
).toBe(false);
expect(wrapper?.task.dispose).toHaveBeenCalled();
});
it('should evict task from cache when it reaches terminal state', async () => {
const taskId = 'test-task-terminal';
const contextId = 'test-context';
const mockSocket = new EventEmitter();
(requestStorage.getStore as Mock).mockReturnValue({
req: { socket: mockSocket },
});
const requestContext = {
userMessage: {
messageId: 'msg-1',
taskId,
contextId,
parts: [{ kind: 'text', text: 'hi' }],
metadata: {
coderAgent: { kind: 'agent-settings', workspacePath: '/tmp' },
},
},
} as unknown as RequestContext;
const primaryPromise = executor.execute(requestContext, mockEventBus);
await new Promise((resolve) => setTimeout(resolve, 50));
const wrapper = executor.getTask(taskId)!;
expect(wrapper).toBeDefined();
// Simulate terminal state
wrapper.task.taskState = 'completed';
// Finish primary execution
mockSocket.emit('end');
await primaryPromise;
expect(executor.getTask(taskId)).toBeUndefined();
expect(wrapper.task.dispose).toHaveBeenCalled();
});
});

View File

@@ -252,6 +252,10 @@ export class CoderAgentExecutor implements AgentExecutor {
);
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`);
// Cleanup listener subscriptions to avoid memory leaks.
wrapper.task.dispose();
this.tasks.delete(taskId);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : 'Unknown error';
@@ -320,23 +324,26 @@ export class CoderAgentExecutor implements AgentExecutor {
if (store) {
// Grab the raw socket from the request object
const socket = store.req.socket;
const onClientEnd = () => {
const onSocketEnd = () => {
logger.info(
`[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`,
`[CoderAgentExecutor] Socket ended for message ${userMessage.messageId} (task ${taskId}). Aborting execution loop.`,
);
if (!abortController.signal.aborted) {
abortController.abort();
}
// Clean up the listener to prevent memory leaks
socket.removeListener('close', onClientEnd);
socket.removeListener('end', onSocketEnd);
};
// Listen on the socket's 'end' event (remote closed the connection)
socket.on('end', onClientEnd);
socket.on('end', onSocketEnd);
socket.once('close', () => {
socket.removeListener('end', onSocketEnd);
});
// It's also good practice to remove the listener if the task completes successfully
abortSignal.addEventListener('abort', () => {
socket.removeListener('end', onClientEnd);
socket.removeListener('end', onSocketEnd);
});
logger.info(
`[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`,
@@ -457,6 +464,26 @@ export class CoderAgentExecutor implements AgentExecutor {
return;
}
// Check if this is the primary/initial execution for this task
const isPrimaryExecution = !this.executingTasks.has(taskId);
if (!isPrimaryExecution) {
logger.info(
`[CoderAgentExecutor] Primary execution already active for task ${taskId}. Starting secondary loop for message ${userMessage.messageId}.`,
);
currentTask.eventBus = eventBus;
for await (const _ of currentTask.acceptUserMessage(
requestContext,
abortController.signal,
)) {
logger.info(
`[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`,
);
}
// End this execution-- the original/source will be resumed.
return;
}
logger.info(
`[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`,
);
@@ -598,6 +625,7 @@ export class CoderAgentExecutor implements AgentExecutor {
}
}
} finally {
if (isPrimaryExecution) {
this.executingTasks.delete(taskId);
logger.info(
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
@@ -611,6 +639,17 @@ export class CoderAgentExecutor implements AgentExecutor {
saveError,
);
}
if (
['canceled', 'failed', 'completed'].includes(currentTask.taskState)
) {
logger.info(
`[CoderAgentExecutor] Task ${taskId} reached terminal state ${currentTask.taskState}. Evicting and disposing.`,
);
wrapper.task.dispose();
this.tasks.delete(taskId);
}
}
}
}
}

View File

@@ -0,0 +1,655 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { Task } from './task.js';
import {
type Config,
MessageBusType,
ToolConfirmationOutcome,
ApprovalMode,
Scheduler,
type MessageBus,
} from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
describe('Task Event-Driven Scheduler', () => {
let mockConfig: Config;
let mockEventBus: ExecutionEventBus;
let messageBus: MessageBus;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = createMockConfig({
isEventDrivenSchedulerEnabled: () => true,
}) as Config;
messageBus = mockConfig.getMessageBus();
mockEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
once: vi.fn(),
removeAllListeners: vi.fn(),
finished: vi.fn(),
};
});
it('should instantiate Scheduler when enabled', () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
expect(task.scheduler).toBeInstanceOf(Scheduler);
});
it('should subscribe to TOOL_CALLS_UPDATE and map status changes', async () => {
// @ts-expect-error - Calling private constructor
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'executing',
};
// Simulate MessageBus event
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
status: expect.objectContaining({
state: 'submitted', // initial task state
}),
metadata: expect.objectContaining({
coderAgent: expect.objectContaining({
kind: 'tool-call-update',
}),
}),
}),
);
});
it('should handle tool confirmations by publishing to MessageBus', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-1',
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
};
// Simulate MessageBus event to stash the correlationId
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
// Simulate A2A client confirmation
const part = {
kind: 'data',
data: {
callId: '1',
outcome: 'proceed_once',
},
};
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart(part);
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-1',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedOnce,
}),
);
});
it('should handle Rejection (Cancel) and Modification (ModifyWithEditor)', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-1',
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
// Simulate Rejection (Cancel)
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'cancel' },
});
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-1',
confirmed: false,
}),
);
const toolCall2 = {
request: { callId: '2', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-2',
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
};
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall2] });
// Simulate ModifyWithEditor
const handled2 = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '2', outcome: 'modify_with_editor' },
});
expect(handled2).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-2',
confirmed: false,
outcome: ToolConfirmationOutcome.ModifyWithEditor,
payload: undefined,
}),
);
});
it('should handle MCP Server tool operations correctly', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'call_mcp_tool', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-mcp-1',
confirmationDetails: {
type: 'mcp',
title: 'MCP Server Operation',
prompt: 'test_mcp',
},
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
// Simulate ProceedOnce for MCP
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'proceed_once' },
});
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-mcp-1',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedOnce,
}),
);
});
it('should handle MCP Server tool ProceedAlwaysServer outcome', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'call_mcp_tool', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-mcp-2',
confirmationDetails: {
type: 'mcp',
title: 'MCP Server Operation',
prompt: 'test_mcp',
},
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'proceed_always_server' },
});
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-mcp-2',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedAlwaysServer,
}),
);
});
it('should handle MCP Server tool ProceedAlwaysTool outcome', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'call_mcp_tool', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-mcp-3',
confirmationDetails: {
type: 'mcp',
title: 'MCP Server Operation',
prompt: 'test_mcp',
},
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'proceed_always_tool' },
});
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-mcp-3',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedAlwaysTool,
}),
);
});
it('should handle MCP Server tool ProceedAlwaysAndSave outcome', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'call_mcp_tool', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-mcp-4',
confirmationDetails: {
type: 'mcp',
title: 'MCP Server Operation',
prompt: 'test_mcp',
},
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'proceed_always_and_save' },
});
expect(handled).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-mcp-4',
confirmed: true,
outcome: ToolConfirmationOutcome.ProceedAlwaysAndSave,
}),
);
});
it('should execute without confirmation in YOLO mode and not transition to input-required', async () => {
// Enable YOLO mode
const yoloConfig = createMockConfig({
isEventDrivenSchedulerEnabled: () => true,
getApprovalMode: () => ApprovalMode.YOLO,
}) as Config;
const yoloMessageBus = yoloConfig.getMessageBus();
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', yoloConfig, mockEventBus);
task.setTaskStateAndPublishUpdate = vi.fn();
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-1',
confirmationDetails: { type: 'info', title: 'test', prompt: 'test' },
};
const handler = (yoloMessageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
// Should NOT auto-publish ProceedOnce anymore, because PolicyEngine handles it directly
expect(yoloMessageBus.publish).not.toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
}),
);
// Should NOT transition to input-required since it was auto-approved
expect(task.setTaskStateAndPublishUpdate).not.toHaveBeenCalledWith(
'input-required',
expect.anything(),
undefined,
undefined,
true,
);
});
it('should handle output updates via the message bus', async () => {
// @ts-expect-error - Calling private constructor
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall = {
request: { callId: '1', name: 'ls', args: {} },
status: 'executing',
liveOutput: 'chunk1',
};
// Simulate MessageBus event
// Simulate MessageBus event
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
if (!handler) {
throw new Error('TOOL_CALLS_UPDATE handler not found');
}
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall],
});
// Should publish artifact update for output
expect(mockEventBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'artifact-update',
artifact: expect.objectContaining({
artifactId: 'tool-1-output',
parts: [{ kind: 'text', text: 'chunk1' }],
}),
}),
);
});
it('should complete artifact creation without hanging', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCallId = 'create-file-123';
task['_registerToolCall'](toolCallId, 'executing');
const toolCall = {
request: {
callId: toolCallId,
name: 'writeFile',
args: { path: 'test.sh' },
},
status: 'success',
result: { ok: true },
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({ type: MessageBusType.TOOL_CALLS_UPDATE, toolCalls: [toolCall] });
// The tool should be complete and registered appropriately, eventually
// triggering the toolCompletionPromise resolution when all clear.
const internalTask = task as unknown as {
completedToolCalls: unknown[];
pendingToolCalls: Map<string, string>;
};
expect(internalTask.completedToolCalls.length).toBe(1);
expect(internalTask.pendingToolCalls.size).toBe(0);
});
it('should preserve messageId across multiple text chunks to prevent UI duplication', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
// Initialize the ID for the first turn (happens internally upon LLM stream)
task.currentAgentMessageId = 'test-id-123';
// Simulate sending multiple text chunks
task._sendTextContent('chunk 1');
task._sendTextContent('chunk 2');
// Both text contents should have been published with the same messageId
const textCalls = (mockEventBus.publish as Mock).mock.calls.filter(
(call) => call[0].status?.message?.kind === 'message',
);
expect(textCalls.length).toBe(2);
expect(textCalls[0][0].status.message.messageId).toBe('test-id-123');
expect(textCalls[1][0].status.message.messageId).toBe('test-id-123');
// Simulate starting a new turn by calling getAndClearCompletedTools
// (which precedes sendCompletedToolsToLlm where a new ID is minted)
task.getAndClearCompletedTools();
// sendCompletedToolsToLlm internally rolls the ID forward.
// Simulate what sendCompletedToolsToLlm does:
const internalTask = task as unknown as {
setTaskStateAndPublishUpdate: (state: string, change: unknown) => void;
};
internalTask.setTaskStateAndPublishUpdate('working', {});
// Simulate what sendCompletedToolsToLlm does: generate a new UUID for the next turn
task.currentAgentMessageId = 'test-id-456';
task._sendTextContent('chunk 3');
const secondTurnCalls = (mockEventBus.publish as Mock).mock.calls.filter(
(call) => call[0].status?.message?.messageId === 'test-id-456',
);
expect(secondTurnCalls.length).toBe(1);
expect(secondTurnCalls[0][0].status.message.parts[0].text).toBe('chunk 3');
});
it('should handle parallel tool calls correctly', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const toolCall1 = {
request: { callId: '1', name: 'ls', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-1',
confirmationDetails: { type: 'info', title: 'test 1', prompt: 'test 1' },
};
const toolCall2 = {
request: { callId: '2', name: 'pwd', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-2',
confirmationDetails: { type: 'info', title: 'test 2', prompt: 'test 2' },
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
// Publish update for both tool calls simultaneously
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall1, toolCall2],
});
// Confirm first tool call
const handled1 = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '1', outcome: 'proceed_once' },
});
expect(handled1).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-1',
confirmed: true,
}),
);
// Confirm second tool call
const handled2 = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: '2', outcome: 'cancel' },
});
expect(handled2).toBe(true);
expect(messageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'corr-2',
confirmed: false,
}),
);
});
it('should wait for executing tools before transitioning to input-required state', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
task.setTaskStateAndPublishUpdate = vi.fn();
// Register tool 1 as executing
task['_registerToolCall']('1', 'executing');
const toolCall1 = {
request: { callId: '1', name: 'ls', args: {} },
status: 'executing',
};
const toolCall2 = {
request: { callId: '2', name: 'pwd', args: {} },
status: 'awaiting_approval',
correlationId: 'corr-2',
confirmationDetails: { type: 'info', title: 'test 2', prompt: 'test 2' },
};
const handler = (messageBus.subscribe as Mock).mock.calls.find(
(call: unknown[]) => call[0] === MessageBusType.TOOL_CALLS_UPDATE,
)?.[1];
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall1, toolCall2],
});
// Should NOT transition to input-required yet
expect(task.setTaskStateAndPublishUpdate).not.toHaveBeenCalledWith(
'input-required',
expect.anything(),
undefined,
undefined,
true,
);
// Complete tool 1
const toolCall1Complete = {
...toolCall1,
status: 'success',
result: { ok: true },
};
handler({
type: MessageBusType.TOOL_CALLS_UPDATE,
toolCalls: [toolCall1Complete, toolCall2],
});
// Now it should transition
expect(task.setTaskStateAndPublishUpdate).toHaveBeenCalledWith(
'input-required',
expect.anything(),
undefined,
undefined,
true,
);
});
it('should ignore confirmations for unknown tool calls', async () => {
// @ts-expect-error - Calling private constructor
const task = new Task('task-id', 'context-id', mockConfig, mockEventBus);
const handled = await (
task as unknown as {
_handleToolConfirmationPart: (part: unknown) => Promise<boolean>;
}
)._handleToolConfirmationPart({
kind: 'data',
data: { callId: 'unknown-id', outcome: 'proceed_once' },
});
// Should return false for unhandled tool call
expect(handled).toBe(false);
// Should not publish anything to the message bus
expect(messageBus.publish).not.toHaveBeenCalled();
});
});

View File

@@ -504,13 +504,14 @@ describe('Task', () => {
});
describe('auto-approval', () => {
it('should auto-approve tool calls when autoExecute is true', () => {
it('should NOT publish ToolCallConfirmationEvent when autoExecute is true', () => {
task.autoExecute = true;
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
correlationId: 'test-corr-id',
confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
@@ -524,9 +525,17 @@ describe('Task', () => {
expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
const calls = (mockEventBus.publish as Mock).mock.calls;
// Search if ToolCallConfirmationEvent was published
const confEvent = calls.find(
(call) =>
call[0].metadata?.coderAgent?.kind ===
CoderAgentEvent.ToolCallConfirmationEvent,
);
expect(confEvent).toBeUndefined();
});
it('should auto-approve tool calls when approval mode is YOLO', () => {
it('should NOT publish ToolCallConfirmationEvent when approval mode is YOLO', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.YOLO);
task.autoExecute = false;
const onConfirmSpy = vi.fn();
@@ -534,6 +543,7 @@ describe('Task', () => {
{
request: { callId: '1' },
status: 'awaiting_approval',
correlationId: 'test-corr-id',
confirmationDetails: {
type: 'edit',
onConfirm: onConfirmSpy,
@@ -547,6 +557,14 @@ describe('Task', () => {
expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
const calls = (mockEventBus.publish as Mock).mock.calls;
// Search if ToolCallConfirmationEvent was published
const confEvent = calls.find(
(call) =>
call[0].metadata?.coderAgent?.kind ===
CoderAgentEvent.ToolCallConfirmationEvent,
);
expect(confEvent).toBeUndefined();
});
it('should NOT auto-approve when autoExecute is false and mode is not YOLO', () => {
@@ -567,6 +585,14 @@ describe('Task', () => {
task._schedulerToolCallsUpdate(toolCalls);
expect(onConfirmSpy).not.toHaveBeenCalled();
const calls = (mockEventBus.publish as Mock).mock.calls;
// Search if ToolCallConfirmationEvent was published
const confEvent = calls.find(
(call) =>
call[0].metadata?.coderAgent?.kind ===
CoderAgentEvent.ToolCallConfirmationEvent,
);
expect(confEvent).toBeDefined();
});
});
});

View File

@@ -5,6 +5,7 @@
*/
import {
Scheduler,
CoreToolScheduler,
type GeminiClient,
GeminiEventType,
@@ -34,6 +35,8 @@ import {
isSubagentProgress,
EDIT_TOOL_NAMES,
processRestorableToolCalls,
MessageBusType,
type ToolCallsUpdateMessage,
} from '@google/gemini-cli-core';
import {
type ExecutionEventBus,
@@ -96,21 +99,30 @@ function isToolCallConfirmationDetails(
export class Task {
id: string;
contextId: string;
scheduler: CoreToolScheduler;
scheduler: Scheduler | CoreToolScheduler;
config: Config;
geminiClient: GeminiClient;
pendingToolConfirmationDetails: Map<string, ToolCallConfirmationDetails>;
pendingCorrelationIds: Map<string, string> = new Map();
taskState: TaskState;
eventBus?: ExecutionEventBus;
completedToolCalls: CompletedToolCall[];
processedToolCallIds: Set<string> = new Set();
skipFinalTrueAfterInlineEdit = false;
modelInfo?: string;
currentPromptId: string | undefined;
currentAgentMessageId = uuidv4();
promptCount = 0;
autoExecute: boolean;
private get isYoloMatch(): boolean {
return (
this.autoExecute || this.config.getApprovalMode() === ApprovalMode.YOLO
);
}
// For tool waiting logic
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
private toolsAlreadyConfirmed: Set<string> = new Set();
private toolCompletionPromise?: Promise<void>;
private toolCompletionNotifier?: {
resolve: () => void;
@@ -127,7 +139,13 @@ export class Task {
this.id = id;
this.contextId = contextId;
this.config = config;
this.scheduler = this.createScheduler();
if (this.config.isEventDrivenSchedulerEnabled()) {
this.scheduler = this.setupEventDrivenScheduler();
} else {
this.scheduler = this.createLegacyScheduler();
}
this.geminiClient = this.config.getGeminiClient();
this.pendingToolConfirmationDetails = new Map();
this.taskState = 'submitted';
@@ -227,7 +245,7 @@ export class Task {
logger.info(
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
);
return this.toolCompletionPromise;
await this.toolCompletionPromise;
}
cancelPendingTools(reason: string): void {
@@ -240,6 +258,13 @@ export class Task {
this.toolCompletionNotifier.reject(new Error(reason));
}
this.pendingToolCalls.clear();
this.pendingCorrelationIds.clear();
if (this.scheduler instanceof Scheduler) {
this.scheduler.cancelAll();
} else {
this.scheduler.cancelAll(new AbortController().signal);
}
// Reset the promise for any future operations, ensuring it's in a clean state.
this._resetToolCompletionPromise();
}
@@ -252,7 +277,7 @@ export class Task {
kind: 'message',
role,
parts: [{ kind: 'text', text }],
messageId: uuidv4(),
messageId: role === 'agent' ? this.currentAgentMessageId : uuidv4(),
taskId: this.id,
contextId: this.contextId,
};
@@ -425,6 +450,16 @@ export class Task {
// Only send an update if the status has actually changed.
if (hasChanged) {
// Skip sending confirmation event if we are going to auto-approve it anyway
if (
tc.status === 'awaiting_approval' &&
tc.confirmationDetails &&
this.isYoloMatch
) {
logger.info(
`[Task] Skipping ToolCallConfirmationEvent for ${tc.request.callId} due to YOLO mode.`,
);
} else {
const coderAgentMessage: CoderAgentMessage =
tc.status === 'awaiting_approval'
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
@@ -439,12 +474,10 @@ export class Task {
);
this.eventBus?.publish(event);
}
}
});
if (
this.autoExecute ||
this.config.getApprovalMode() === ApprovalMode.YOLO
) {
if (this.isYoloMatch) {
logger.info(
'[Task] ' +
(this.autoExecute ? '' : 'YOLO mode enabled. ') +
@@ -492,7 +525,7 @@ export class Task {
}
}
private createScheduler(): CoreToolScheduler {
private createLegacyScheduler(): CoreToolScheduler {
const scheduler = new CoreToolScheduler({
outputUpdateHandler: this._schedulerOutputUpdate.bind(this),
onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this),
@@ -503,6 +536,171 @@ export class Task {
return scheduler;
}
private messageBusListener?: (message: ToolCallsUpdateMessage) => void;
private setupEventDrivenScheduler(): Scheduler {
const messageBus = this.config.getMessageBus();
const scheduler = new Scheduler({
schedulerId: this.id,
config: this.config,
messageBus,
getPreferredEditor: () => DEFAULT_GUI_EDITOR,
});
this.messageBusListener = this.handleEventDrivenToolCallsUpdate.bind(this);
messageBus.subscribe<ToolCallsUpdateMessage>(
MessageBusType.TOOL_CALLS_UPDATE,
this.messageBusListener,
);
return scheduler;
}
dispose(): void {
if (this.messageBusListener) {
this.config
.getMessageBus()
.unsubscribe(MessageBusType.TOOL_CALLS_UPDATE, this.messageBusListener);
this.messageBusListener = undefined;
}
if (this.scheduler instanceof Scheduler) {
this.scheduler.dispose();
}
}
private handleEventDrivenToolCallsUpdate(
event: ToolCallsUpdateMessage,
): void {
if (event.type !== MessageBusType.TOOL_CALLS_UPDATE) {
return;
}
const toolCalls = event.toolCalls;
toolCalls.forEach((tc) => {
this.handleEventDrivenToolCall(tc);
});
this.checkInputRequiredState();
}
private handleEventDrivenToolCall(tc: ToolCall): void {
const callId = tc.request.callId;
// Do not process events for tools that have already been finalized.
// This prevents duplicate completions if the state manager emits a snapshot containing
// already resolved tools whose IDs were removed from pendingToolCalls.
if (
this.processedToolCallIds.has(callId) ||
this.completedToolCalls.some((c) => c.request.callId === callId)
) {
return;
}
const previousStatus = this.pendingToolCalls.get(callId);
const hasChanged = previousStatus !== tc.status;
// 1. Handle Output
if (tc.status === 'executing' && tc.liveOutput) {
this._schedulerOutputUpdate(callId, tc.liveOutput);
}
// 2. Handle terminal states
if (
tc.status === 'success' ||
tc.status === 'error' ||
tc.status === 'cancelled'
) {
this.toolsAlreadyConfirmed.delete(callId);
if (hasChanged) {
logger.info(
`[Task] Tool call ${callId} completed with status: ${tc.status}`,
);
this.completedToolCalls.push(tc);
this._resolveToolCall(callId);
}
} else {
// Keep track of pending tools
this._registerToolCall(callId, tc.status);
}
// 3. Handle Confirmation Stash
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
const details = tc.confirmationDetails;
if (tc.correlationId) {
this.pendingCorrelationIds.set(callId, tc.correlationId);
}
this.pendingToolConfirmationDetails.set(callId, {
...details,
onConfirm: async () => {},
} as ToolCallConfirmationDetails);
}
// 4. Publish Status Updates to A2A event bus
if (hasChanged) {
const coderAgentMessage: CoderAgentMessage =
tc.status === 'awaiting_approval'
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
const message = this.toolStatusMessage(tc, this.id, this.contextId);
const statusUpdate = this._createStatusUpdateEvent(
this.taskState,
coderAgentMessage,
message,
false,
);
this.eventBus?.publish(statusUpdate);
}
}
private checkInputRequiredState(): void {
if (this.isYoloMatch) {
return;
}
// 6. Handle Input Required State
let isAwaitingApproval = false;
let isExecuting = false;
for (const [callId, status] of this.pendingToolCalls.entries()) {
if (status === 'executing' || status === 'scheduled') {
isExecuting = true;
} else if (
status === 'awaiting_approval' &&
!this.toolsAlreadyConfirmed.has(callId)
) {
isAwaitingApproval = true;
}
}
if (
isAwaitingApproval &&
!isExecuting &&
!this.skipFinalTrueAfterInlineEdit
) {
this.skipFinalTrueAfterInlineEdit = false;
const wasAlreadyInputRequired = this.taskState === 'input-required';
this.setTaskStateAndPublishUpdate(
'input-required',
{ kind: CoderAgentEvent.StateChangeEvent },
undefined,
undefined,
/*final*/ true,
);
// Unblock waitForPendingTools to correctly end the executor loop and release the HTTP response stream.
// The IDE client will open a new stream with the confirmation reply.
if (!wasAlreadyInputRequired && this.toolCompletionNotifier) {
this.toolCompletionNotifier.resolve();
}
}
}
private _pickFields<
T extends ToolCall | AnyDeclarativeTool,
K extends UnionKeys<T>,
@@ -713,7 +911,16 @@ export class Task {
};
this.setTaskStateAndPublishUpdate('working', stateChange);
await this.scheduler.schedule(updatedRequests, abortSignal);
// Pre-register tools to ensure waitForPendingTools sees them as pending
// before the async scheduler enqueues them and fires the event bus update.
for (const req of updatedRequests) {
if (!this.pendingToolCalls.has(req.callId)) {
this._registerToolCall(req.callId, 'scheduled');
}
}
// Fire and forget so we don't block the executor loop before waitForPendingTools can be called
void this.scheduler.schedule(updatedRequests, abortSignal);
}
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
@@ -839,9 +1046,15 @@ export class Task {
) {
return false;
}
if (!part.data['outcome']) {
return false;
}
const callId = part.data['callId'];
const outcomeString = part.data['outcome'];
this.toolsAlreadyConfirmed.add(callId);
let confirmationOutcome: ToolConfirmationOutcome | undefined;
if (outcomeString === 'proceed_once') {
@@ -854,6 +1067,8 @@ export class Task {
confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysServer;
} else if (outcomeString === 'proceed_always_tool') {
confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysTool;
} else if (outcomeString === 'proceed_always_and_save') {
confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysAndSave;
} else if (outcomeString === 'modify_with_editor') {
confirmationOutcome = ToolConfirmationOutcome.ModifyWithEditor;
} else {
@@ -864,8 +1079,9 @@ export class Task {
}
const confirmationDetails = this.pendingToolConfirmationDetails.get(callId);
const correlationId = this.pendingCorrelationIds.get(callId);
if (!confirmationDetails) {
if (!confirmationDetails && !correlationId) {
logger.warn(
`[Task] Received tool confirmation for unknown or already processed callId: ${callId}`,
);
@@ -887,25 +1103,36 @@ export class Task {
// This will trigger the scheduler to continue or cancel the specific tool.
// The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled).
// If `edit` tool call, pass updated payload if presesent
if (confirmationDetails.type === 'edit') {
// If `edit` tool call, pass updated payload if present
const newContent = part.data['newContent'];
const payload =
typeof newContent === 'string'
confirmationDetails?.type === 'edit' && typeof newContent === 'string'
? ({ newContent } as ToolConfirmationPayload)
: undefined;
this.skipFinalTrueAfterInlineEdit = !!payload;
try {
if (correlationId) {
await this.config.getMessageBus().publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId,
confirmed:
confirmationOutcome !== ToolConfirmationOutcome.Cancel &&
confirmationOutcome !==
ToolConfirmationOutcome.ModifyWithEditor,
outcome: confirmationOutcome,
payload,
});
} else if (confirmationDetails?.onConfirm) {
// Fallback for legacy callback-based confirmation
await confirmationDetails.onConfirm(confirmationOutcome, payload);
}
} finally {
// Once confirmationDetails.onConfirm finishes (or fails) with a payload,
// Once confirmation payload is sent or callback finishes,
// reset skipFinalTrueAfterInlineEdit so that external callers receive
// their call has been completed.
this.skipFinalTrueAfterInlineEdit = false;
}
} else {
await confirmationDetails.onConfirm(confirmationOutcome);
}
} finally {
if (gcpProject) {
process.env['GOOGLE_CLOUD_PROJECT'] = gcpProject;
@@ -920,6 +1147,7 @@ export class Task {
// Note !== ToolConfirmationOutcome.ModifyWithEditor does not work!
if (confirmationOutcome !== 'modify_with_editor') {
this.pendingToolConfirmationDetails.delete(callId);
this.pendingCorrelationIds.delete(callId);
}
// If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool.
@@ -953,6 +1181,9 @@ export class Task {
getAndClearCompletedTools(): CompletedToolCall[] {
const tools = [...this.completedToolCalls];
for (const tool of tools) {
this.processedToolCallIds.add(tool.request.callId);
}
this.completedToolCalls = [];
return tools;
}
@@ -1013,6 +1244,7 @@ export class Task {
};
// Set task state to working as we are about to call LLM
this.setTaskStateAndPublishUpdate('working', stateChange);
this.currentAgentMessageId = uuidv4();
yield* this.geminiClient.sendMessageStream(
llmParts,
aborted,
@@ -1034,6 +1266,10 @@ export class Task {
if (confirmationHandled) {
anyConfirmationHandled = true;
// If a confirmation was handled, the scheduler will now run the tool (or cancel it).
// We resolve the toolCompletionPromise manually in checkInputRequiredState
// to break the original execution loop, so we must reset it here so the
// new loop correctly awaits the tool's final execution.
this._resetToolCompletionPromise();
// We don't send anything to the LLM for this part.
// The subsequent tool execution will eventually lead to resolveToolCall.
continue;
@@ -1048,6 +1284,7 @@ export class Task {
if (hasContentForLlm) {
this.currentPromptId =
this.config.getSessionId() + '########' + this.promptCount++;
this.currentAgentMessageId = uuidv4();
logger.info('[Task] Sending new parts to LLM.');
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
@@ -1093,7 +1330,6 @@ export class Task {
if (content === '') {
return;
}
logger.info('[Task] Sending text content to event bus.');
const message = this._createTextMessage(content);
const textContent: TextContent = {
kind: CoderAgentEvent.TextContentEvent,
@@ -1125,7 +1361,7 @@ export class Task {
data: content,
} as Part,
],
messageId: uuidv4(),
messageId: this.currentAgentMessageId,
taskId: this.id,
contextId: this.contextId,
};

View File

@@ -106,6 +106,8 @@ export async function loadConfig(
trustedFolder: true,
extensionLoader,
checkpointing,
enableEventDrivenScheduler:
settings.experimental?.enableEventDrivenScheduler ?? true,
interactive: !isHeadlessMode(),
enableInteractiveShell: !isHeadlessMode(),
ptyInfo: 'auto',

View File

@@ -37,6 +37,12 @@ export interface Settings {
showMemoryUsage?: boolean;
checkpointing?: CheckpointingSettings;
folderTrust?: boolean;
general?: {
previewFeatures?: boolean;
};
experimental?: {
enableEventDrivenScheduler?: boolean;
};
// Git-aware file filtering settings
fileFiltering?: {

View File

@@ -64,6 +64,7 @@ export function createMockConfig(
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getUserTier: vi.fn(),
isEventDrivenSchedulerEnabled: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn(),
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),

View File

@@ -333,6 +333,48 @@ describe('PolicyEngine', () => {
PolicyDecision.ASK_USER,
);
});
it('should return ALLOW by default in YOLO mode when no rules match', async () => {
engine = new PolicyEngine({ approvalMode: ApprovalMode.YOLO });
// No rules defined, should return ALLOW in YOLO mode
const { decision } = await engine.check({ name: 'any-tool' }, undefined);
expect(decision).toBe(PolicyDecision.ALLOW);
});
it('should NOT override explicit DENY rules in YOLO mode', async () => {
const rules: PolicyRule[] = [
{ toolName: 'dangerous-tool', decision: PolicyDecision.DENY },
];
engine = new PolicyEngine({ rules, approvalMode: ApprovalMode.YOLO });
const { decision } = await engine.check(
{ name: 'dangerous-tool' },
undefined,
);
expect(decision).toBe(PolicyDecision.DENY);
// But other tools still allowed
expect(
(await engine.check({ name: 'safe-tool' }, undefined)).decision,
).toBe(PolicyDecision.ALLOW);
});
it('should respect rule priority in YOLO mode when a match exists', async () => {
const rules: PolicyRule[] = [
{
toolName: 'test-tool',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{ toolName: 'test-tool', decision: PolicyDecision.DENY, priority: 20 },
];
engine = new PolicyEngine({ rules, approvalMode: ApprovalMode.YOLO });
// Priority 20 (DENY) should win over priority 10 (ASK_USER)
const { decision } = await engine.check({ name: 'test-tool' }, undefined);
expect(decision).toBe(PolicyDecision.DENY);
});
});
describe('addRule', () => {

View File

@@ -466,6 +466,15 @@ export class PolicyEngine {
// Default if no rule matched
if (decision === undefined) {
if (this.approvalMode === ApprovalMode.YOLO) {
debugLogger.debug(
`[PolicyEngine.check] NO MATCH in YOLO mode - using ALLOW`,
);
return {
decision: PolicyDecision.ALLOW,
};
}
debugLogger.debug(
`[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`,
);