mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat: add message bus integration for tool confirmation (#8938)
This commit is contained in:
@@ -667,6 +667,8 @@ export async function loadCliConfig(
|
||||
format: (argv.outputFormat ?? settings.output?.format) as OutputFormat,
|
||||
},
|
||||
useModelRouter,
|
||||
enableMessageBusIntegration:
|
||||
settings.tools?.enableMessageBusIntegration ?? false,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ const MIGRATION_MAP: Record<string, string> = {
|
||||
disableAutoUpdate: 'general.disableAutoUpdate',
|
||||
disableUpdateNag: 'general.disableUpdateNag',
|
||||
dnsResolutionOrder: 'advanced.dnsResolutionOrder',
|
||||
enableMessageBusIntegration: 'tools.enableMessageBusIntegration',
|
||||
enablePromptCompletion: 'general.enablePromptCompletion',
|
||||
enforcedAuthType: 'security.auth.enforcedType',
|
||||
excludeTools: 'tools.exclude',
|
||||
|
||||
@@ -774,6 +774,16 @@ const SETTINGS_SCHEMA = {
|
||||
description: 'The number of lines to keep when truncating tool output.',
|
||||
showInDialog: true,
|
||||
},
|
||||
enableMessageBusIntegration: {
|
||||
type: 'boolean',
|
||||
label: 'Enable Message Bus Integration',
|
||||
category: 'Tools',
|
||||
requiresRestart: true,
|
||||
default: false,
|
||||
description:
|
||||
'Enable policy-based tool confirmation via message bus integration. When enabled, tools will automatically respect policy engine decisions (ALLOW/DENY/ASK_USER) without requiring individual tool implementations.',
|
||||
showInDialog: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
|
||||
@@ -251,6 +251,7 @@ export interface ConfigParameters {
|
||||
policyEngineConfig?: PolicyEngineConfig;
|
||||
output?: OutputSettings;
|
||||
useModelRouter?: boolean;
|
||||
enableMessageBusIntegration?: boolean;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
@@ -340,6 +341,7 @@ export class Config {
|
||||
private readonly policyEngine: PolicyEngine;
|
||||
private readonly outputSettings: OutputSettings;
|
||||
private readonly useModelRouter: boolean;
|
||||
private readonly enableMessageBusIntegration: boolean;
|
||||
|
||||
constructor(params: ConfigParameters) {
|
||||
this.sessionId = params.sessionId;
|
||||
@@ -427,6 +429,8 @@ export class Config {
|
||||
this.useSmartEdit = params.useSmartEdit ?? true;
|
||||
this.useWriteTodos = params.useWriteTodos ?? false;
|
||||
this.useModelRouter = params.useModelRouter ?? false;
|
||||
this.enableMessageBusIntegration =
|
||||
params.enableMessageBusIntegration ?? false;
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
|
||||
@@ -986,6 +990,10 @@ export class Config {
|
||||
return this.policyEngine;
|
||||
}
|
||||
|
||||
getEnableMessageBusIntegration(): boolean {
|
||||
return this.enableMessageBusIntegration;
|
||||
}
|
||||
|
||||
async createToolRegistry(): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(this, this.eventEmitter);
|
||||
|
||||
@@ -1019,7 +1027,24 @@ export class Config {
|
||||
}
|
||||
|
||||
if (isEnabled) {
|
||||
registry.registerTool(new ToolClass(...args));
|
||||
// Pass message bus to tools when feature flag is enabled
|
||||
// This first implementation is only focused on the general case of
|
||||
// the tool registry.
|
||||
const messageBusEnabled = this.getEnableMessageBusIntegration();
|
||||
if (this.debugMode) {
|
||||
console.log(
|
||||
`[DEBUG] enableMessageBusIntegration setting: ${messageBusEnabled}`,
|
||||
);
|
||||
}
|
||||
const toolArgs = messageBusEnabled
|
||||
? [...args, this.getMessageBus()]
|
||||
: args;
|
||||
if (this.debugMode) {
|
||||
console.log(
|
||||
`[DEBUG] Registering ${className} with messageBus: ${messageBusEnabled ? 'YES' : 'NO'}`,
|
||||
);
|
||||
}
|
||||
registry.registerTool(new ToolClass(...toolArgs));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
285
packages/core/src/tools/message-bus-integration.test.ts
Normal file
285
packages/core/src/tools/message-bus-integration.test.ts
Normal file
@@ -0,0 +1,285 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
vi,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import {
|
||||
BaseToolInvocation,
|
||||
BaseDeclarativeTool,
|
||||
Kind,
|
||||
type ToolResult,
|
||||
} from './tools.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type ToolConfirmationResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
|
||||
// Mock crypto module
|
||||
vi.mock('node:crypto', () => ({
|
||||
randomUUID: vi.fn(),
|
||||
}));
|
||||
|
||||
interface TestParams {
|
||||
testParam: string;
|
||||
}
|
||||
|
||||
interface TestResult extends ToolResult {
|
||||
testValue: string;
|
||||
}
|
||||
|
||||
class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
|
||||
getDescription(): string {
|
||||
return `Test tool with param: ${this.params.testParam}`;
|
||||
}
|
||||
|
||||
async execute(): Promise<TestResult> {
|
||||
return {
|
||||
llmContent: `Executed with ${this.params.testParam}`,
|
||||
returnDisplay: `Test result: ${this.params.testParam}`,
|
||||
testValue: this.params.testParam,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
class TestTool extends BaseDeclarativeTool<TestParams, TestResult> {
|
||||
constructor(messageBus?: MessageBus) {
|
||||
super(
|
||||
'test-tool',
|
||||
'Test Tool',
|
||||
'A test tool for message bus integration',
|
||||
Kind.Other,
|
||||
{
|
||||
type: 'object',
|
||||
properties: {
|
||||
testParam: { type: 'string' },
|
||||
},
|
||||
required: ['testParam'],
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
protected createInvocation(params: TestParams, messageBus?: MessageBus) {
|
||||
return new TestToolInvocation(params, messageBus);
|
||||
}
|
||||
}
|
||||
|
||||
describe('Message Bus Integration', () => {
|
||||
let policyEngine: PolicyEngine;
|
||||
let messageBus: MessageBus;
|
||||
let mockUUID: Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
policyEngine = new PolicyEngine();
|
||||
messageBus = new MessageBus(policyEngine);
|
||||
mockUUID = vi.mocked(randomUUID);
|
||||
mockUUID.mockReturnValue('test-correlation-id');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('BaseToolInvocation with MessageBus', () => {
|
||||
it('should use message bus for confirmation when available', async () => {
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
// Mock message bus publish and subscribe
|
||||
const publishSpy = vi.spyOn(messageBus, 'publish');
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
const unsubscribeSpy = vi.spyOn(messageBus, 'unsubscribe');
|
||||
|
||||
// Start confirmation process
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Verify confirmation request was published
|
||||
expect(publishSpy).toHaveBeenCalledWith({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall: {
|
||||
name: 'TestToolInvocation',
|
||||
args: { testParam: 'test-value' },
|
||||
},
|
||||
correlationId: 'test-correlation-id',
|
||||
});
|
||||
|
||||
// Verify subscription to response
|
||||
expect(subscribeSpy).toHaveBeenCalledWith(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
expect.any(Function),
|
||||
);
|
||||
|
||||
// Simulate confirmation response
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const response: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'test-correlation-id',
|
||||
confirmed: true,
|
||||
};
|
||||
|
||||
responseHandler(response);
|
||||
|
||||
const result = await confirmationPromise;
|
||||
expect(result).toBe(false); // No further confirmation needed
|
||||
expect(unsubscribeSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reject promise when confirmation is denied', async () => {
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Simulate denial response
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const response: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'test-correlation-id',
|
||||
confirmed: false,
|
||||
};
|
||||
|
||||
responseHandler(response);
|
||||
|
||||
// Should reject with error when denied
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool execution denied by policy',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle timeout', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Fast-forward past timeout
|
||||
vi.advanceTimersByTime(30000);
|
||||
|
||||
const result = await confirmationPromise;
|
||||
expect(result).toBe(false);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should handle abort signal', async () => {
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
const abortController = new AbortController();
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
// Abort the operation
|
||||
abortController.abort();
|
||||
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool confirmation aborted',
|
||||
);
|
||||
});
|
||||
|
||||
it('should fall back to default behavior when no message bus', async () => {
|
||||
const tool = new TestTool(); // No message bus
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('should ignore responses with wrong correlation ID', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
const subscribeSpy = vi.spyOn(messageBus, 'subscribe');
|
||||
const confirmationPromise = invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Send response with wrong correlation ID
|
||||
const responseHandler = subscribeSpy.mock.calls[0][1];
|
||||
const wrongResponse: ToolConfirmationResponse = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: 'wrong-id',
|
||||
confirmed: true,
|
||||
};
|
||||
|
||||
responseHandler(wrongResponse);
|
||||
|
||||
// Should timeout since correct response wasn't received
|
||||
vi.advanceTimersByTime(30000);
|
||||
const result = await confirmationPromise;
|
||||
expect(result).toBe(false);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Backward Compatibility', () => {
|
||||
it('should work with existing tools that do not use message bus', async () => {
|
||||
const tool = new TestTool(); // No message bus
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
// Should execute normally
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.testValue).toBe('test-value');
|
||||
expect(result.llmContent).toBe('Executed with test-value');
|
||||
});
|
||||
|
||||
it('should work with tools that have message bus but use default confirmation', async () => {
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
// Should execute normally even with message bus available
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.testValue).toBe('test-value');
|
||||
expect(result.llmContent).toBe('Executed with test-value');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle message bus publish errors gracefully', async () => {
|
||||
const tool = new TestTool(messageBus);
|
||||
const invocation = tool.build({ testParam: 'test-value' });
|
||||
|
||||
// Mock publish to throw error
|
||||
vi.spyOn(messageBus, 'publish').mockImplementation(() => {
|
||||
throw new Error('Message bus error');
|
||||
});
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(result).toBe(false); // Should gracefully fall back
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -10,6 +10,13 @@ import type { DiffUpdateResult } from '../ide/ide-client.js';
|
||||
import type { ShellExecutionConfig } from '../services/shellExecutionService.js';
|
||||
import { SchemaValidator } from '../utils/schemaValidator.js';
|
||||
import type { AnsiOutput } from '../utils/terminalSerializer.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import {
|
||||
MessageBusType,
|
||||
type ToolConfirmationRequest,
|
||||
type ToolConfirmationResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
/**
|
||||
* Represents a validated and ready-to-execute tool call.
|
||||
@@ -66,7 +73,16 @@ export abstract class BaseToolInvocation<
|
||||
TResult extends ToolResult,
|
||||
> implements ToolInvocation<TParams, TResult>
|
||||
{
|
||||
constructor(readonly params: TParams) {}
|
||||
constructor(
|
||||
readonly params: TParams,
|
||||
protected readonly messageBus?: MessageBus,
|
||||
) {
|
||||
if (this.messageBus) {
|
||||
console.log(
|
||||
`[DEBUG] Tool ${this.constructor.name} created with messageBus: YES`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
abstract getDescription(): string;
|
||||
|
||||
@@ -77,9 +93,116 @@ export abstract class BaseToolInvocation<
|
||||
shouldConfirmExecute(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// If message bus is available, use it for confirmation
|
||||
if (this.messageBus) {
|
||||
console.log(
|
||||
`[DEBUG] Using message bus for tool confirmation: ${this.constructor.name}`,
|
||||
);
|
||||
return this.handleMessageBusConfirmation(_abortSignal);
|
||||
}
|
||||
|
||||
// Fall back to existing confirmation flow
|
||||
return Promise.resolve(false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tool confirmation using the message bus.
|
||||
* This method publishes a confirmation request and waits for the response.
|
||||
*/
|
||||
protected async handleMessageBusConfirmation(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (!this.messageBus) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const correlationId = randomUUID();
|
||||
const toolCall = {
|
||||
name: this.constructor.name,
|
||||
args: this.params as Record<string, unknown>,
|
||||
};
|
||||
|
||||
return new Promise<ToolCallConfirmationDetails | false>(
|
||||
(resolve, reject) => {
|
||||
if (!this.messageBus) {
|
||||
resolve(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
// Centralized cleanup function
|
||||
const cleanup = () => {
|
||||
if (timeoutId) {
|
||||
clearTimeout(timeoutId);
|
||||
timeoutId = undefined;
|
||||
}
|
||||
abortSignal.removeEventListener('abort', abortHandler);
|
||||
this.messageBus?.unsubscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
};
|
||||
|
||||
// Set up abort handler
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
reject(new Error('Tool confirmation aborted'));
|
||||
};
|
||||
|
||||
// Check if already aborted
|
||||
if (abortSignal.aborted) {
|
||||
reject(new Error('Tool confirmation aborted'));
|
||||
return;
|
||||
}
|
||||
|
||||
// Set up response handler
|
||||
const responseHandler = (response: ToolConfirmationResponse) => {
|
||||
if (response.correlationId === correlationId) {
|
||||
cleanup();
|
||||
|
||||
if (response.confirmed) {
|
||||
// Tool was confirmed, return false to indicate no further confirmation needed
|
||||
resolve(false);
|
||||
} else {
|
||||
// Tool was denied, reject to prevent execution
|
||||
reject(new Error('Tool execution denied by policy'));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Add event listener for abort signal
|
||||
abortSignal.addEventListener('abort', abortHandler);
|
||||
|
||||
// Set up timeout
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
resolve(false);
|
||||
}, 30000); // 30 second timeout
|
||||
|
||||
// Subscribe to response
|
||||
this.messageBus.subscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
responseHandler,
|
||||
);
|
||||
|
||||
// Publish confirmation request
|
||||
const request: ToolConfirmationRequest = {
|
||||
type: MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
toolCall,
|
||||
correlationId,
|
||||
};
|
||||
|
||||
try {
|
||||
this.messageBus.publish(request);
|
||||
} catch (_error) {
|
||||
cleanup();
|
||||
resolve(false);
|
||||
}
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
abstract execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string | AnsiOutput) => void,
|
||||
@@ -159,6 +282,7 @@ export abstract class DeclarativeTool<
|
||||
readonly parameterSchema: unknown,
|
||||
readonly isOutputMarkdown: boolean = true,
|
||||
readonly canUpdateOutput: boolean = false,
|
||||
readonly messageBus?: MessageBus,
|
||||
) {}
|
||||
|
||||
get schema(): FunctionDeclaration {
|
||||
@@ -282,7 +406,7 @@ export abstract class BaseDeclarativeTool<
|
||||
if (validationError) {
|
||||
throw new Error(validationError);
|
||||
}
|
||||
return this.createInvocation(params);
|
||||
return this.createInvocation(params, this.messageBus);
|
||||
}
|
||||
|
||||
override validateToolParams(params: TParams): string | null {
|
||||
@@ -304,6 +428,7 @@ export abstract class BaseDeclarativeTool<
|
||||
|
||||
protected abstract createInvocation(
|
||||
params: TParams,
|
||||
messageBus?: MessageBus,
|
||||
): ToolInvocation<TParams, TResult>;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user