feat: add message bus integration for tool confirmation (#8938)

This commit is contained in:
Allen Hutchison
2025-09-22 12:03:20 -07:00
committed by GitHub
parent 6c559e2338
commit bcc4d81d19
6 changed files with 451 additions and 3 deletions

View File

@@ -667,6 +667,8 @@ export async function loadCliConfig(
format: (argv.outputFormat ?? settings.output?.format) as OutputFormat,
},
useModelRouter,
enableMessageBusIntegration:
settings.tools?.enableMessageBusIntegration ?? false,
});
}

View File

@@ -74,6 +74,7 @@ const MIGRATION_MAP: Record<string, string> = {
disableAutoUpdate: 'general.disableAutoUpdate',
disableUpdateNag: 'general.disableUpdateNag',
dnsResolutionOrder: 'advanced.dnsResolutionOrder',
enableMessageBusIntegration: 'tools.enableMessageBusIntegration',
enablePromptCompletion: 'general.enablePromptCompletion',
enforcedAuthType: 'security.auth.enforcedType',
excludeTools: 'tools.exclude',

View File

@@ -774,6 +774,16 @@ const SETTINGS_SCHEMA = {
description: 'The number of lines to keep when truncating tool output.',
showInDialog: true,
},
enableMessageBusIntegration: {
type: 'boolean',
label: 'Enable Message Bus Integration',
category: 'Tools',
requiresRestart: true,
default: false,
description:
'Enable policy-based tool confirmation via message bus integration. When enabled, tools will automatically respect policy engine decisions (ALLOW/DENY/ASK_USER) without requiring individual tool implementations.',
showInDialog: true,
},
},
},

View File

