From bcc4d81d1915e718857df9d58c5c5bddf9f4c335 Mon Sep 17 00:00:00 2001 From: Allen Hutchison Date: Mon, 22 Sep 2025 12:03:20 -0700 Subject: [PATCH] feat: add message bus integration for tool confirmation (#8938) --- packages/cli/src/config/config.ts | 2 + packages/cli/src/config/settings.ts | 1 + packages/cli/src/config/settingsSchema.ts | 10 + packages/core/src/config/config.ts | 27 +- .../src/tools/message-bus-integration.test.ts | 285 ++++++++++++++++++ packages/core/src/tools/tools.ts | 129 +++++++- 6 files changed, 451 insertions(+), 3 deletions(-) create mode 100644 packages/core/src/tools/message-bus-integration.test.ts diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 80dcb88d4e..acf1b8ca93 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -667,6 +667,8 @@ export async function loadCliConfig( format: (argv.outputFormat ?? settings.output?.format) as OutputFormat, }, useModelRouter, + enableMessageBusIntegration: + settings.tools?.enableMessageBusIntegration ?? false, }); } diff --git a/packages/cli/src/config/settings.ts b/packages/cli/src/config/settings.ts index 6d73722eb1..cc28c3e3c5 100644 --- a/packages/cli/src/config/settings.ts +++ b/packages/cli/src/config/settings.ts @@ -74,6 +74,7 @@ const MIGRATION_MAP: Record = { disableAutoUpdate: 'general.disableAutoUpdate', disableUpdateNag: 'general.disableUpdateNag', dnsResolutionOrder: 'advanced.dnsResolutionOrder', + enableMessageBusIntegration: 'tools.enableMessageBusIntegration', enablePromptCompletion: 'general.enablePromptCompletion', enforcedAuthType: 'security.auth.enforcedType', excludeTools: 'tools.exclude', diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index bc9fa35af2..b27f7465f2 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -774,6 +774,16 @@ const SETTINGS_SCHEMA = { description: 'The number of lines to keep when truncating tool output.', showInDialog: true, }, + enableMessageBusIntegration: { + type: 'boolean', + label: 'Enable Message Bus Integration', + category: 'Tools', + requiresRestart: true, + default: false, + description: + 'Enable policy-based tool confirmation via message bus integration. When enabled, tools will automatically respect policy engine decisions (ALLOW/DENY/ASK_USER) without requiring individual tool implementations.', + showInDialog: true, + }, }, }, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 88f8b4f974..343b8ad198 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -251,6 +251,7 @@ export interface ConfigParameters { policyEngineConfig?: PolicyEngineConfig; output?: OutputSettings; useModelRouter?: boolean; + enableMessageBusIntegration?: boolean; } export class Config { @@ -340,6 +341,7 @@ export class Config { private readonly policyEngine: PolicyEngine; private readonly outputSettings: OutputSettings; private readonly useModelRouter: boolean; + private readonly enableMessageBusIntegration: boolean; constructor(params: ConfigParameters) { this.sessionId = params.sessionId; @@ -427,6 +429,8 @@ export class Config { this.useSmartEdit = params.useSmartEdit ?? true; this.useWriteTodos = params.useWriteTodos ?? false; this.useModelRouter = params.useModelRouter ?? false; + this.enableMessageBusIntegration = + params.enableMessageBusIntegration ?? false; this.extensionManagement = params.extensionManagement ?? true; this.storage = new Storage(this.targetDir); this.enablePromptCompletion = params.enablePromptCompletion ?? false; @@ -986,6 +990,10 @@ export class Config { return this.policyEngine; } + getEnableMessageBusIntegration(): boolean { + return this.enableMessageBusIntegration; + } + async createToolRegistry(): Promise { const registry = new ToolRegistry(this, this.eventEmitter); @@ -1019,7 +1027,24 @@ export class Config { } if (isEnabled) { - registry.registerTool(new ToolClass(...args)); + // Pass message bus to tools when feature flag is enabled + // This first implementation is only focused on the general case of + // the tool registry. + const messageBusEnabled = this.getEnableMessageBusIntegration(); + if (this.debugMode) { + console.log( + `[DEBUG] enableMessageBusIntegration setting: ${messageBusEnabled}`, + ); + } + const toolArgs = messageBusEnabled + ? [...args, this.getMessageBus()] + : args; + if (this.debugMode) { + console.log( + `[DEBUG] Registering ${className} with messageBus: ${messageBusEnabled ? 'YES' : 'NO'}`, + ); + } + registry.registerTool(new ToolClass(...toolArgs)); } }; diff --git a/packages/core/src/tools/message-bus-integration.test.ts b/packages/core/src/tools/message-bus-integration.test.ts new file mode 100644 index 0000000000..41cec0535a --- /dev/null +++ b/packages/core/src/tools/message-bus-integration.test.ts @@ -0,0 +1,285 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + beforeEach, + afterEach, + vi, + type Mock, +} from 'vitest'; +import { + BaseToolInvocation, + BaseDeclarativeTool, + Kind, + type ToolResult, +} from './tools.js'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import { + MessageBusType, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; +import { randomUUID } from 'node:crypto'; + +// Mock crypto module +vi.mock('node:crypto', () => ({ + randomUUID: vi.fn(), +})); + +interface TestParams { + testParam: string; +} + +interface TestResult extends ToolResult { + testValue: string; +} + +class TestToolInvocation extends BaseToolInvocation { + getDescription(): string { + return `Test tool with param: ${this.params.testParam}`; + } + + async execute(): Promise { + return { + llmContent: `Executed with ${this.params.testParam}`, + returnDisplay: `Test result: ${this.params.testParam}`, + testValue: this.params.testParam, + }; + } +} + +class TestTool extends BaseDeclarativeTool { + constructor(messageBus?: MessageBus) { + super( + 'test-tool', + 'Test Tool', + 'A test tool for message bus integration', + Kind.Other, + { + type: 'object', + properties: { + testParam: { type: 'string' }, + }, + required: ['testParam'], + }, + true, + false, + messageBus, + ); + } + + protected createInvocation(params: TestParams, messageBus?: MessageBus) { + return new TestToolInvocation(params, messageBus); + } +} + +describe('Message Bus Integration', () => { + let policyEngine: PolicyEngine; + let messageBus: MessageBus; + let mockUUID: Mock; + + beforeEach(() => { + vi.resetAllMocks(); + policyEngine = new PolicyEngine(); + messageBus = new MessageBus(policyEngine); + mockUUID = vi.mocked(randomUUID); + mockUUID.mockReturnValue('test-correlation-id'); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe('BaseToolInvocation with MessageBus', () => { + it('should use message bus for confirmation when available', async () => { + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + // Mock message bus publish and subscribe + const publishSpy = vi.spyOn(messageBus, 'publish'); + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe'); + + // Start confirmation process + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Verify confirmation request was published + expect(publishSpy).toHaveBeenCalledWith({ + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { + name: 'TestToolInvocation', + args: { testParam: 'test-value' }, + }, + correlationId: 'test-correlation-id', + }); + + // Verify subscription to response + expect(subscribeSpy).toHaveBeenCalledWith( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + expect.any(Function), + ); + + // Simulate confirmation response + const responseHandler = subscribeSpy.mock.calls[0][1]; + const response: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'test-correlation-id', + confirmed: true, + }; + + responseHandler(response); + + const result = await confirmationPromise; + expect(result).toBe(false); // No further confirmation needed + expect(unsubscribeSpy).toHaveBeenCalled(); + }); + + it('should reject promise when confirmation is denied', async () => { + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Simulate denial response + const responseHandler = subscribeSpy.mock.calls[0][1]; + const response: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'test-correlation-id', + confirmed: false, + }; + + responseHandler(response); + + // Should reject with error when denied + await expect(confirmationPromise).rejects.toThrow( + 'Tool execution denied by policy', + ); + }); + + it('should handle timeout', async () => { + vi.useFakeTimers(); + + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Fast-forward past timeout + vi.advanceTimersByTime(30000); + + const result = await confirmationPromise; + expect(result).toBe(false); + + vi.useRealTimers(); + }); + + it('should handle abort signal', async () => { + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + const abortController = new AbortController(); + const confirmationPromise = invocation.shouldConfirmExecute( + abortController.signal, + ); + + // Abort the operation + abortController.abort(); + + await expect(confirmationPromise).rejects.toThrow( + 'Tool confirmation aborted', + ); + }); + + it('should fall back to default behavior when no message bus', async () => { + const tool = new TestTool(); // No message bus + const invocation = tool.build({ testParam: 'test-value' }); + + const result = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(result).toBe(false); + }); + + it('should ignore responses with wrong correlation ID', async () => { + vi.useFakeTimers(); + + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Send response with wrong correlation ID + const responseHandler = subscribeSpy.mock.calls[0][1]; + const wrongResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'wrong-id', + confirmed: true, + }; + + responseHandler(wrongResponse); + + // Should timeout since correct response wasn't received + vi.advanceTimersByTime(30000); + const result = await confirmationPromise; + expect(result).toBe(false); + + vi.useRealTimers(); + }); + }); + + describe('Backward Compatibility', () => { + it('should work with existing tools that do not use message bus', async () => { + const tool = new TestTool(); // No message bus + const invocation = tool.build({ testParam: 'test-value' }); + + // Should execute normally + const result = await invocation.execute(new AbortController().signal); + expect(result.testValue).toBe('test-value'); + expect(result.llmContent).toBe('Executed with test-value'); + }); + + it('should work with tools that have message bus but use default confirmation', async () => { + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + // Should execute normally even with message bus available + const result = await invocation.execute(new AbortController().signal); + expect(result.testValue).toBe('test-value'); + expect(result.llmContent).toBe('Executed with test-value'); + }); + }); + + describe('Error Handling', () => { + it('should handle message bus publish errors gracefully', async () => { + const tool = new TestTool(messageBus); + const invocation = tool.build({ testParam: 'test-value' }); + + // Mock publish to throw error + vi.spyOn(messageBus, 'publish').mockImplementation(() => { + throw new Error('Message bus error'); + }); + + const result = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(result).toBe(false); // Should gracefully fall back + }); + }); +}); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 48cf2d2d1f..a21d7fe3df 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -10,6 +10,13 @@ import type { DiffUpdateResult } from '../ide/ide-client.js'; import type { ShellExecutionConfig } from '../services/shellExecutionService.js'; import { SchemaValidator } from '../utils/schemaValidator.js'; import type { AnsiOutput } from '../utils/terminalSerializer.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { randomUUID } from 'node:crypto'; +import { + MessageBusType, + type ToolConfirmationRequest, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; /** * Represents a validated and ready-to-execute tool call. @@ -66,7 +73,16 @@ export abstract class BaseToolInvocation< TResult extends ToolResult, > implements ToolInvocation { - constructor(readonly params: TParams) {} + constructor( + readonly params: TParams, + protected readonly messageBus?: MessageBus, + ) { + if (this.messageBus) { + console.log( + `[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`, + ); + } + } abstract getDescription(): string; @@ -77,9 +93,116 @@ export abstract class BaseToolInvocation< shouldConfirmExecute( _abortSignal: AbortSignal, ): Promise { + // If message bus is available, use it for confirmation + if (this.messageBus) { + console.log( + `[DEBUG] Using message bus for tool confirmation: ${this.constructor.name}`, + ); + return this.handleMessageBusConfirmation(_abortSignal); + } + + // Fall back to existing confirmation flow return Promise.resolve(false); } + /** + * Handle tool confirmation using the message bus. + * This method publishes a confirmation request and waits for the response. + */ + protected async handleMessageBusConfirmation( + abortSignal: AbortSignal, + ): Promise { + if (!this.messageBus) { + return false; + } + + const correlationId = randomUUID(); + const toolCall = { + name: this.constructor.name, + args: this.params as Record, + }; + + return new Promise( + (resolve, reject) => { + if (!this.messageBus) { + resolve(false); + return; + } + + let timeoutId: NodeJS.Timeout | undefined; + + // Centralized cleanup function + const cleanup = () => { + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = undefined; + } + abortSignal.removeEventListener('abort', abortHandler); + this.messageBus?.unsubscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + }; + + // Set up abort handler + const abortHandler = () => { + cleanup(); + reject(new Error('Tool confirmation aborted')); + }; + + // Check if already aborted + if (abortSignal.aborted) { + reject(new Error('Tool confirmation aborted')); + return; + } + + // Set up response handler + const responseHandler = (response: ToolConfirmationResponse) => { + if (response.correlationId === correlationId) { + cleanup(); + + if (response.confirmed) { + // Tool was confirmed, return false to indicate no further confirmation needed + resolve(false); + } else { + // Tool was denied, reject to prevent execution + reject(new Error('Tool execution denied by policy')); + } + } + }; + + // Add event listener for abort signal + abortSignal.addEventListener('abort', abortHandler); + + // Set up timeout + timeoutId = setTimeout(() => { + cleanup(); + resolve(false); + }, 30000); // 30 second timeout + + // Subscribe to response + this.messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + + // Publish confirmation request + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall, + correlationId, + }; + + try { + this.messageBus.publish(request); + } catch (_error) { + cleanup(); + resolve(false); + } + }, + ); + } + abstract execute( signal: AbortSignal, updateOutput?: (output: string | AnsiOutput) => void, @@ -159,6 +282,7 @@ export abstract class DeclarativeTool< readonly parameterSchema: unknown, readonly isOutputMarkdown: boolean = true, readonly canUpdateOutput: boolean = false, + readonly messageBus?: MessageBus, ) {} get schema(): FunctionDeclaration { @@ -282,7 +406,7 @@ export abstract class BaseDeclarativeTool< if (validationError) { throw new Error(validationError); } - return this.createInvocation(params); + return this.createInvocation(params, this.messageBus); } override validateToolParams(params: TParams): string | null { @@ -304,6 +428,7 @@ export abstract class BaseDeclarativeTool< protected abstract createInvocation( params: TParams, + messageBus?: MessageBus, ): ToolInvocation; }