diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index ab7c19dd70..865fc9d5ac 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -50,6 +50,9 @@ export function createMockConfig( getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'), getSessionId: vi.fn().mockReturnValue('test-session-id'), getUserTier: vi.fn(), + getEnableMessageBusIntegration: vi.fn().mockReturnValue(false), + getMessageBus: vi.fn(), + getPolicyEngine: vi.fn(), ...overrides, } as unknown as Config; diff --git a/packages/cli/src/ui/hooks/useToolScheduler.test.ts b/packages/cli/src/ui/hooks/useToolScheduler.test.ts index 4a626d93ce..9304666246 100644 --- a/packages/cli/src/ui/hooks/useToolScheduler.test.ts +++ b/packages/cli/src/ui/hooks/useToolScheduler.test.ts @@ -70,6 +70,9 @@ const mockConfig = { getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests getShellExecutionConfig: () => ({ terminalWidth: 80, terminalHeight: 24 }), + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const mockTool = new MockTool({ diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index cb86595be9..f1c58a7c31 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -24,6 +24,11 @@ export interface ToolConfirmationResponse { type: MessageBusType.TOOL_CONFIRMATION_RESPONSE; correlationId: string; confirmed: boolean; + /** + * When true, indicates that policy decision was ASK_USER and the tool should + * show its legacy confirmation UI instead of auto-proceeding. + */ + requiresUserConfirmation?: boolean; } export interface ToolPolicyRejection { diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 2855888288..16e3b2d90c 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -255,6 +255,9 @@ describe('CoreToolScheduler', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -332,6 +335,9 @@ describe('CoreToolScheduler', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -365,15 +371,18 @@ describe('CoreToolScheduler', () => { describe('getToolSuggestion', () => { it('should suggest the top N closest tool names for a typo', () => { // Create mocked tool registry + const mockToolRegistry = { + getAllToolNames: () => ['list_files', 'read_file', 'write_file'], + } as unknown as ToolRegistry; const mockConfig = { getToolRegistry: () => mockToolRegistry, getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; - const mockToolRegistry = { - getAllToolNames: () => ['list_files', 'read_file', 'write_file'], - } as unknown as ToolRegistry; // Create scheduler const scheduler = new CoreToolScheduler({ @@ -448,6 +457,9 @@ describe('CoreToolScheduler with payload', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -768,6 +780,9 @@ describe('CoreToolScheduler edit cancellation', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -874,6 +889,9 @@ describe('CoreToolScheduler YOLO mode', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -981,6 +999,9 @@ describe('CoreToolScheduler request queueing', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -1113,6 +1134,9 @@ describe('CoreToolScheduler request queueing', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -1215,6 +1239,9 @@ describe('CoreToolScheduler request queueing', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -1287,6 +1314,9 @@ describe('CoreToolScheduler request queueing', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; const testTool = new TestApprovalTool(mockConfig); @@ -1475,6 +1505,8 @@ describe('CoreToolScheduler Sequential Execution', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ @@ -1595,6 +1627,8 @@ describe('CoreToolScheduler Sequential Execution', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, } as unknown as Config; const scheduler = new CoreToolScheduler({ diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index d1d7829871..f3a28fa96f 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -42,6 +42,8 @@ import * as path from 'node:path'; import { doesToolInvocationMatch } from '../utils/tool-utils.js'; import levenshtein from 'fast-levenshtein'; import { ShellToolInvocation } from '../tools/shell.js'; +import type { ToolConfirmationRequest } from '../confirmation-bus/types.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; export type ValidatingToolCall = { status: 'validating'; @@ -352,6 +354,15 @@ export class CoreToolScheduler { this.onToolCallsUpdate = options.onToolCallsUpdate; this.getPreferredEditor = options.getPreferredEditor; this.onEditorClose = options.onEditorClose; + + // Subscribe to message bus for ASK_USER policy decisions + if (this.config.getEnableMessageBusIntegration()) { + const messageBus = this.config.getMessageBus(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_REQUEST, + this.handleToolConfirmationRequest.bind(this), + ); + } } private setStatusInternal( @@ -1160,6 +1171,26 @@ export class CoreToolScheduler { }); } + /** + * Handle tool confirmation requests from the message bus when policy decision is ASK_USER. + * This publishes a response with requiresUserConfirmation=true to signal the tool + * that it should fall back to its legacy confirmation UI. + */ + private handleToolConfirmationRequest( + request: ToolConfirmationRequest, + ): void { + // When ASK_USER policy decision is made, the message bus emits the request here. + // We respond with requiresUserConfirmation=true to tell the tool to use its + // legacy confirmation flow (which will show diffs, URLs, etc in the UI). + const messageBus = this.config.getMessageBus(); + messageBus.publish({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: request.correlationId, + confirmed: false, // Not auto-approved + requiresUserConfirmation: true, // Use legacy UI confirmation + }); + } + private async autoApproveCompatiblePendingTools( signal: AbortSignal, triggeringCallId: string, diff --git a/packages/core/src/core/nonInteractiveToolExecutor.test.ts b/packages/core/src/core/nonInteractiveToolExecutor.test.ts index 563ecab97d..0af6648485 100644 --- a/packages/core/src/core/nonInteractiveToolExecutor.test.ts +++ b/packages/core/src/core/nonInteractiveToolExecutor.test.ts @@ -62,6 +62,9 @@ describe('executeToolCall', () => { getUseSmartEdit: () => false, getUseModelRouter: () => false, getGeminiClient: () => null, // No client needed for these tests + getEnableMessageBusIntegration: () => false, + getMessageBus: () => null, + getPolicyEngine: () => null, } as unknown as Config; abortController = new AbortController(); diff --git a/packages/core/src/tools/message-bus-integration.test.ts b/packages/core/src/tools/message-bus-integration.test.ts index 41cec0535a..2ee38b0d22 100644 --- a/packages/core/src/tools/message-bus-integration.test.ts +++ b/packages/core/src/tools/message-bus-integration.test.ts @@ -52,6 +52,22 @@ class TestToolInvocation extends BaseToolInvocation { testValue: this.params.testParam, }; } + + override async shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise { + // This conditional is here to allow testing of the case where there is no message bus. + if (this.messageBus) { + const decision = await this.getMessageBusDecision(abortSignal); + if (decision === 'ALLOW') { + return false; + } + if (decision === 'DENY') { + throw new Error('Tool execution denied by policy'); + } + } + return false; + } } class TestTool extends BaseDeclarativeTool { @@ -200,7 +216,7 @@ describe('Message Bus Integration', () => { abortController.abort(); await expect(confirmationPromise).rejects.toThrow( - 'Tool confirmation aborted', + 'Tool execution denied by policy', ); }); diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index a21d7fe3df..af882da8b3 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -78,7 +78,7 @@ export abstract class BaseToolInvocation< protected readonly messageBus?: MessageBus, ) { if (this.messageBus) { - console.log( + console.debug( `[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`, ); } @@ -93,27 +93,17 @@ export abstract class BaseToolInvocation< shouldConfirmExecute( _abortSignal: AbortSignal, ): Promise { - // If message bus is available, use it for confirmation - if (this.messageBus) { - console.log( - `[DEBUG] Using message bus for tool confirmation: ${this.constructor.name}`, - ); - return this.handleMessageBusConfirmation(_abortSignal); - } - - // Fall back to existing confirmation flow + // Default implementation for tools that don't override it. return Promise.resolve(false); } - /** - * Handle tool confirmation using the message bus. - * This method publishes a confirmation request and waits for the response. - */ - protected async handleMessageBusConfirmation( + protected getMessageBusDecision( abortSignal: AbortSignal, - ): Promise { + ): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> { if (!this.messageBus) { - return false; + // If there's no message bus, we can't make a decision, so we allow. + // The legacy confirmation flow will still apply if the tool needs it. + return Promise.resolve('ALLOW'); } const correlationId = randomUUID(); @@ -122,85 +112,74 @@ export abstract class BaseToolInvocation< args: this.params as Record, }; - return new Promise( - (resolve, reject) => { - if (!this.messageBus) { - resolve(false); - return; + return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => { + if (!this.messageBus) { + resolve('ALLOW'); + return; + } + + let timeoutId: NodeJS.Timeout | undefined; + + const cleanup = () => { + if (timeoutId) { + clearTimeout(timeoutId); + timeoutId = undefined; } - - let timeoutId: NodeJS.Timeout | undefined; - - // Centralized cleanup function - const cleanup = () => { - if (timeoutId) { - clearTimeout(timeoutId); - timeoutId = undefined; - } - abortSignal.removeEventListener('abort', abortHandler); - this.messageBus?.unsubscribe( - MessageBusType.TOOL_CONFIRMATION_RESPONSE, - responseHandler, - ); - }; - - // Set up abort handler - const abortHandler = () => { - cleanup(); - reject(new Error('Tool confirmation aborted')); - }; - - // Check if already aborted - if (abortSignal.aborted) { - reject(new Error('Tool confirmation aborted')); - return; - } - - // Set up response handler - const responseHandler = (response: ToolConfirmationResponse) => { - if (response.correlationId === correlationId) { - cleanup(); - - if (response.confirmed) { - // Tool was confirmed, return false to indicate no further confirmation needed - resolve(false); - } else { - // Tool was denied, reject to prevent execution - reject(new Error('Tool execution denied by policy')); - } - } - }; - - // Add event listener for abort signal - abortSignal.addEventListener('abort', abortHandler); - - // Set up timeout - timeoutId = setTimeout(() => { - cleanup(); - resolve(false); - }, 30000); // 30 second timeout - - // Subscribe to response - this.messageBus.subscribe( + abortSignal.removeEventListener('abort', abortHandler); + this.messageBus?.unsubscribe( MessageBusType.TOOL_CONFIRMATION_RESPONSE, responseHandler, ); + }; - // Publish confirmation request - const request: ToolConfirmationRequest = { - type: MessageBusType.TOOL_CONFIRMATION_REQUEST, - toolCall, - correlationId, - }; + const abortHandler = () => { + cleanup(); + resolve('DENY'); + }; - try { - this.messageBus.publish(request); - } catch (_error) { + if (abortSignal.aborted) { + resolve('DENY'); + return; + } + + const responseHandler = (response: ToolConfirmationResponse) => { + if (response.correlationId === correlationId) { cleanup(); - resolve(false); + if (response.requiresUserConfirmation) { + resolve('ASK_USER'); + } else if (response.confirmed) { + resolve('ALLOW'); + } else { + resolve('DENY'); + } } - }, - ); + }; + + abortSignal.addEventListener('abort', abortHandler); + + timeoutId = setTimeout(() => { + cleanup(); + resolve('ASK_USER'); // Default to ASK_USER on timeout + }, 30000); + + this.messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall, + correlationId, + }; + + try { + this.messageBus.publish(request); + } catch (_error) { + cleanup(); + resolve('ALLOW'); + } + }); } abstract execute( diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index 47aa5fc2de..8dc4d3ae52 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -4,13 +4,20 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; import { WebFetchTool, parsePrompt } from './web-fetch.js'; import type { Config } from '../config/config.js'; import { ApprovalMode } from '../config/config.js'; import { ToolConfirmationOutcome } from './tools.js'; import { ToolErrorType } from './tool-error.js'; import * as fetchUtils from '../utils/fetch.js'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import { + MessageBusType, + type ToolConfirmationResponse, +} from '../confirmation-bus/types.js'; +import { randomUUID } from 'node:crypto'; import { logWebFetchFallbackAttempt, WebFetchFallbackAttemptEvent, @@ -35,6 +42,10 @@ vi.mock('../utils/fetch.js', async (importOriginal) => { }; }); +vi.mock('node:crypto', () => ({ + randomUUID: vi.fn(), +})); + describe('parsePrompt', () => { it('should extract valid URLs separated by whitespace', () => { const prompt = 'Go to https://example.com and http://google.com'; @@ -313,4 +324,229 @@ describe('WebFetchTool', () => { ); }); }); + + describe('Message Bus Integration', () => { + let policyEngine: PolicyEngine; + let messageBus: MessageBus; + let mockUUID: Mock; + + beforeEach(() => { + policyEngine = new PolicyEngine(); + messageBus = new MessageBus(policyEngine); + mockUUID = vi.mocked(randomUUID); + mockUUID.mockReturnValue('test-correlation-id'); + }); + + it('should use message bus for confirmation when available', async () => { + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + // Mock message bus publish and subscribe + const publishSpy = vi.spyOn(messageBus, 'publish'); + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe'); + + // Start confirmation process + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Verify confirmation request was published + expect(publishSpy).toHaveBeenCalledWith({ + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { + name: 'WebFetchToolInvocation', + args: { prompt: 'fetch https://example.com' }, + }, + correlationId: 'test-correlation-id', + }); + + // Verify subscription to response + expect(subscribeSpy).toHaveBeenCalledWith( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + expect.any(Function), + ); + + // Simulate confirmation response + const responseHandler = subscribeSpy.mock.calls[0][1]; + const response: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'test-correlation-id', + confirmed: true, + }; + + responseHandler(response); + + const result = await confirmationPromise; + expect(result).toBe(false); // No further confirmation needed + expect(unsubscribeSpy).toHaveBeenCalled(); + }); + + it('should reject promise when confirmation is denied via message bus', async () => { + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Simulate denial response + const responseHandler = subscribeSpy.mock.calls[0][1]; + const response: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'test-correlation-id', + confirmed: false, + }; + + responseHandler(response); + + // Should reject with error when denied + await expect(confirmationPromise).rejects.toThrow( + 'Tool execution denied by policy', + ); + }); + + it('should handle timeout gracefully', async () => { + vi.useFakeTimers(); + + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Fast-forward past timeout + await vi.advanceTimersByTimeAsync(30000); + const result = await confirmationPromise; + expect(result).not.toBe(false); + expect(result).toHaveProperty('type', 'info'); + + vi.useRealTimers(); + }); + + it('should handle abort signal during confirmation', async () => { + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const abortController = new AbortController(); + const confirmationPromise = invocation.shouldConfirmExecute( + abortController.signal, + ); + + // Abort the operation + abortController.abort(); + + await expect(confirmationPromise).rejects.toThrow( + 'Tool execution denied by policy.', + ); + }); + + it('should fall back to legacy confirmation when no message bus', async () => { + const tool = new WebFetchTool(mockConfig); // No message bus + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const result = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Should use legacy confirmation flow (returns confirmation details, not false) + expect(result).not.toBe(false); + expect(result).toHaveProperty('type', 'info'); + }); + + it('should ignore responses with wrong correlation ID', async () => { + vi.useFakeTimers(); + + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Send response with wrong correlation ID + const responseHandler = subscribeSpy.mock.calls[0][1]; + const wrongResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'wrong-id', + confirmed: true, + }; + + responseHandler(wrongResponse); + + // Should timeout since correct response wasn't received + await vi.advanceTimersByTimeAsync(30000); + const result = await confirmationPromise; + expect(result).not.toBe(false); + expect(result).toHaveProperty('type', 'info'); + + vi.useRealTimers(); + }); + + it('should handle message bus publish errors gracefully', async () => { + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + // Mock publish to throw error + vi.spyOn(messageBus, 'publish').mockImplementation(() => { + throw new Error('Message bus error'); + }); + + const result = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(result).toBe(false); // Should gracefully fall back + }); + + it('should execute normally after confirmation approval', async () => { + vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); + mockGenerateContent.mockResolvedValue({ + candidates: [ + { + content: { + parts: [{ text: 'Fetched content from https://example.com' }], + role: 'model', + }, + }, + ], + }); + + const tool = new WebFetchTool(mockConfig, messageBus); + const params = { prompt: 'fetch https://example.com' }; + const invocation = tool.build(params); + + const subscribeSpy = vi.spyOn(messageBus, 'subscribe'); + + // Start confirmation + const confirmationPromise = invocation.shouldConfirmExecute( + new AbortController().signal, + ); + + // Approve via message bus + const responseHandler = subscribeSpy.mock.calls[0][1]; + responseHandler({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: 'test-correlation-id', + confirmed: true, + }); + + await confirmationPromise; + + // Execute the tool + const result = await invocation.execute(new AbortController().signal); + expect(result.error).toBeUndefined(); + expect(result.llmContent).toContain('Fetched content'); + }); + }); }); diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 6f74ba7e36..68fa28f527 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -15,6 +15,7 @@ import { Kind, ToolConfirmationOutcome, } from './tools.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { ToolErrorType } from './tool-error.js'; import { getErrorMessage } from '../utils/errors.js'; import type { Config } from '../config/config.js'; @@ -107,8 +108,9 @@ class WebFetchToolInvocation extends BaseToolInvocation< constructor( private readonly config: Config, params: WebFetchToolParams, + messageBus?: MessageBus, ) { - super(params); + super(params, messageBus); } private async executeFallback(signal: AbortSignal): Promise { @@ -181,9 +183,22 @@ ${textContent} return `Processing URLs and instructions from prompt: "${displayPrompt}"`; } - override async shouldConfirmExecute(): Promise< - ToolCallConfirmationDetails | false - > { + override async shouldConfirmExecute( + abortSignal: AbortSignal, + ): Promise { + // Try message bus confirmation first if available + if (this.messageBus) { + const decision = await this.getMessageBusDecision(abortSignal); + if (decision === 'ALLOW') { + return false; // No confirmation needed + } + if (decision === 'DENY') { + throw new Error('Tool execution denied by policy.'); + } + // if 'ASK_USER', fall through to legacy logic + } + + // Legacy confirmation flow (no message bus OR policy decision was ASK_USER) if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) { return false; } @@ -366,7 +381,10 @@ export class WebFetchTool extends BaseDeclarativeTool< > { static readonly Name: string = 'web_fetch'; - constructor(private readonly config: Config) { + constructor( + private readonly config: Config, + messageBus?: MessageBus, + ) { super( WebFetchTool.Name, 'WebFetch', @@ -383,6 +401,9 @@ export class WebFetchTool extends BaseDeclarativeTool< required: ['prompt'], type: 'object', }, + true, // isOutputMarkdown + false, // canUpdateOutput + messageBus, ); const proxy = config.getProxy(); if (proxy) { @@ -412,7 +433,8 @@ export class WebFetchTool extends BaseDeclarativeTool< protected createInvocation( params: WebFetchToolParams, + messageBus?: MessageBus, ): ToolInvocation { - return new WebFetchToolInvocation(this.config, params); + return new WebFetchToolInvocation(this.config, params, messageBus); } }