@@ -251,6 +251,7 @@ export interface ConfigParameters {
policyEngineConfig?: PolicyEngineConfig;
output?: OutputSettings;
useModelRouter?: boolean;
enableMessageBusIntegration?: boolean;
}
export class Config {
@@ -340,6 +341,7 @@ export class Config {
private readonly policyEngine: PolicyEngine;
private readonly outputSettings: OutputSettings;
private readonly useModelRouter: boolean;
private readonly enableMessageBusIntegration: boolean;
constructor(params: ConfigParameters) {
this.sessionId = params.sessionId;
@@ -427,6 +429,8 @@ export class Config {
this.useSmartEdit = params.useSmartEdit ?? true;
this.useWriteTodos = params.useWriteTodos ?? false;
this.useModelRouter = params.useModelRouter ?? false;
this.enableMessageBusIntegration =
params.enableMessageBusIntegration ?? false;
this.extensionManagement = params.extensionManagement ?? true;
this.storage = new Storage(this.targetDir);
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
@@ -986,6 +990,10 @@ export class Config {
return this.policyEngine;
}
getEnableMessageBusIntegration(): boolean {
return this.enableMessageBusIntegration;
}
async createToolRegistry(): Promise<ToolRegistry> {
const registry = new ToolRegistry(this, this.eventEmitter);
@@ -1019,7 +1027,24 @@ export class Config {
}
if (isEnabled) {
registry.registerTool(new ToolClass(...args));
// Pass message bus to tools when feature flag is enabled
// This first implementation is only focused on the general case of
// the tool registry.
const messageBusEnabled = this.getEnableMessageBusIntegration();
if (this.debugMode) {
console.log(
`[DEBUG] enableMessageBusIntegration setting: ${messageBusEnabled}`,
);
}
const toolArgs = messageBusEnabled
? [...args, this.getMessageBus()]
: args;
if (this.debugMode) {
console.log(
`[DEBUG] Registering ${className} with messageBus: ${messageBusEnabled ? 'YES' : 'NO'}`,
);
}
registry.registerTool(new ToolClass(...toolArgs));
}
};

View File

@@ -0,0 +1,285 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
beforeEach,
afterEach,
vi,
type Mock,
} from 'vitest';
import {
BaseToolInvocation,
BaseDeclarativeTool,
Kind,
type ToolResult,
} from './tools.js';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { PolicyEngine } from '../policy/policy-engine.js';
import {
MessageBusType,
type ToolConfirmationResponse,
} from '../confirmation-bus/types.js';
import { randomUUID } from 'node:crypto';
// Mock crypto module
vi.mock('node:crypto', () => ({
randomUUID: vi.fn(),
}));
interface TestParams {
testParam: string;
}
interface TestResult extends ToolResult {
testValue: string;
}
class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
getDescription(): string {
return `Test tool with param: ${this.params.testParam}`;
}
async execute(): Promise<TestResult> {
return {
llmContent: `Executed with ${this.params.testParam}`,
returnDisplay: `Test result: ${this.params.testParam}`,
testValue: this.params.testParam,
};
}
}
class TestTool extends BaseDeclarativeTool<TestParams, TestResult> {
constructor(messageBus?: MessageBus) {
super(
'test-tool',
'Test Tool',
'A test tool for message bus integration',
Kind.Other,
{
type: 'object',
properties: {
testParam: { type: 'string' },
},
required: ['testParam'],
},
true,
false,
messageBus,
);
}
protected createInvocation(params: TestParams, messageBus?: MessageBus) {
return new TestToolInvocation(params, messageBus);
}
}
describe('Message Bus Integration', () => {
let policyEngine: PolicyEngine;
let messageBus: MessageBus;
let mockUUID: Mock;
beforeEach(() => {
vi.resetAllMocks();
policyEngine = new PolicyEngine();
messageBus = new MessageBus(policyEngine);
mockUUID = vi.mocked(randomUUID);
mockUUID.mockReturnValue('test-correlation-id');
});
afterEach(() => {
vi.restoreAllMocks();
});
describe('BaseToolInvocation with MessageBus', () => {
it('should use message bus for confirmation when available', async () => {
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
// Mock message bus publish and subscribe
const publishSpy = vi.spyOn(messageBus, 'publish');
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe');
// Start confirmation process
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Verify confirmation request was published
expect(publishSpy).toHaveBeenCalledWith({
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall: {
name: 'TestToolInvocation',
args: { testParam: 'test-value' },
},
correlationId: 'test-correlation-id',
});
// Verify subscription to response
expect(subscribeSpy).toHaveBeenCalledWith(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
expect.any(Function),
);
// Simulate confirmation response
const responseHandler = subscribeSpy.mock.calls[0][1];
const response: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'test-correlation-id',
confirmed: true,
};
responseHandler(response);
const result = await confirmationPromise;
expect(result).toBe(false); // No further confirmation needed
expect(unsubscribeSpy).toHaveBeenCalled();
});
it('should reject promise when confirmation is denied', async () => {
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Simulate denial response
const responseHandler = subscribeSpy.mock.calls[0][1];
const response: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'test-correlation-id',
confirmed: false,
};
responseHandler(response);
// Should reject with error when denied
await expect(confirmationPromise).rejects.toThrow(
'Tool execution denied by policy',
);
});
it('should handle timeout', async () => {
vi.useFakeTimers();
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Fast-forward past timeout
vi.advanceTimersByTime(30000);
const result = await confirmationPromise;
expect(result).toBe(false);
vi.useRealTimers();
});
it('should handle abort signal', async () => {
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
const abortController = new AbortController();
const confirmationPromise = invocation.shouldConfirmExecute(
abortController.signal,
);
// Abort the operation
abortController.abort();
await expect(confirmationPromise).rejects.toThrow(
'Tool confirmation aborted',
);
});
it('should fall back to default behavior when no message bus', async () => {
const tool = new TestTool(); // No message bus
const invocation = tool.build({ testParam: 'test-value' });
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(result).toBe(false);
});
it('should ignore responses with wrong correlation ID', async () => {
vi.useFakeTimers();
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
const confirmationPromise = invocation.shouldConfirmExecute(
new AbortController().signal,
);
// Send response with wrong correlation ID
const responseHandler = subscribeSpy.mock.calls[0][1];
const wrongResponse: ToolConfirmationResponse = {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: 'wrong-id',
confirmed: true,
};
responseHandler(wrongResponse);
// Should timeout since correct response wasn't received
vi.advanceTimersByTime(30000);
const result = await confirmationPromise;
expect(result).toBe(false);
vi.useRealTimers();
});
});
describe('Backward Compatibility', () => {
it('should work with existing tools that do not use message bus', async () => {
const tool = new TestTool(); // No message bus
const invocation = tool.build({ testParam: 'test-value' });
// Should execute normally
const result = await invocation.execute(new AbortController().signal);
expect(result.testValue).toBe('test-value');
expect(result.llmContent).toBe('Executed with test-value');
});
it('should work with tools that have message bus but use default confirmation', async () => {
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
// Should execute normally even with message bus available
const result = await invocation.execute(new AbortController().signal);
expect(result.testValue).toBe('test-value');
expect(result.llmContent).toBe('Executed with test-value');
});
});
describe('Error Handling', () => {
it('should handle message bus publish errors gracefully', async () => {
const tool = new TestTool(messageBus);
const invocation = tool.build({ testParam: 'test-value' });
// Mock publish to throw error
vi.spyOn(messageBus, 'publish').mockImplementation(() => {
throw new Error('Message bus error');
});
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(result).toBe(false); // Should gracefully fall back
});
});
});

