From e2901f3f7e2ffeb369ccb81d372e236b51d65e88 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Mon, 19 Jan 2026 18:19:17 -0500 Subject: [PATCH] refactor(core): decouple scheduler into orchestration, policy, and confirmation (#16895) --- .../core/src/scheduler/confirmation.test.ts | 527 ++++++--- packages/core/src/scheduler/confirmation.ts | 222 +++- packages/core/src/scheduler/policy.test.ts | 422 +++++++ packages/core/src/scheduler/policy.ts | 176 +++ packages/core/src/scheduler/scheduler.test.ts | 1008 +++++++++++++++++ packages/core/src/scheduler/scheduler.ts | 477 ++++++++ 6 files changed, 2670 insertions(+), 162 deletions(-) create mode 100644 packages/core/src/scheduler/policy.test.ts create mode 100644 packages/core/src/scheduler/policy.ts create mode 100644 packages/core/src/scheduler/scheduler.test.ts create mode 100644 packages/core/src/scheduler/scheduler.ts diff --git a/packages/core/src/scheduler/confirmation.test.ts b/packages/core/src/scheduler/confirmation.test.ts index 4e7453428b..12243137cd 100644 --- a/packages/core/src/scheduler/confirmation.test.ts +++ b/packages/core/src/scheduler/confirmation.test.ts @@ -4,219 +4,424 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, + type Mock, +} from 'vitest'; import { EventEmitter } from 'node:events'; -import { awaitConfirmation } from './confirmation.js'; +import { awaitConfirmation, resolveConfirmation } from './confirmation.js'; import { MessageBusType, type ToolConfirmationResponse, } from '../confirmation-bus/types.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; -import { ToolConfirmationOutcome } from '../tools/tools.js'; +import { + ToolConfirmationOutcome, + type AnyToolInvocation, + type AnyDeclarativeTool, +} from '../tools/tools.js'; +import type { SchedulerStateManager } from './state-manager.js'; +import type { ToolModificationHandler } from './tool-modifier.js'; +import type { ValidatingToolCall, WaitingToolCall } from './types.js'; +import type { Config } from '../config/config.js'; +import type { EditorType } from '../utils/editor.js'; +import { randomUUID } from 'node:crypto'; +import { fireToolNotificationHook } from '../core/coreToolHookTriggers.js'; -describe('awaitConfirmation', () => { +// Mock Dependencies +vi.mock('node:crypto', () => ({ + randomUUID: vi.fn(), +})); + +vi.mock('../core/coreToolHookTriggers.js', () => ({ + fireToolNotificationHook: vi.fn(), +})); + +describe('confirmation.ts', () => { let mockMessageBus: MessageBus; beforeEach(() => { mockMessageBus = new EventEmitter() as unknown as MessageBus; mockMessageBus.publish = vi.fn().mockResolvedValue(undefined); - // on() from node:events uses addListener/removeListener or on/off internally. vi.spyOn(mockMessageBus, 'on'); vi.spyOn(mockMessageBus, 'removeListener'); + vi.mocked(randomUUID).mockReturnValue( + '123e4567-e89b-12d3-a456-426614174000', + ); + }); + + afterEach(() => { + vi.clearAllMocks(); }); const emitResponse = (response: ToolConfirmationResponse) => { mockMessageBus.emit(MessageBusType.TOOL_CONFIRMATION_RESPONSE, response); }; - it('should resolve when confirmed response matches correlationId', async () => { - const correlationId = 'test-correlation-id'; - const abortController = new AbortController(); - - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); - - expect(mockMessageBus.on).toHaveBeenCalledWith( - MessageBusType.TOOL_CONFIRMATION_RESPONSE, - expect.any(Function), - ); - - // Simulate response - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId, - confirmed: true, + /** + * Helper to wait for a listener to be attached to the bus. + * This is more robust than setTimeout for synchronizing with the async iterator. + */ + const waitForListener = (eventName: string | symbol): Promise => + new Promise((resolve) => { + const handler = (event: string | symbol) => { + if (event === eventName) { + mockMessageBus.off('newListener', handler); + resolve(); + } + }; + mockMessageBus.on('newListener', handler); }); - const result = await promise; - expect(result).toEqual({ - outcome: ToolConfirmationOutcome.ProceedOnce, - payload: undefined, + describe('awaitConfirmation', () => { + it('should resolve when confirmed response matches correlationId', async () => { + const correlationId = 'test-correlation-id'; + const abortController = new AbortController(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: true, + }); + + const result = await promise; + expect(result).toEqual({ + outcome: ToolConfirmationOutcome.ProceedOnce, + payload: undefined, + }); + }); + + it('should reject when abort signal is triggered', async () => { + const correlationId = 'abort-id'; + const abortController = new AbortController(); + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + abortController.abort(); + await expect(promise).rejects.toThrow('Operation cancelled'); }); - expect(mockMessageBus.removeListener).toHaveBeenCalled(); }); - it('should resolve with mapped outcome when confirmed is false', async () => { - const correlationId = 'id-123'; - const abortController = new AbortController(); + describe('resolveConfirmation', () => { + let mockState: Mocked; + let mockModifier: Mocked; + let mockConfig: Mocked; + let getPreferredEditor: Mock<() => EditorType | undefined>; + let signal: AbortSignal; + let toolCall: ValidatingToolCall; + let invocationMock: Mocked; + let toolMock: Mocked; - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); + beforeEach(() => { + signal = new AbortController().signal; - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId, - confirmed: false, + mockState = { + getToolCall: vi.fn(), + updateStatus: vi.fn(), + updateArgs: vi.fn(), + } as unknown as Mocked; + // Mock accessors via defineProperty + Object.defineProperty(mockState, 'firstActiveCall', { + get: vi.fn(), + configurable: true, + }); + + mockModifier = { + handleModifyWithEditor: vi.fn(), + applyInlineModify: vi.fn(), + } as unknown as Mocked; + + mockConfig = { + getEnableHooks: vi.fn().mockReturnValue(true), + } as unknown as Mocked; + + getPreferredEditor = vi.fn().mockReturnValue('vim'); + + invocationMock = { + shouldConfirmExecute: vi.fn(), + } as unknown as Mocked; + + toolMock = { + build: vi.fn(), + } as unknown as Mocked; + + toolCall = { + status: 'validating', + request: { + callId: 'call-1', + name: 'tool', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-1', + }, + invocation: invocationMock, + tool: toolMock, + } as ValidatingToolCall; + + // Default: state returns the current call + mockState.getToolCall.mockReturnValue(toolCall); + // Default: define firstActiveCall for modifiers + vi.spyOn(mockState, 'firstActiveCall', 'get').mockReturnValue( + toolCall as unknown as WaitingToolCall, + ); }); - const result = await promise; - expect(result.outcome).toBe(ToolConfirmationOutcome.Cancel); - }); + it('should return ProceedOnce immediately if no confirmation needed', async () => { + invocationMock.shouldConfirmExecute.mockResolvedValue(false); - it('should resolve with explicit outcome if provided', async () => { - const correlationId = 'id-456'; - const abortController = new AbortController(); + const result = await resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); - - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId, - confirmed: true, - outcome: ToolConfirmationOutcome.ProceedAlways, + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedOnce); + expect(mockState.updateStatus).not.toHaveBeenCalledWith( + expect.anything(), + 'awaiting_approval', + expect.anything(), + ); }); - const result = await promise; - expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedAlways); - }); + it('should return ProceedOnce after successful user confirmation', async () => { + const details = { + type: 'info' as const, + prompt: 'Confirm?', + title: 'Title', + onConfirm: vi.fn(), + }; + invocationMock.shouldConfirmExecute.mockResolvedValue(details); - it('should resolve with payload', async () => { - const correlationId = 'id-payload'; - const abortController = new AbortController(); - const payload = { newContent: 'updated' }; + // Wait for listener to attach + const listenerPromise = waitForListener( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + ); + const promise = resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); + await listenerPromise; - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123e4567-e89b-12d3-a456-426614174000', + confirmed: true, + }); - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId, - confirmed: true, - outcome: ToolConfirmationOutcome.ModifyWithEditor, - payload, + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedOnce); + expect(mockState.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'awaiting_approval', + expect.objectContaining({ + correlationId: '123e4567-e89b-12d3-a456-426614174000', + }), + ); }); - const result = await promise; - expect(result.payload).toEqual(payload); - }); + it('should fire hooks if enabled', async () => { + const details = { + type: 'info' as const, + prompt: 'Confirm?', + title: 'Title', + onConfirm: vi.fn(), + }; + invocationMock.shouldConfirmExecute.mockResolvedValue(details); - it('should ignore responses with different correlation IDs', async () => { - const correlationId = 'my-id'; - const abortController = new AbortController(); + const promise = resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); - let resolved = false; - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ).then((r) => { - resolved = true; - return r; + await waitForListener(MessageBusType.TOOL_CONFIRMATION_RESPONSE); + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123e4567-e89b-12d3-a456-426614174000', + confirmed: true, + }); + await promise; + + expect(fireToolNotificationHook).toHaveBeenCalledWith( + mockMessageBus, + expect.objectContaining({ + type: details.type, + prompt: details.prompt, + title: details.title, + }), + ); }); - // Emit wrong ID - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId: 'wrong-id', - confirmed: true, + it('should handle ModifyWithEditor loop', async () => { + const details = { + type: 'info' as const, + prompt: 'Confirm?', + title: 'Title', + onConfirm: vi.fn(), + }; + invocationMock.shouldConfirmExecute.mockResolvedValue(details); + + // 1. User says Modify + // 2. User says Proceed + const listenerPromise1 = waitForListener( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + ); + const promise = resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); + + await listenerPromise1; + + // First response: Modify + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123e4567-e89b-12d3-a456-426614174000', + confirmed: true, + outcome: ToolConfirmationOutcome.ModifyWithEditor, + }); + + // Mock the modifier action + mockModifier.handleModifyWithEditor.mockResolvedValue({ + updatedParams: { foo: 'bar' }, + }); + toolMock.build.mockReturnValue({} as unknown as AnyToolInvocation); + + // Wait for loop to cycle and re-subscribe + const listenerPromise2 = waitForListener( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + ); + await listenerPromise2; + + // Expect state update + expect(mockState.updateArgs).toHaveBeenCalled(); + + // Second response: Proceed + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123e4567-e89b-12d3-a456-426614174000', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, + }); + + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedOnce); + expect(mockModifier.handleModifyWithEditor).toHaveBeenCalled(); }); - // Allow microtasks to process - await new Promise((r) => setTimeout(r, 0)); - expect(resolved).toBe(false); + it('should handle inline modification (payload)', async () => { + const details = { + type: 'info' as const, + prompt: 'Confirm?', + title: 'Title', + onConfirm: vi.fn(), + }; + invocationMock.shouldConfirmExecute.mockResolvedValue(details); - // Emit correct ID - emitResponse({ - type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, - correlationId, - confirmed: true, + const listenerPromise = waitForListener( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + ); + const promise = resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); + + await listenerPromise; + + // Response with payload + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123e4567-e89b-12d3-a456-426614174000', + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedOnce, // Ignored if payload present + payload: { newContent: 'inline' }, + }); + + mockModifier.applyInlineModify.mockResolvedValue({ + updatedParams: { inline: 'true' }, + }); + toolMock.build.mockReturnValue({} as unknown as AnyToolInvocation); + + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedOnce); + expect(mockModifier.applyInlineModify).toHaveBeenCalled(); + expect(mockState.updateArgs).toHaveBeenCalled(); }); - await expect(promise).resolves.toBeDefined(); - }); + it('should resolve immediately if IDE confirmation resolves first', async () => { + const idePromise = Promise.resolve({ + status: 'accepted' as const, + content: 'ide-content', + }); - it('should reject when abort signal is triggered', async () => { - const correlationId = 'abort-id'; - const abortController = new AbortController(); + const details = { + type: 'info' as const, + prompt: 'Confirm?', + title: 'Title', + onConfirm: vi.fn(), + ideConfirmation: idePromise, + }; + invocationMock.shouldConfirmExecute.mockResolvedValue(details); - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); + // We don't strictly need to wait for the listener because the race might finish instantly + const promise = resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }); - abortController.abort(); - - await expect(promise).rejects.toThrow('Operation cancelled'); - expect(mockMessageBus.removeListener).toHaveBeenCalled(); - }); - - it('should reject when abort signal timeout is triggered', async () => { - vi.useFakeTimers(); - const correlationId = 'timeout-id'; - const signal = AbortSignal.timeout(100); - - const promise = awaitConfirmation(mockMessageBus, correlationId, signal); - - vi.advanceTimersByTime(101); - - await expect(promise).rejects.toThrow('Operation cancelled'); - expect(mockMessageBus.removeListener).toHaveBeenCalled(); - vi.useRealTimers(); - }); - - it('should reject immediately if signal is already aborted', async () => { - const correlationId = 'pre-abort-id'; - const abortController = new AbortController(); - abortController.abort(); - - const promise = awaitConfirmation( - mockMessageBus, - correlationId, - abortController.signal, - ); - - await expect(promise).rejects.toThrow('Operation cancelled'); - expect(mockMessageBus.on).not.toHaveBeenCalled(); - }); - - it('should cleanup and reject if subscribe throws', async () => { - const error = new Error('Subscribe failed'); - vi.mocked(mockMessageBus.on).mockImplementationOnce(() => { - throw error; + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedOnce); }); - const abortController = new AbortController(); - const promise = awaitConfirmation( - mockMessageBus, - 'fail-id', - abortController.signal, - ); + it('should throw if tool call is lost from state during loop', async () => { + invocationMock.shouldConfirmExecute.mockResolvedValue({ + type: 'info' as const, + title: 'Title', + onConfirm: vi.fn(), + prompt: 'Prompt', + }); + // Simulate state losing the call (undefined) + mockState.getToolCall.mockReturnValue(undefined); - await expect(promise).rejects.toThrow(error); - expect(mockMessageBus.removeListener).not.toHaveBeenCalled(); + await expect( + resolveConfirmation(toolCall, signal, { + config: mockConfig, + messageBus: mockMessageBus, + state: mockState, + modifier: mockModifier, + getPreferredEditor, + }), + ).rejects.toThrow(/lost during confirmation loop/); + }); }); }); diff --git a/packages/core/src/scheduler/confirmation.ts b/packages/core/src/scheduler/confirmation.ts index 5ed2f31e98..f8d5f6b6b4 100644 --- a/packages/core/src/scheduler/confirmation.ts +++ b/packages/core/src/scheduler/confirmation.ts @@ -5,21 +5,40 @@ */ import { on } from 'node:events'; +import { randomUUID } from 'node:crypto'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBusType, type ToolConfirmationResponse, + type SerializableConfirmationDetails, } from '../confirmation-bus/types.js'; -import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { ToolConfirmationOutcome, type ToolConfirmationPayload, + type ToolCallConfirmationDetails, } from '../tools/tools.js'; +import type { ValidatingToolCall, WaitingToolCall } from './types.js'; +import type { Config } from '../config/config.js'; +import type { SchedulerStateManager } from './state-manager.js'; +import type { ToolModificationHandler } from './tool-modifier.js'; +import type { EditorType } from '../utils/editor.js'; +import type { DiffUpdateResult } from '../ide/ide-client.js'; +import { fireToolNotificationHook } from '../core/coreToolHookTriggers.js'; +import { debugLogger } from '../utils/debugLogger.js'; export interface ConfirmationResult { outcome: ToolConfirmationOutcome; payload?: ToolConfirmationPayload; } +/** + * Result of the full confirmation flow, including any user modifications. + */ +export interface ResolutionResult { + outcome: ToolConfirmationOutcome; + lastDetails?: SerializableConfirmationDetails; +} + /** * Waits for a confirmation response with the matching correlationId. * @@ -71,3 +90,204 @@ export async function awaitConfirmation( // which generally means the signal was aborted. throw new Error('Operation cancelled'); } + +/** + * Manages the interactive confirmation loop, handling user modifications + * via inline diffs or external editors (Vim). + */ +export async function resolveConfirmation( + toolCall: ValidatingToolCall, + signal: AbortSignal, + deps: { + config: Config; + messageBus: MessageBus; + state: SchedulerStateManager; + modifier: ToolModificationHandler; + getPreferredEditor: () => EditorType | undefined; + }, +): Promise { + const { state } = deps; + const callId = toolCall.request.callId; + let outcome = ToolConfirmationOutcome.ModifyWithEditor; + let lastDetails: SerializableConfirmationDetails | undefined; + + // Loop exists to allow the user to modify the parameters and see the new + // diff. + while (outcome === ToolConfirmationOutcome.ModifyWithEditor) { + if (signal.aborted) throw new Error('Operation cancelled'); + + const currentCall = state.getToolCall(callId); + if (!currentCall || !('invocation' in currentCall)) { + throw new Error(`Tool call ${callId} lost during confirmation loop`); + } + const currentInvocation = currentCall.invocation; + + const details = await currentInvocation.shouldConfirmExecute(signal); + if (!details) { + outcome = ToolConfirmationOutcome.ProceedOnce; + break; + } + + await notifyHooks(deps, details); + + const correlationId = randomUUID(); + const serializableDetails = details as SerializableConfirmationDetails; + lastDetails = serializableDetails; + + const ideConfirmation = + 'ideConfirmation' in details ? details.ideConfirmation : undefined; + + state.updateStatus(callId, 'awaiting_approval', { + confirmationDetails: serializableDetails, + correlationId, + }); + + const response = await waitForConfirmation( + deps.messageBus, + correlationId, + signal, + ideConfirmation, + ); + outcome = response.outcome; + + if (outcome === ToolConfirmationOutcome.ModifyWithEditor) { + await handleExternalModification(deps, toolCall, signal); + } else if (response.payload?.newContent) { + await handleInlineModification(deps, toolCall, response.payload, signal); + outcome = ToolConfirmationOutcome.ProceedOnce; + } + } + + return { outcome, lastDetails }; +} + +/** + * Fires hook notifications. + */ +async function notifyHooks( + deps: { config: Config; messageBus: MessageBus }, + details: ToolCallConfirmationDetails, +): Promise { + if (deps.config.getEnableHooks()) { + await fireToolNotificationHook(deps.messageBus, { + ...details, + // Pass no-op onConfirm to satisfy type definition; side-effects via + // callbacks are disallowed. + onConfirm: async () => {}, + } as ToolCallConfirmationDetails); + } +} + +/** + * Handles modification via an external editor (e.g. Vim). + */ +async function handleExternalModification( + deps: { + state: SchedulerStateManager; + modifier: ToolModificationHandler; + getPreferredEditor: () => EditorType | undefined; + }, + toolCall: ValidatingToolCall, + signal: AbortSignal, +): Promise { + const { state, modifier, getPreferredEditor } = deps; + const editor = getPreferredEditor(); + if (!editor) return; + + const result = await modifier.handleModifyWithEditor( + state.firstActiveCall as WaitingToolCall, + editor, + signal, + ); + if (result) { + const newInvocation = toolCall.tool.build(result.updatedParams); + state.updateArgs( + toolCall.request.callId, + result.updatedParams, + newInvocation, + ); + } +} + +/** + * Handles modification via inline payload (e.g. from IDE or TUI). + */ +async function handleInlineModification( + deps: { state: SchedulerStateManager; modifier: ToolModificationHandler }, + toolCall: ValidatingToolCall, + payload: ToolConfirmationPayload, + signal: AbortSignal, +): Promise { + const { state, modifier } = deps; + const result = await modifier.applyInlineModify( + state.firstActiveCall as WaitingToolCall, + payload, + signal, + ); + if (result) { + const newInvocation = toolCall.tool.build(result.updatedParams); + state.updateArgs( + toolCall.request.callId, + result.updatedParams, + newInvocation, + ); + } +} + +/** + * Waits for user confirmation, allowing either the MessageBus (TUI) or IDE to + * resolve it. + */ +async function waitForConfirmation( + messageBus: MessageBus, + correlationId: string, + signal: AbortSignal, + ideConfirmation?: Promise, +): Promise { + // Create a controller to abort the bus listener if the IDE wins (or vice versa) + const raceController = new AbortController(); + const raceSignal = raceController.signal; + + // Propagate the parent signal's abort to our race controller + const onParentAbort = () => raceController.abort(); + if (signal.aborted) { + raceController.abort(); + } else { + signal.addEventListener('abort', onParentAbort); + } + + try { + const busPromise = awaitConfirmation(messageBus, correlationId, raceSignal); + + if (!ideConfirmation) { + return await busPromise; + } + + // Wrap IDE promise to match ConfirmationResult signature + const idePromise = ideConfirmation + .then( + (resolution) => + ({ + outcome: + resolution.status === 'accepted' + ? ToolConfirmationOutcome.ProceedOnce + : ToolConfirmationOutcome.Cancel, + payload: resolution.content + ? { newContent: resolution.content } + : undefined, + }) as ConfirmationResult, + ) + .catch((error) => { + debugLogger.warn('Error waiting for confirmation via IDE', error); + // Return a never-resolving promise so the race continues with the bus + return new Promise(() => {}); + }); + + return await Promise.race([busPromise, idePromise]); + } finally { + // Cleanup: remove parent listener and abort the race signal to ensure + // the losing listener (e.g. bus iterator) is closed. + signal.removeEventListener('abort', onParentAbort); + raceController.abort(); + } +} diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts new file mode 100644 index 0000000000..0e347d1c62 --- /dev/null +++ b/packages/core/src/scheduler/policy.test.ts @@ -0,0 +1,422 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, type Mocked } from 'vitest'; +import { checkPolicy, updatePolicy } from './policy.js'; +import type { Config } from '../config/config.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; +import { ApprovalMode, PolicyDecision } from '../policy/types.js'; +import { + ToolConfirmationOutcome, + type AnyDeclarativeTool, + type ToolMcpConfirmationDetails, + type ToolExecuteConfirmationDetails, +} from '../tools/tools.js'; +import type { ValidatingToolCall } from './types.js'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; + +describe('policy.ts', () => { + describe('checkPolicy', () => { + it('should return the decision from the policy engine', async () => { + const mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }), + } as unknown as Mocked; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + } as unknown as Mocked; + + const toolCall = { + request: { name: 'test-tool', args: {} }, + tool: { name: 'test-tool' }, + } as ValidatingToolCall; + + const decision = await checkPolicy(toolCall, mockConfig); + expect(decision).toBe(PolicyDecision.ALLOW); + expect(mockPolicyEngine.check).toHaveBeenCalledWith( + { name: 'test-tool', args: {} }, + undefined, + ); + }); + + it('should pass serverName for MCP tools', async () => { + const mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }), + } as unknown as Mocked; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + } as unknown as Mocked; + + const mcpTool = Object.create(DiscoveredMCPTool.prototype); + mcpTool.serverName = 'my-server'; + + const toolCall = { + request: { name: 'mcp-tool', args: {} }, + tool: mcpTool, + } as ValidatingToolCall; + + await checkPolicy(toolCall, mockConfig); + expect(mockPolicyEngine.check).toHaveBeenCalledWith( + { name: 'mcp-tool', args: {} }, + 'my-server', + ); + }); + + it('should throw if ASK_USER is returned in non-interactive mode', async () => { + const mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ASK_USER }), + } as unknown as Mocked; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + isInteractive: vi.fn().mockReturnValue(false), + } as unknown as Mocked; + + const toolCall = { + request: { name: 'test-tool', args: {} }, + tool: { name: 'test-tool' }, + } as ValidatingToolCall; + + await expect(checkPolicy(toolCall, mockConfig)).rejects.toThrow( + /not supported in non-interactive mode/, + ); + }); + + it('should return DENY without throwing', async () => { + const mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.DENY }), + } as unknown as Mocked; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + } as unknown as Mocked; + + const toolCall = { + request: { name: 'test-tool', args: {} }, + tool: { name: 'test-tool' }, + } as ValidatingToolCall; + + const decision = await checkPolicy(toolCall, mockConfig); + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should return ASK_USER without throwing in interactive mode', async () => { + const mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ASK_USER }), + } as unknown as Mocked; + + const mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + isInteractive: vi.fn().mockReturnValue(true), + } as unknown as Mocked; + + const toolCall = { + request: { name: 'test-tool', args: {} }, + tool: { name: 'test-tool' }, + } as ValidatingToolCall; + + const decision = await checkPolicy(toolCall, mockConfig); + expect(decision).toBe(PolicyDecision.ASK_USER); + }); + }); + + describe('updatePolicy', () => { + it('should set AUTO_EDIT mode for auto-edit transition tools', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + + const tool = { name: 'replace' } as AnyDeclarativeTool; // 'replace' is in EDIT_TOOL_NAMES + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlways, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockConfig.setApprovalMode).toHaveBeenCalledWith( + ApprovalMode.AUTO_EDIT, + ); + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + + it('should handle standard policy updates (persist=false)', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlways, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'test-tool', + persist: false, + }), + ); + }); + + it('should handle standard policy updates with persistence', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysAndSave, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'test-tool', + persist: true, + }), + ); + }); + + it('should handle shell command prefixes', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'run_shell_command' } as AnyDeclarativeTool; + const details: ToolExecuteConfirmationDetails = { + type: 'exec', + command: 'ls -la', + rootCommand: 'ls', + rootCommands: ['ls'], + title: 'Shell', + onConfirm: vi.fn(), + }; + + await updatePolicy(tool, ToolConfirmationOutcome.ProceedAlways, details, { + config: mockConfig, + messageBus: mockMessageBus, + }); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'run_shell_command', + commandPrefix: ['ls'], + }), + ); + }); + + it('should handle MCP policy updates (server scope)', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'mcp-tool' } as AnyDeclarativeTool; + const details: ToolMcpConfirmationDetails = { + type: 'mcp', + serverName: 'my-server', + toolName: 'mcp-tool', + toolDisplayName: 'My Tool', + title: 'MCP', + onConfirm: vi.fn(), + }; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysServer, + details, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'my-server__*', + mcpName: 'my-server', + persist: false, + }), + ); + }); + + it('should NOT publish update for ProceedOnce', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy(tool, ToolConfirmationOutcome.ProceedOnce, undefined, { + config: mockConfig, + messageBus: mockMessageBus, + }); + + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + expect(mockConfig.setApprovalMode).not.toHaveBeenCalled(); + }); + + it('should NOT publish update for Cancel', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy(tool, ToolConfirmationOutcome.Cancel, undefined, { + config: mockConfig, + messageBus: mockMessageBus, + }); + + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + + it('should NOT publish update for ModifyWithEditor', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ModifyWithEditor, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).not.toHaveBeenCalled(); + }); + + it('should handle MCP ProceedAlwaysTool (specific tool name)', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'mcp-tool' } as AnyDeclarativeTool; + const details: ToolMcpConfirmationDetails = { + type: 'mcp', + serverName: 'my-server', + toolName: 'mcp-tool', + toolDisplayName: 'My Tool', + title: 'MCP', + onConfirm: vi.fn(), + }; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysTool, + details, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'mcp-tool', // Specific name, not wildcard + mcpName: 'my-server', + persist: false, + }), + ); + }); + + it('should handle MCP ProceedAlways (persist: false)', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'mcp-tool' } as AnyDeclarativeTool; + const details: ToolMcpConfirmationDetails = { + type: 'mcp', + serverName: 'my-server', + toolName: 'mcp-tool', + toolDisplayName: 'My Tool', + title: 'MCP', + onConfirm: vi.fn(), + }; + + await updatePolicy(tool, ToolConfirmationOutcome.ProceedAlways, details, { + config: mockConfig, + messageBus: mockMessageBus, + }); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'mcp-tool', + mcpName: 'my-server', + persist: false, + }), + ); + }); + + it('should handle MCP ProceedAlwaysAndSave (persist: true)', async () => { + const mockConfig = { + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'mcp-tool' } as AnyDeclarativeTool; + const details: ToolMcpConfirmationDetails = { + type: 'mcp', + serverName: 'my-server', + toolName: 'mcp-tool', + toolDisplayName: 'My Tool', + title: 'MCP', + onConfirm: vi.fn(), + }; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysAndSave, + details, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'mcp-tool', + mcpName: 'my-server', + persist: true, + }), + ); + }); + }); +}); diff --git a/packages/core/src/scheduler/policy.ts b/packages/core/src/scheduler/policy.ts new file mode 100644 index 0000000000..18e7a3f852 --- /dev/null +++ b/packages/core/src/scheduler/policy.ts @@ -0,0 +1,176 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { ApprovalMode, PolicyDecision } from '../policy/types.js'; +import type { Config } from '../config/config.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + MessageBusType, + type SerializableConfirmationDetails, +} from '../confirmation-bus/types.js'; +import { + ToolConfirmationOutcome, + type AnyDeclarativeTool, + type PolicyUpdateOptions, +} from '../tools/tools.js'; +import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; +import { EDIT_TOOL_NAMES } from '../tools/tool-names.js'; +import type { ValidatingToolCall } from './types.js'; + +/** + * Queries the system PolicyEngine to determine tool allowance. + * @returns The PolicyDecision. + * @throws Error if policy requires ASK_USER but the CLI is non-interactive. + */ +export async function checkPolicy( + toolCall: ValidatingToolCall, + config: Config, +): Promise { + const serverName = + toolCall.tool instanceof DiscoveredMCPTool + ? toolCall.tool.serverName + : undefined; + + const { decision } = await config + .getPolicyEngine() + .check( + { name: toolCall.request.name, args: toolCall.request.args }, + serverName, + ); + + if (decision === PolicyDecision.ASK_USER) { + if (!config.isInteractive()) { + throw new Error( + `Tool execution for "${ + toolCall.tool.displayName || toolCall.tool.name + }" requires user confirmation, which is not supported in non-interactive mode.`, + ); + } + } + + return decision; +} + +/** + * Evaluates the outcome of a user confirmation and dispatches + * policy config updates. + */ +export async function updatePolicy( + tool: AnyDeclarativeTool, + outcome: ToolConfirmationOutcome, + confirmationDetails: SerializableConfirmationDetails | undefined, + deps: { config: Config; messageBus: MessageBus }, +): Promise { + // Mode Transitions (AUTO_EDIT) + if (isAutoEditTransition(tool, outcome)) { + deps.config.setApprovalMode(ApprovalMode.AUTO_EDIT); + return; + } + + // Specialized Tools (MCP) + if (confirmationDetails?.type === 'mcp') { + await handleMcpPolicyUpdate( + tool, + outcome, + confirmationDetails, + deps.messageBus, + ); + return; + } + + // Generic Fallback (Shell, Info, etc.) + await handleStandardPolicyUpdate( + tool, + outcome, + confirmationDetails, + deps.messageBus, + ); +} + +/** + * Returns true if the user's 'Always Allow' selection for a specific tool + * should trigger a session-wide transition to AUTO_EDIT mode. + */ +function isAutoEditTransition( + tool: AnyDeclarativeTool, + outcome: ToolConfirmationOutcome, +): boolean { + // TODO: This is a temporary fix to enable AUTO_EDIT mode for specific + // tools. We should refactor this so that callbacks can be removed from + // tools. + return ( + outcome === ToolConfirmationOutcome.ProceedAlways && + EDIT_TOOL_NAMES.has(tool.name) + ); +} + +/** + * Handles policy updates for standard tools (Shell, Info, etc.), including + * session-level and persistent approvals. + */ +async function handleStandardPolicyUpdate( + tool: AnyDeclarativeTool, + outcome: ToolConfirmationOutcome, + confirmationDetails: SerializableConfirmationDetails | undefined, + messageBus: MessageBus, +): Promise { + if ( + outcome === ToolConfirmationOutcome.ProceedAlways || + outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave + ) { + const options: PolicyUpdateOptions = {}; + + if (confirmationDetails?.type === 'exec') { + options.commandPrefix = confirmationDetails.rootCommands; + } + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: tool.name, + persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave, + ...options, + }); + } +} + +/** + * Handles policy updates specifically for MCP tools, including session-level + * and persistent approvals. + */ +async function handleMcpPolicyUpdate( + tool: AnyDeclarativeTool, + outcome: ToolConfirmationOutcome, + confirmationDetails: Extract< + SerializableConfirmationDetails, + { type: 'mcp' } + >, + messageBus: MessageBus, +): Promise { + const isMcpAlways = + outcome === ToolConfirmationOutcome.ProceedAlways || + outcome === ToolConfirmationOutcome.ProceedAlwaysTool || + outcome === ToolConfirmationOutcome.ProceedAlwaysServer || + outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave; + + if (!isMcpAlways) { + return; + } + + let toolName = tool.name; + const persist = outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave; + + // If "Always allow all tools from this server", use the wildcard pattern + if (outcome === ToolConfirmationOutcome.ProceedAlwaysServer) { + toolName = `${confirmationDetails.serverName}__*`; + } + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName, + mcpName: confirmationDetails.serverName, + persist, + }); +} diff --git a/packages/core/src/scheduler/scheduler.test.ts b/packages/core/src/scheduler/scheduler.test.ts new file mode 100644 index 0000000000..25bdb34deb --- /dev/null +++ b/packages/core/src/scheduler/scheduler.test.ts @@ -0,0 +1,1008 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, + type Mocked, +} from 'vitest'; +import { randomUUID } from 'node:crypto'; + +vi.mock('node:crypto', () => ({ + randomUUID: vi.fn(), +})); + +vi.mock('../telemetry/trace.js', () => ({ + runInDevTraceSpan: vi.fn(async (_opts, fn) => + fn({ metadata: { input: {}, output: {} } }), + ), +})); + +import { logToolCall } from '../telemetry/loggers.js'; +import { ToolCallEvent } from '../telemetry/types.js'; +vi.mock('../telemetry/loggers.js', () => ({ + logToolCall: vi.fn(), +})); +vi.mock('../telemetry/types.js', () => ({ + ToolCallEvent: vi.fn().mockImplementation((call) => ({ ...call })), +})); + +vi.mock('../core/coreToolHookTriggers.js', () => ({ + fireToolNotificationHook: vi.fn(), +})); + +import { SchedulerStateManager } from './state-manager.js'; +import { resolveConfirmation } from './confirmation.js'; +import { checkPolicy, updatePolicy } from './policy.js'; +import { ToolExecutor } from './tool-executor.js'; +import { ToolModificationHandler } from './tool-modifier.js'; + +vi.mock('./state-manager.js'); +vi.mock('./confirmation.js'); +vi.mock('./policy.js'); +vi.mock('./tool-executor.js'); +vi.mock('./tool-modifier.js'); + +import { Scheduler } from './scheduler.js'; +import type { Config } from '../config/config.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import type { ToolRegistry } from '../tools/tool-registry.js'; +import { PolicyDecision } from '../policy/types.js'; +import { + ToolConfirmationOutcome, + type AnyDeclarativeTool, + type AnyToolInvocation, +} from '../tools/tools.js'; +import type { + ToolCallRequestInfo, + ValidatingToolCall, + SuccessfulToolCall, + ErroredToolCall, + CancelledToolCall, + ToolCallResponseInfo, +} from './types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; +import * as ToolUtils from '../utils/tool-utils.js'; +import type { EditorType } from '../utils/editor.js'; + +describe('Scheduler (Orchestrator)', () => { + let scheduler: Scheduler; + let signal: AbortSignal; + let abortController: AbortController; + + // Mocked Services (Injected via Config/Options) + let mockConfig: Mocked; + let mockMessageBus: Mocked; + let mockPolicyEngine: Mocked; + let mockToolRegistry: Mocked; + let getPreferredEditor: Mock<() => EditorType | undefined>; + + // Mocked Sub-components (Instantiated by Scheduler) + let mockStateManager: Mocked; + let mockExecutor: Mocked; + let mockModifier: Mocked; + + // Test Data + const req1: ToolCallRequestInfo = { + callId: 'call-1', + name: 'test-tool', + args: { foo: 'bar' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + const req2: ToolCallRequestInfo = { + callId: 'call-2', + name: 'test-tool', + args: { foo: 'baz' }, + isClientInitiated: false, + prompt_id: 'prompt-1', + }; + + const mockTool = { + name: 'test-tool', + build: vi.fn(), + } as unknown as AnyDeclarativeTool; + + const mockInvocation = { + shouldConfirmExecute: vi.fn(), + }; + + beforeEach(() => { + vi.mocked(randomUUID).mockReturnValue( + '123e4567-e89b-12d3-a456-426614174000', + ); + abortController = new AbortController(); + signal = abortController.signal; + + // --- Setup Injected Mocks --- + mockPolicyEngine = { + check: vi.fn().mockResolvedValue({ decision: PolicyDecision.ALLOW }), + } as unknown as Mocked; + + mockToolRegistry = { + getTool: vi.fn().mockReturnValue(mockTool), + getAllToolNames: vi.fn().mockReturnValue(['test-tool']), + } as unknown as Mocked; + + mockConfig = { + getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine), + getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), + isInteractive: vi.fn().mockReturnValue(true), + getEnableHooks: vi.fn().mockReturnValue(true), + setApprovalMode: vi.fn(), + } as unknown as Mocked; + + mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + } as unknown as Mocked; + + getPreferredEditor = vi.fn().mockReturnValue('vim'); + + // --- Setup Sub-component Mocks --- + mockStateManager = { + enqueue: vi.fn(), + dequeue: vi.fn(), + getToolCall: vi.fn(), + updateStatus: vi.fn(), + finalizeCall: vi.fn(), + updateArgs: vi.fn(), + setOutcome: vi.fn(), + cancelAllQueued: vi.fn(), + clearBatch: vi.fn(), + } as unknown as Mocked; + + // Define getters for accessors idiomatically + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi.fn().mockReturnValue(0), + configurable: true, + }); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(undefined), + configurable: true, + }); + Object.defineProperty(mockStateManager, 'completedBatch', { + get: vi.fn().mockReturnValue([]), + configurable: true, + }); + + vi.spyOn(mockStateManager, 'cancelAllQueued').mockImplementation(() => {}); + vi.spyOn(mockStateManager, 'clearBatch').mockImplementation(() => {}); + + vi.mocked(resolveConfirmation).mockReset(); + vi.mocked(checkPolicy).mockReset(); + vi.mocked(updatePolicy).mockReset(); + + mockExecutor = { + execute: vi.fn(), + } as unknown as Mocked; + + mockModifier = { + handleModifyWithEditor: vi.fn(), + applyInlineModify: vi.fn(), + } as unknown as Mocked; + + // Wire up class constructors to return our mock instances + vi.mocked(SchedulerStateManager).mockReturnValue( + mockStateManager as unknown as Mocked, + ); + vi.mocked(ToolExecutor).mockReturnValue( + mockExecutor as unknown as Mocked, + ); + vi.mocked(ToolModificationHandler).mockReturnValue( + mockModifier as unknown as Mocked, + ); + + // Initialize Scheduler + scheduler = new Scheduler({ + config: mockConfig, + messageBus: mockMessageBus, + getPreferredEditor, + }); + + // Reset Tool build behavior + vi.mocked(mockTool.build).mockReturnValue( + mockInvocation as unknown as AnyToolInvocation, + ); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + describe('Phase 1: Ingestion & Resolution', () => { + it('should create an ErroredToolCall if tool is not found', async () => { + vi.mocked(mockToolRegistry.getTool).mockReturnValue(undefined); + vi.spyOn(ToolUtils, 'getToolSuggestion').mockReturnValue( + ' (Did you mean "test-tool"?)', + ); + + await scheduler.schedule(req1, signal); + + // Verify it was enqueued with an error status + expect(mockStateManager.enqueue).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + status: 'error', + response: expect.objectContaining({ + errorType: ToolErrorType.TOOL_NOT_REGISTERED, + }), + }), + ]), + ); + }); + + it('should create an ErroredToolCall if tool.build throws (invalid args)', async () => { + vi.mocked(mockTool.build).mockImplementation(() => { + throw new Error('Invalid schema'); + }); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.enqueue).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + status: 'error', + response: expect.objectContaining({ + errorType: ToolErrorType.INVALID_TOOL_PARAMS, + }), + }), + ]), + ); + }); + + it('should correctly build ValidatingToolCalls for happy path', async () => { + await scheduler.schedule(req1, signal); + + expect(mockStateManager.enqueue).toHaveBeenCalledWith( + expect.arrayContaining([ + expect.objectContaining({ + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation, + }), + ]), + ); + }); + }); + + describe('Phase 2: Queue Management', () => { + it('should drain the queue if multiple calls are scheduled', async () => { + const validatingCall: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + // Setup queue simulation: two items + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi + .fn() + .mockReturnValueOnce(2) + .mockReturnValueOnce(1) + .mockReturnValue(0), + configurable: true, + }); + + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + + mockStateManager.dequeue.mockReturnValue(validatingCall); + vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(validatingCall), + configurable: true, + }); + + // Execute is the end of the loop, stub it + mockExecutor.execute.mockResolvedValue({ + status: 'success', + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + // Verify loop ran twice + expect(mockStateManager.dequeue).toHaveBeenCalledTimes(2); + expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(2); + }); + + it('should execute tool calls sequentially (first completes before second starts)', async () => { + // Setup queue simulation: two items + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi + .fn() + .mockReturnValueOnce(2) + .mockReturnValueOnce(1) + .mockReturnValue(0), + configurable: true, + }); + + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + + const validatingCall1: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + const validatingCall2: ValidatingToolCall = { + status: 'validating', + request: req2, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + vi.mocked(mockStateManager.dequeue) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2) + .mockReturnValue(undefined); + + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi + .fn() + .mockReturnValueOnce(validatingCall1) // Used in loop check for call 1 + .mockReturnValueOnce(validatingCall1) // Used in _execute for call 1 + .mockReturnValueOnce(validatingCall2) // Used in loop check for call 2 + .mockReturnValueOnce(validatingCall2), // Used in _execute for call 2 + configurable: true, + }); + + const executionLog: string[] = []; + + // Mock executor to push to log with a deterministic microtask delay + mockExecutor.execute.mockImplementation(async ({ call }) => { + const id = call.request.callId; + executionLog.push(`start-${id}`); + // Yield to the event loop deterministically using queueMicrotask + await new Promise((resolve) => queueMicrotask(resolve)); + executionLog.push(`end-${id}`); + return { status: 'success' } as unknown as SuccessfulToolCall; + }); + + // Action: Schedule batch of 2 tools + await scheduler.schedule([req1, req2], signal); + + // Assert: The second tool only started AFTER the first one ended + expect(executionLog).toEqual([ + 'start-call-1', + 'end-call-1', + 'start-call-2', + 'end-call-2', + ]); + }); + + it('should queue and process multiple schedule() calls made synchronously', async () => { + const validatingCall1: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + const validatingCall2: ValidatingToolCall = { + status: 'validating', + request: req2, // Second request + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + // Mock state responses dynamically + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + + // Queue state responses for the two batches: + // Batch 1: length 1 -> 0 + // Batch 2: length 1 -> 0 + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi + .fn() + .mockReturnValueOnce(1) + .mockReturnValueOnce(0) + .mockReturnValueOnce(1) + .mockReturnValue(0), + configurable: true, + }); + + vi.mocked(mockStateManager.dequeue) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi + .fn() + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2) + .mockReturnValueOnce(validatingCall2), + configurable: true, + }); + + // Executor succeeds instantly + mockExecutor.execute.mockResolvedValue({ + status: 'success', + } as unknown as SuccessfulToolCall); + + // ACT: Call schedule twice synchronously (without awaiting the first) + const promise1 = scheduler.schedule(req1, signal); + const promise2 = scheduler.schedule(req2, signal); + + await Promise.all([promise1, promise2]); + + // ASSERT: Both requests were eventually pulled from the queue and executed + expect(mockExecutor.execute).toHaveBeenCalledTimes(2); + expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); + expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-2'); + }); + + it('should queue requests when scheduler is busy (overlapping batches)', async () => { + const validatingCall1: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + const validatingCall2: ValidatingToolCall = { + status: 'validating', + request: req2, // Second request + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + // 1. Setup State Manager for 2 sequential batches + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(false), + configurable: true, + }); + + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi + .fn() + .mockReturnValueOnce(1) // Batch 1 + .mockReturnValueOnce(0) + .mockReturnValueOnce(1) // Batch 2 + .mockReturnValue(0), + configurable: true, + }); + + vi.mocked(mockStateManager.dequeue) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2); + + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi + .fn() + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2) + .mockReturnValueOnce(validatingCall2), + configurable: true, + }); + + // 2. Setup Executor with a controllable lock for the first batch + const executionLog: string[] = []; + let finishFirstBatch: (value: unknown) => void; + const firstBatchPromise = new Promise((resolve) => { + finishFirstBatch = resolve; + }); + + mockExecutor.execute.mockImplementationOnce(async () => { + executionLog.push('start-batch-1'); + await firstBatchPromise; // Simulating long-running tool execution + executionLog.push('end-batch-1'); + return { status: 'success' } as unknown as SuccessfulToolCall; + }); + + mockExecutor.execute.mockImplementationOnce(async () => { + executionLog.push('start-batch-2'); + executionLog.push('end-batch-2'); + return { status: 'success' } as unknown as SuccessfulToolCall; + }); + + // 3. ACTIONS + // Start Batch 1 (it will block indefinitely inside execution) + const promise1 = scheduler.schedule(req1, signal); + + // Schedule Batch 2 WHILE Batch 1 is executing + const promise2 = scheduler.schedule(req2, signal); + + // Yield event loop to let promise2 hit the queue + await new Promise((r) => setTimeout(r, 0)); + + // At this point, Batch 2 should NOT have started + expect(executionLog).not.toContain('start-batch-2'); + + // Now resolve Batch 1, which should trigger the request queue drain + finishFirstBatch!({}); + + await Promise.all([promise1, promise2]); + + // 4. ASSERTIONS + // Verify complete sequential ordering of the two overlapping batches + expect(executionLog).toEqual([ + 'start-batch-1', + 'end-batch-1', + 'start-batch-2', + 'end-batch-2', + ]); + }); + + it('should cancel all queues if AbortSignal is triggered during loop', async () => { + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi.fn().mockReturnValue(1), + configurable: true, + }); + abortController.abort(); // Signal aborted + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.cancelAllQueued).toHaveBeenCalledWith( + 'Operation cancelled', + ); + expect(mockStateManager.dequeue).not.toHaveBeenCalled(); // Loop broke + }); + + it('cancelAll() should cancel active call and clear queue', () => { + const activeCall: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(activeCall), + configurable: true, + }); + + scheduler.cancelAll(); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'cancelled', + 'Operation cancelled by user', + ); + // finalizeCall is handled by the processing loop, not synchronously by cancelAll + // expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); + expect(mockStateManager.cancelAllQueued).toHaveBeenCalledWith( + 'Operation cancelled by user', + ); + }); + + it('cancelAll() should clear the requestQueue and reject pending promises', async () => { + // 1. Setup a busy scheduler with one batch processing + Object.defineProperty(mockStateManager, 'isActive', { + get: vi.fn().mockReturnValue(true), + configurable: true, + }); + const promise1 = scheduler.schedule(req1, signal); + // Catch promise1 to avoid unhandled rejection when we cancelAll + promise1.catch(() => {}); + + // 2. Queue another batch while the first is busy + const promise2 = scheduler.schedule(req2, signal); + + // 3. ACT: Cancel everything + scheduler.cancelAll(); + + // 4. ASSERT: The second batch's promise should be rejected + await expect(promise2).rejects.toThrow('Operation cancelled by user'); + }); + }); + + describe('Phase 3: Policy & Confirmation Loop', () => { + const validatingCall: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + beforeEach(() => { + Object.defineProperty(mockStateManager, 'queueLength', { + get: vi.fn().mockReturnValueOnce(1).mockReturnValue(0), + configurable: true, + }); + vi.mocked(mockStateManager.dequeue).mockReturnValue(validatingCall); + Object.defineProperty(mockStateManager, 'firstActiveCall', { + get: vi.fn().mockReturnValue(validatingCall), + configurable: true, + }); + }); + + it('should update state to error with POLICY_VIOLATION if Policy returns DENY', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.DENY); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'error', + expect.objectContaining({ + errorType: ToolErrorType.POLICY_VIOLATION, + }), + ); + // Deny shouldn't throw, execution is just skipped, state is updated + expect(mockExecutor.execute).not.toHaveBeenCalled(); + }); + + it('should handle errors from checkPolicy (e.g. non-interactive ASK_USER)', async () => { + const error = new Error('Not interactive'); + vi.mocked(checkPolicy).mockRejectedValue(error); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'error', + expect.objectContaining({ + errorType: ToolErrorType.UNHANDLED_EXCEPTION, + responseParts: expect.arrayContaining([ + expect.objectContaining({ + functionResponse: expect.objectContaining({ + response: { error: 'Not interactive' }, + }), + }), + ]), + }), + ); + }); + + it('should bypass confirmation and ProceedOnce if Policy returns ALLOW (YOLO/AllowedTools)', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.ALLOW); + + // Provide a mock execute to finish the loop + mockExecutor.execute.mockResolvedValue({ + status: 'success', + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + // Never called coordinator + expect(resolveConfirmation).not.toHaveBeenCalled(); + + // State recorded as ProceedOnce + expect(mockStateManager.setOutcome).toHaveBeenCalledWith( + 'call-1', + ToolConfirmationOutcome.ProceedOnce, + ); + + // Triggered execution + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'executing', + ); + expect(mockExecutor.execute).toHaveBeenCalled(); + }); + + it('should auto-approve remaining identical tools in batch after ProceedAlways', async () => { + // Setup: two identical tools + const validatingCall1: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + const validatingCall2: ValidatingToolCall = { + status: 'validating', + request: req2, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + vi.mocked(mockStateManager.dequeue) + .mockReturnValueOnce(validatingCall1) + .mockReturnValueOnce(validatingCall2) + .mockReturnValue(undefined); + + vi.spyOn(mockStateManager, 'queueLength', 'get') + .mockReturnValueOnce(2) + .mockReturnValueOnce(1) + .mockReturnValue(0); + + // First call requires confirmation, second is auto-approved (simulating policy update) + vi.mocked(checkPolicy) + .mockResolvedValueOnce(PolicyDecision.ASK_USER) + .mockResolvedValueOnce(PolicyDecision.ALLOW); + + vi.mocked(resolveConfirmation).mockResolvedValue({ + outcome: ToolConfirmationOutcome.ProceedAlways, + lastDetails: undefined, + }); + + mockExecutor.execute.mockResolvedValue({ + status: 'success', + } as unknown as SuccessfulToolCall); + + await scheduler.schedule([req1, req2], signal); + + // resolveConfirmation only called ONCE + expect(resolveConfirmation).toHaveBeenCalledTimes(1); + // updatePolicy called for the first tool + expect(updatePolicy).toHaveBeenCalled(); + // execute called TWICE + expect(mockExecutor.execute).toHaveBeenCalledTimes(2); + }); + + it('should call resolveConfirmation and updatePolicy when ASK_USER', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.ASK_USER); + + const resolution = { + outcome: ToolConfirmationOutcome.ProceedAlways, + lastDetails: { + type: 'info' as const, + title: 'Title', + prompt: 'Confirm?', + }, + }; + vi.mocked(resolveConfirmation).mockResolvedValue(resolution); + + mockExecutor.execute.mockResolvedValue({ + status: 'success', + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + expect(resolveConfirmation).toHaveBeenCalledWith( + expect.anything(), // toolCall + signal, + expect.objectContaining({ + config: mockConfig, + messageBus: mockMessageBus, + state: mockStateManager, + }), + ); + + expect(updatePolicy).toHaveBeenCalledWith( + mockTool, + resolution.outcome, + resolution.lastDetails, + expect.objectContaining({ + config: mockConfig, + messageBus: mockMessageBus, + }), + ); + + expect(mockExecutor.execute).toHaveBeenCalled(); + }); + + it('should cancel and NOT execute if resolveConfirmation returns Cancel', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.ASK_USER); + + const resolution = { + outcome: ToolConfirmationOutcome.Cancel, + lastDetails: undefined, + }; + vi.mocked(resolveConfirmation).mockResolvedValue(resolution); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'cancelled', + 'User denied execution.', + ); + expect(mockStateManager.cancelAllQueued).toHaveBeenCalledWith( + 'User cancelled operation', + ); + expect(mockExecutor.execute).not.toHaveBeenCalled(); + }); + + it('should mark as cancelled (not errored) when abort happens during confirmation error', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.ASK_USER); + + // Simulate shouldConfirmExecute logic throwing while aborted + vi.mocked(resolveConfirmation).mockImplementation(async () => { + // Trigger abort + abortController.abort(); + throw new Error('Some internal network abort error'); + }); + + await scheduler.schedule(req1, signal); + + // Verify execution did NOT happen + expect(mockExecutor.execute).not.toHaveBeenCalled(); + + // Because the signal is aborted, the catch block should convert the error to a cancellation + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'cancelled', + 'Operation cancelled', + ); + }); + + it('should preserve confirmation details (e.g. diff) in cancelled state', async () => { + vi.mocked(checkPolicy).mockResolvedValue(PolicyDecision.ASK_USER); + + const confirmDetails = { + type: 'edit' as const, + title: 'Edit', + fileName: 'file.txt', + fileDiff: 'diff content', + filePath: '/path/to/file.txt', + originalContent: 'old', + newContent: 'new', + }; + + const resolution = { + outcome: ToolConfirmationOutcome.Cancel, + lastDetails: confirmDetails, + }; + vi.mocked(resolveConfirmation).mockResolvedValue(resolution); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'cancelled', + 'User denied execution.', + ); + // We assume the state manager stores these details. + // Since we mock state manager, we just verify the flow passed the details. + // In a real integration, StateManager.updateStatus would merge these. + }); + }); + + describe('Phase 4: Execution Outcomes', () => { + const validatingCall: ValidatingToolCall = { + status: 'validating', + request: req1, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + + beforeEach(() => { + vi.spyOn(mockStateManager, 'queueLength', 'get') + .mockReturnValueOnce(1) + .mockReturnValue(0); + mockStateManager.dequeue.mockReturnValue(validatingCall); + vi.spyOn(mockStateManager, 'firstActiveCall', 'get').mockReturnValue( + validatingCall, + ); + mockPolicyEngine.check.mockResolvedValue({ + decision: PolicyDecision.ALLOW, + }); // Bypass confirmation + }); + + it('should update state to success on successful execution', async () => { + const mockResponse = { + callId: 'call-1', + responseParts: [], + } as unknown as ToolCallResponseInfo; + + mockExecutor.execute.mockResolvedValue({ + status: 'success', + response: mockResponse, + } as unknown as SuccessfulToolCall); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'success', + mockResponse, + ); + }); + + it('should update state to cancelled when executor returns cancelled status', async () => { + mockExecutor.execute.mockResolvedValue({ + status: 'cancelled', + response: { callId: 'call-1', responseParts: [] }, + } as unknown as CancelledToolCall); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'cancelled', + 'Operation cancelled', + ); + }); + + it('should update state to error on execution failure', async () => { + const mockResponse = { + callId: 'call-1', + error: new Error('fail'), + } as unknown as ToolCallResponseInfo; + + mockExecutor.execute.mockResolvedValue({ + status: 'error', + response: mockResponse, + } as unknown as ErroredToolCall); + + await scheduler.schedule(req1, signal); + + expect(mockStateManager.updateStatus).toHaveBeenCalledWith( + 'call-1', + 'error', + mockResponse, + ); + }); + + it('should log telemetry for terminal states in the queue processor', async () => { + const mockResponse = { + callId: 'call-1', + responseParts: [], + } as unknown as ToolCallResponseInfo; + + // Mock the execution so the state advances + mockExecutor.execute.mockResolvedValue({ + status: 'success', + response: mockResponse, + } as unknown as SuccessfulToolCall); + + // Mock the state manager to return a SUCCESS state when getToolCall is + // called + const successfulCall: SuccessfulToolCall = { + status: 'success', + request: req1, + response: mockResponse, + tool: mockTool, + invocation: mockInvocation as unknown as AnyToolInvocation, + }; + mockStateManager.getToolCall.mockReturnValue(successfulCall); + Object.defineProperty(mockStateManager, 'completedBatch', { + get: vi.fn().mockReturnValue([successfulCall]), + configurable: true, + }); + + await scheduler.schedule(req1, signal); + + // Verify the finalizer and logger were called + expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); + expect(ToolCallEvent).toHaveBeenCalledWith(successfulCall); + expect(logToolCall).toHaveBeenCalledWith( + mockConfig, + expect.objectContaining(successfulCall), + ); + }); + + it('should not double-report completed tools when concurrent completions occur', async () => { + // Simulate a race where execution finishes but cancelAll is called immediately after + const response: ToolCallResponseInfo = { + callId: 'call-1', + responseParts: [], + resultDisplay: undefined, + error: undefined, + errorType: undefined, + contentLength: 0, + }; + + mockExecutor.execute.mockResolvedValue({ + status: 'success', + response, + } as unknown as SuccessfulToolCall); + + const promise = scheduler.schedule(req1, signal); + scheduler.cancelAll(); + await promise; + + // finalizeCall should be called exactly once for this ID + expect(mockStateManager.finalizeCall).toHaveBeenCalledTimes(1); + expect(mockStateManager.finalizeCall).toHaveBeenCalledWith('call-1'); + }); + }); +}); diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts new file mode 100644 index 0000000000..b4021faa0b --- /dev/null +++ b/packages/core/src/scheduler/scheduler.ts @@ -0,0 +1,477 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { SchedulerStateManager } from './state-manager.js'; +import { resolveConfirmation } from './confirmation.js'; +import { checkPolicy, updatePolicy } from './policy.js'; +import { ToolExecutor } from './tool-executor.js'; +import { ToolModificationHandler } from './tool-modifier.js'; +import { + type ToolCallRequestInfo, + type ToolCall, + type ToolCallResponseInfo, + type CompletedToolCall, + type ExecutingToolCall, + type ValidatingToolCall, + type ErroredToolCall, +} from './types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; +import { PolicyDecision } from '../policy/types.js'; +import { + ToolConfirmationOutcome, + type AnyDeclarativeTool, +} from '../tools/tools.js'; +import { getToolSuggestion } from '../utils/tool-utils.js'; +import { runInDevTraceSpan } from '../telemetry/trace.js'; +import { logToolCall } from '../telemetry/loggers.js'; +import { ToolCallEvent } from '../telemetry/types.js'; +import type { EditorType } from '../utils/editor.js'; +import { + MessageBusType, + type SerializableConfirmationDetails, + type ToolConfirmationRequest, +} from '../confirmation-bus/types.js'; + +interface SchedulerQueueItem { + requests: ToolCallRequestInfo[]; + signal: AbortSignal; + resolve: (results: CompletedToolCall[]) => void; + reject: (reason?: Error) => void; +} + +export interface SchedulerOptions { + config: Config; + messageBus: MessageBus; + getPreferredEditor: () => EditorType | undefined; +} + +const createErrorResponse = ( + request: ToolCallRequestInfo, + error: Error, + errorType: ToolErrorType | undefined, +): ToolCallResponseInfo => ({ + callId: request.callId, + error, + responseParts: [ + { + functionResponse: { + id: request.callId, + name: request.name, + response: { error: error.message }, + }, + }, + ], + resultDisplay: error.message, + errorType, + contentLength: error.message.length, +}); + +/** + * Event-Driven Orchestrator for Tool Execution. + * Coordinates execution via state updates and event listening. + */ +export class Scheduler { + // Tracks which MessageBus instances have the legacy listener attached to prevent duplicates. + private static subscribedMessageBuses = new WeakSet(); + + private readonly state: SchedulerStateManager; + private readonly executor: ToolExecutor; + private readonly modifier: ToolModificationHandler; + private readonly config: Config; + private readonly messageBus: MessageBus; + private readonly getPreferredEditor: () => EditorType | undefined; + + private isProcessing = false; + private isCancelling = false; + private readonly requestQueue: SchedulerQueueItem[] = []; + + constructor(options: SchedulerOptions) { + this.config = options.config; + this.messageBus = options.messageBus; + this.getPreferredEditor = options.getPreferredEditor; + this.state = new SchedulerStateManager(this.messageBus); + this.executor = new ToolExecutor(this.config); + this.modifier = new ToolModificationHandler(); + + this.setupMessageBusListener(this.messageBus); + } + + private setupMessageBusListener(messageBus: MessageBus): void { + if (Scheduler.subscribedMessageBuses.has(messageBus)) { + return; + } + + // TODO: Optimize policy checks. Currently, tools check policy via + // MessageBus even though the Scheduler already checked it. + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_REQUEST, + async (request: ToolConfirmationRequest) => { + await messageBus.publish({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: request.correlationId, + confirmed: false, + requiresUserConfirmation: true, + }); + }, + ); + + Scheduler.subscribedMessageBuses.add(messageBus); + } + + /** + * Schedules a batch of tool calls. + * @returns A promise that resolves with the results of the completed batch. + */ + async schedule( + request: ToolCallRequestInfo | ToolCallRequestInfo[], + signal: AbortSignal, + ): Promise { + return runInDevTraceSpan( + { name: 'schedule' }, + async ({ metadata: spanMetadata }) => { + const requests = Array.isArray(request) ? request : [request]; + spanMetadata.input = requests; + + if (this.isProcessing || this.state.isActive) { + return this._enqueueRequest(requests, signal); + } + + return this._startBatch(requests, signal); + }, + ); + } + + private _enqueueRequest( + requests: ToolCallRequestInfo[], + signal: AbortSignal, + ): Promise { + return new Promise((resolve, reject) => { + const abortHandler = () => { + const index = this.requestQueue.findIndex( + (item) => item.requests === requests, + ); + if (index > -1) { + this.requestQueue.splice(index, 1); + reject(new Error('Tool call cancelled while in queue.')); + } + }; + + if (signal.aborted) { + reject(new Error('Operation cancelled')); + return; + } + + signal.addEventListener('abort', abortHandler, { once: true }); + + this.requestQueue.push({ + requests, + signal, + resolve: (results) => { + signal.removeEventListener('abort', abortHandler); + resolve(results); + }, + reject: (err) => { + signal.removeEventListener('abort', abortHandler); + reject(err); + }, + }); + }); + } + + cancelAll(): void { + if (this.isCancelling) return; + this.isCancelling = true; + + // Clear scheduler request queue + while (this.requestQueue.length > 0) { + const next = this.requestQueue.shift(); + next?.reject(new Error('Operation cancelled by user')); + } + + // Cancel active call + const activeCall = this.state.firstActiveCall; + if (activeCall && !this.isTerminal(activeCall.status)) { + this.state.updateStatus( + activeCall.request.callId, + 'cancelled', + 'Operation cancelled by user', + ); + } + + // Clear queue + this.state.cancelAllQueued('Operation cancelled by user'); + } + + get completedCalls(): CompletedToolCall[] { + return this.state.completedBatch; + } + + private isTerminal(status: string) { + return status === 'success' || status === 'error' || status === 'cancelled'; + } + + // --- Phase 1: Ingestion & Resolution --- + + private async _startBatch( + requests: ToolCallRequestInfo[], + signal: AbortSignal, + ): Promise { + this.isProcessing = true; + this.isCancelling = false; + this.state.clearBatch(); + + try { + const toolRegistry = this.config.getToolRegistry(); + const newCalls: ToolCall[] = requests.map((request) => { + const tool = toolRegistry.getTool(request.name); + + if (!tool) { + return this._createToolNotFoundErroredToolCall( + request, + toolRegistry.getAllToolNames(), + ); + } + + return this._validateAndCreateToolCall(request, tool); + }); + + this.state.enqueue(newCalls); + await this._processQueue(signal); + return this.state.completedBatch; + } finally { + this.isProcessing = false; + this._processNextInRequestQueue(); + } + } + + private _createToolNotFoundErroredToolCall( + request: ToolCallRequestInfo, + toolNames: string[], + ): ErroredToolCall { + const suggestion = getToolSuggestion(request.name, toolNames); + return { + status: 'error', + request, + response: createErrorResponse( + request, + new Error(`Tool "${request.name}" not found.${suggestion}`), + ToolErrorType.TOOL_NOT_REGISTERED, + ), + durationMs: 0, + }; + } + + private _validateAndCreateToolCall( + request: ToolCallRequestInfo, + tool: AnyDeclarativeTool, + ): ValidatingToolCall | ErroredToolCall { + try { + const invocation = tool.build(request.args); + return { + status: 'validating', + request, + tool, + invocation, + startTime: Date.now(), + }; + } catch (e) { + return { + status: 'error', + request, + tool, + response: createErrorResponse( + request, + e instanceof Error ? e : new Error(String(e)), + ToolErrorType.INVALID_TOOL_PARAMS, + ), + durationMs: 0, + }; + } + } + + // --- Phase 2: Processing Loop --- + + private async _processQueue(signal: AbortSignal): Promise { + while (this.state.queueLength > 0 || this.state.isActive) { + const shouldContinue = await this._processNextItem(signal); + if (!shouldContinue) break; + } + } + + /** + * Processes the next item in the queue. + * @returns true if the loop should continue, false if it should terminate. + */ + private async _processNextItem(signal: AbortSignal): Promise { + if (signal.aborted || this.isCancelling) { + this.state.cancelAllQueued('Operation cancelled'); + return false; + } + + if (!this.state.isActive) { + const next = this.state.dequeue(); + if (!next) return false; + + if (next.status === 'error') { + this.state.updateStatus(next.request.callId, 'error', next.response); + this.state.finalizeCall(next.request.callId); + return true; + } + } + + const active = this.state.firstActiveCall; + if (!active) return false; + + if (active.status === 'validating') { + await this._processValidatingCall(active, signal); + } + + return true; + } + + private async _processValidatingCall( + active: ValidatingToolCall, + signal: AbortSignal, + ): Promise { + try { + await this._processToolCall(active, signal); + } catch (error) { + const err = error instanceof Error ? error : new Error(String(error)); + // If the signal aborted while we were waiting on something, treat as + // cancelled. Otherwise, it's a genuine unhandled system exception. + if (signal.aborted || err.name === 'AbortError') { + this.state.updateStatus( + active.request.callId, + 'cancelled', + 'Operation cancelled', + ); + } else { + this.state.updateStatus( + active.request.callId, + 'error', + createErrorResponse( + active.request, + err, + ToolErrorType.UNHANDLED_EXCEPTION, + ), + ); + } + } + + // Fetch the updated call from state before finalizing to capture the + // terminal status. + const terminalCall = this.state.getToolCall(active.request.callId); + if (terminalCall && this.isTerminal(terminalCall.status)) { + logToolCall( + this.config, + new ToolCallEvent(terminalCall as CompletedToolCall), + ); + } + + this.state.finalizeCall(active.request.callId); + } + + // --- Phase 3: Single Call Orchestration --- + + private async _processToolCall( + toolCall: ValidatingToolCall, + signal: AbortSignal, + ): Promise { + const callId = toolCall.request.callId; + + // Policy & Security + const decision = await checkPolicy(toolCall, this.config); + + if (decision === PolicyDecision.DENY) { + this.state.updateStatus( + callId, + 'error', + createErrorResponse( + toolCall.request, + new Error('Tool execution denied by policy.'), + ToolErrorType.POLICY_VIOLATION, + ), + ); + return; + } + + // User Confirmation Loop + let outcome = ToolConfirmationOutcome.ProceedOnce; + let lastDetails: SerializableConfirmationDetails | undefined; + + if (decision === PolicyDecision.ASK_USER) { + const result = await resolveConfirmation(toolCall, signal, { + config: this.config, + messageBus: this.messageBus, + state: this.state, + modifier: this.modifier, + getPreferredEditor: this.getPreferredEditor, + }); + outcome = result.outcome; + lastDetails = result.lastDetails; + } else { + this.state.setOutcome(callId, ToolConfirmationOutcome.ProceedOnce); + } + + // Handle Policy Updates + await updatePolicy(toolCall.tool, outcome, lastDetails, { + config: this.config, + messageBus: this.messageBus, + }); + + // Handle cancellation (cascades to entire batch) + if (outcome === ToolConfirmationOutcome.Cancel) { + this.state.updateStatus(callId, 'cancelled', 'User denied execution.'); + this.state.cancelAllQueued('User cancelled operation'); + return; // Skip execution + } + + // Execution + await this._execute(callId, signal); + } + + // --- Sub-phase Handlers --- + + /** + * Executes the tool and records the result. + */ + private async _execute(callId: string, signal: AbortSignal): Promise { + this.state.updateStatus(callId, 'scheduled'); + if (signal.aborted) throw new Error('Operation cancelled'); + this.state.updateStatus(callId, 'executing'); + + const result = await this.executor.execute({ + call: this.state.firstActiveCall as ExecutingToolCall, + signal, + outputUpdateHandler: (id, out) => + this.state.updateStatus(id, 'executing', { liveOutput: out }), + onUpdateToolCall: (updated) => { + if (updated.status === 'executing' && updated.pid) { + this.state.updateStatus(callId, 'executing', { pid: updated.pid }); + } + }, + }); + + if (result.status === 'success') { + this.state.updateStatus(callId, 'success', result.response); + } else if (result.status === 'cancelled') { + this.state.updateStatus(callId, 'cancelled', 'Operation cancelled'); + } else { + this.state.updateStatus(callId, 'error', result.response); + } + } + + private _processNextInRequestQueue() { + if (this.requestQueue.length > 0) { + const next = this.requestQueue.shift()!; + this.schedule(next.requests, next.signal) + .then(next.resolve) + .catch(next.reject); + } + } +}