From f367b959cdb347487045a0d6d3fba35c3855fb88 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:44:59 -0500 Subject: [PATCH] feat(scheduler): add functional awaitConfirmation utility (#16721) --- .../core/src/scheduler/confirmation.test.ts | 222 ++++++++++++++++++ packages/core/src/scheduler/confirmation.ts | 73 ++++++ 2 files changed, 295 insertions(+) create mode 100644 packages/core/src/scheduler/confirmation.test.ts create mode 100644 packages/core/src/scheduler/confirmation.ts diff --git a/packages/core/src/scheduler/confirmation.test.ts b/packages/core/src/scheduler/confirmation.test.ts new file mode 100644 index 0000000000..4e7453428b --- /dev/null +++ b/packages/core/src/scheduler/confirmation.test.ts @@ -0,0 +1,222 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { EventEmitter } from 'node:events'; +import { awaitConfirmation } from './confirmation.js'; +import { + MessageBusType, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { ToolConfirmationOutcome } from '../tools/tools.js'; + +describe('awaitConfirmation', () => { + let mockMessageBus: MessageBus; + + beforeEach(() => { + mockMessageBus = new EventEmitter() as unknown as MessageBus; + mockMessageBus.publish = vi.fn().mockResolvedValue(undefined); + // on() from node:events uses addListener/removeListener or on/off internally. + vi.spyOn(mockMessageBus, 'on'); + vi.spyOn(mockMessageBus, 'removeListener'); + }); + + const emitResponse = (response: ToolConfirmationResponse) => { + mockMessageBus.emit(MessageBusType.TOOL_CONFIRMATION_RESPONSE, response); + }; + + it('should resolve when confirmed response matches correlationId', async () => { + const correlationId = 'test-correlation-id'; + const abortController = new AbortController(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + expect(mockMessageBus.on).toHaveBeenCalledWith( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + expect.any(Function), + ); + + // Simulate response + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: true, + }); + + const result = await promise; + expect(result).toEqual({ + outcome: ToolConfirmationOutcome.ProceedOnce, + payload: undefined, + }); + expect(mockMessageBus.removeListener).toHaveBeenCalled(); + }); + + it('should resolve with mapped outcome when confirmed is false', async () => { + const correlationId = 'id-123'; + const abortController = new AbortController(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: false, + }); + + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.Cancel); + }); + + it('should resolve with explicit outcome if provided', async () => { + const correlationId = 'id-456'; + const abortController = new AbortController(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: true, + outcome: ToolConfirmationOutcome.ProceedAlways, + }); + + const result = await promise; + expect(result.outcome).toBe(ToolConfirmationOutcome.ProceedAlways); + }); + + it('should resolve with payload', async () => { + const correlationId = 'id-payload'; + const abortController = new AbortController(); + const payload = { newContent: 'updated' }; + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: true, + outcome: ToolConfirmationOutcome.ModifyWithEditor, + payload, + }); + + const result = await promise; + expect(result.payload).toEqual(payload); + }); + + it('should ignore responses with different correlation IDs', async () => { + const correlationId = 'my-id'; + const abortController = new AbortController(); + + let resolved = false; + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ).then((r) => { + resolved = true; + return r; + }); + + // Emit wrong ID + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'wrong-id', + confirmed: true, + }); + + // Allow microtasks to process + await new Promise((r) => setTimeout(r, 0)); + expect(resolved).toBe(false); + + // Emit correct ID + emitResponse({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId, + confirmed: true, + }); + + await expect(promise).resolves.toBeDefined(); + }); + + it('should reject when abort signal is triggered', async () => { + const correlationId = 'abort-id'; + const abortController = new AbortController(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + abortController.abort(); + + await expect(promise).rejects.toThrow('Operation cancelled'); + expect(mockMessageBus.removeListener).toHaveBeenCalled(); + }); + + it('should reject when abort signal timeout is triggered', async () => { + vi.useFakeTimers(); + const correlationId = 'timeout-id'; + const signal = AbortSignal.timeout(100); + + const promise = awaitConfirmation(mockMessageBus, correlationId, signal); + + vi.advanceTimersByTime(101); + + await expect(promise).rejects.toThrow('Operation cancelled'); + expect(mockMessageBus.removeListener).toHaveBeenCalled(); + vi.useRealTimers(); + }); + + it('should reject immediately if signal is already aborted', async () => { + const correlationId = 'pre-abort-id'; + const abortController = new AbortController(); + abortController.abort(); + + const promise = awaitConfirmation( + mockMessageBus, + correlationId, + abortController.signal, + ); + + await expect(promise).rejects.toThrow('Operation cancelled'); + expect(mockMessageBus.on).not.toHaveBeenCalled(); + }); + + it('should cleanup and reject if subscribe throws', async () => { + const error = new Error('Subscribe failed'); + vi.mocked(mockMessageBus.on).mockImplementationOnce(() => { + throw error; + }); + + const abortController = new AbortController(); + const promise = awaitConfirmation( + mockMessageBus, + 'fail-id', + abortController.signal, + ); + + await expect(promise).rejects.toThrow(error); + expect(mockMessageBus.removeListener).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/scheduler/confirmation.ts b/packages/core/src/scheduler/confirmation.ts new file mode 100644 index 0000000000..5ed2f31e98 --- /dev/null +++ b/packages/core/src/scheduler/confirmation.ts @@ -0,0 +1,73 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { on } from 'node:events'; +import { + MessageBusType, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + ToolConfirmationOutcome, + type ToolConfirmationPayload, +} from '../tools/tools.js'; + +export interface ConfirmationResult { + outcome: ToolConfirmationOutcome; + payload?: ToolConfirmationPayload; +} + +/** + * Waits for a confirmation response with the matching correlationId. + * + * NOTE: It is the caller's responsibility to manage the lifecycle of this wait + * via the provided AbortSignal. To prevent memory leaks and "zombie" listeners + * in the event of a lost connection (e.g. IDE crash), it is strongly recommended + * to use a signal with a timeout (e.g. AbortSignal.timeout(ms)). + * + * @param messageBus The MessageBus to listen on. + * @param correlationId The correlationId to match. + * @param signal An AbortSignal to cancel the wait and cleanup listeners. + */ +export async function awaitConfirmation( + messageBus: MessageBus, + correlationId: string, + signal: AbortSignal, +): Promise { + if (signal.aborted) { + throw new Error('Operation cancelled'); + } + + try { + for await (const [msg] of on( + messageBus, + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + { signal }, + )) { + const response = msg as ToolConfirmationResponse; + if (response.correlationId === correlationId) { + return { + outcome: + response.outcome ?? + // TODO: Remove legacy confirmed boolean fallback once migration complete + (response.confirmed + ? ToolConfirmationOutcome.ProceedOnce + : ToolConfirmationOutcome.Cancel), + payload: response.payload, + }; + } + } + } catch (error) { + if (signal.aborted || (error as Error).name === 'AbortError') { + throw new Error('Operation cancelled'); + } + throw error; + } + + // This point should only be reached if the iterator closes without resolving, + // which generally means the signal was aborted. + throw new Error('Operation cancelled'); +}