View File

@@ -10,6 +10,13 @@ import type { DiffUpdateResult } from '../ide/ide-client.js';
import type { ShellExecutionConfig } from '../services/shellExecutionService.js';
import { SchemaValidator } from '../utils/schemaValidator.js';
import type { AnsiOutput } from '../utils/terminalSerializer.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { randomUUID } from 'node:crypto';
import {
MessageBusType,
type ToolConfirmationRequest,
type ToolConfirmationResponse,
} from '../confirmation-bus/types.js';
/**
* Represents a validated and ready-to-execute tool call.
@@ -66,7 +73,16 @@ export abstract class BaseToolInvocation<
TResult extends ToolResult,
> implements ToolInvocation<TParams, TResult>
{
constructor(readonly params: TParams) {}
constructor(
readonly params: TParams,
protected readonly messageBus?: MessageBus,
) {
if (this.messageBus) {
console.log(
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
);
}
}
abstract getDescription(): string;
@@ -77,9 +93,116 @@ export abstract class BaseToolInvocation<
shouldConfirmExecute(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
// If message bus is available, use it for confirmation
if (this.messageBus) {
console.log(
`[DEBUG] Using message bus for tool confirmation: ${this.constructor.name}`,
);
return this.handleMessageBusConfirmation(_abortSignal);
}
// Fall back to existing confirmation flow
return Promise.resolve(false);
}
/**
* Handle tool confirmation using the message bus.
* This method publishes a confirmation request and waits for the response.
*/
protected async handleMessageBusConfirmation(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (!this.messageBus) {
return false;
}
const correlationId = randomUUID();
const toolCall = {
name: this.constructor.name,
args: this.params as Record<string, unknown>,
};
return new Promise<ToolCallConfirmationDetails | false>(
(resolve, reject) => {
if (!this.messageBus) {
resolve(false);
return;
}
let timeoutId: NodeJS.Timeout | undefined;
// Centralized cleanup function
const cleanup = () => {
if (timeoutId) {
clearTimeout(timeoutId);
timeoutId = undefined;
}
abortSignal.removeEventListener('abort', abortHandler);
this.messageBus?.unsubscribe(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
responseHandler,
);
};
// Set up abort handler
const abortHandler = () => {
cleanup();
reject(new Error('Tool confirmation aborted'));
};
// Check if already aborted
if (abortSignal.aborted) {
reject(new Error('Tool confirmation aborted'));
return;
}
// Set up response handler
const responseHandler = (response: ToolConfirmationResponse) => {
if (response.correlationId === correlationId) {
cleanup();
if (response.confirmed) {
// Tool was confirmed, return false to indicate no further confirmation needed
resolve(false);
} else {
// Tool was denied, reject to prevent execution
reject(new Error('Tool execution denied by policy'));
}
}
};
// Add event listener for abort signal
abortSignal.addEventListener('abort', abortHandler);
// Set up timeout
timeoutId = setTimeout(() => {
cleanup();
resolve(false);
}, 30000); // 30 second timeout
// Subscribe to response
this.messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
responseHandler,
);
// Publish confirmation request
const request: ToolConfirmationRequest = {
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
toolCall,
correlationId,
};
try {
this.messageBus.publish(request);
} catch (_error) {
cleanup();
resolve(false);
}
},
);
}
abstract execute(
signal: AbortSignal,
updateOutput?: (output: string | AnsiOutput) => void,
@@ -159,6 +282,7 @@ export abstract class DeclarativeTool<
readonly parameterSchema: unknown,
readonly isOutputMarkdown: boolean = true,
readonly canUpdateOutput: boolean = false,
readonly messageBus?: MessageBus,
) {}
get schema(): FunctionDeclaration {
@@ -282,7 +406,7 @@ export abstract class BaseDeclarativeTool<
if (validationError) {
throw new Error(validationError);
}
return this.createInvocation(params);
return this.createInvocation(params, this.messageBus);
}
override validateToolParams(params: TParams): string | null {
@@ -304,6 +428,7 @@ export abstract class BaseDeclarativeTool<
protected abstract createInvocation(
params: TParams,
messageBus?: MessageBus,
): ToolInvocation<TParams, TResult>;
}