feat(core): Introduce message bus for tool execution confirmation (#11544)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Allen Hutchison
2025-10-24 13:04:40 -07:00
committed by GitHub
parent 63a90836fe
commit b188a51c32
15 changed files with 224 additions and 92 deletions
+34 -24
View File
@@ -46,6 +46,7 @@ import levenshtein from 'fast-levenshtein';
import { ShellToolInvocation } from '../tools/shell.js';
import type { ToolConfirmationRequest } from '../confirmation-bus/types.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
export type ValidatingToolCall = {
status: 'validating';
@@ -331,6 +332,13 @@ interface CoreToolSchedulerOptions {
}
export class CoreToolScheduler {
// Static WeakMap to track which MessageBus instances already have a handler subscribed
// This prevents duplicate subscriptions when multiple CoreToolScheduler instances are created
private static subscribedMessageBuses = new WeakMap<
MessageBus,
(request: ToolConfirmationRequest) => void
>();
private toolCalls: ToolCall[] = [];
private outputUpdateHandler?: OutputUpdateHandler;
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
@@ -356,12 +364,34 @@ export class CoreToolScheduler {
this.onEditorClose = options.onEditorClose;
// Subscribe to message bus for ASK_USER policy decisions
// Use a static WeakMap to ensure we only subscribe ONCE per MessageBus instance
// This prevents memory leaks when multiple CoreToolScheduler instances are created
// (e.g., on every React render, or for each non-interactive tool call)
if (this.config.getEnableMessageBusIntegration()) {
const messageBus = this.config.getMessageBus();
messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_REQUEST,
this.handleToolConfirmationRequest.bind(this),
);
// Check if we've already subscribed a handler to this message bus
if (!CoreToolScheduler.subscribedMessageBuses.has(messageBus)) {
// Create a shared handler that will be used for this message bus
const sharedHandler = (request: ToolConfirmationRequest) => {
// When ASK_USER policy decision is made, respond with requiresUserConfirmation=true
// to tell tools to use their legacy confirmation flow
messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: request.correlationId,
confirmed: false,
requiresUserConfirmation: true,
});
};
messageBus.subscribe(
MessageBusType.TOOL_CONFIRMATION_REQUEST,
sharedHandler,
);
// Store the handler in the WeakMap so we don't subscribe again
CoreToolScheduler.subscribedMessageBuses.set(messageBus, sharedHandler);
}
}
}
@@ -1170,26 +1200,6 @@ export class CoreToolScheduler {
});
}
/**
* Handle tool confirmation requests from the message bus when policy decision is ASK_USER.
* This publishes a response with requiresUserConfirmation=true to signal the tool
* that it should fall back to its legacy confirmation UI.
*/
private handleToolConfirmationRequest(
request: ToolConfirmationRequest,
): void {
// When ASK_USER policy decision is made, the message bus emits the request here.
// We respond with requiresUserConfirmation=true to tell the tool to use its
// legacy confirmation flow (which will show diffs, URLs, etc in the UI).
const messageBus = this.config.getMessageBus();
messageBus.publish({
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: request.correlationId,
confirmed: false, // Not auto-approved
requiresUserConfirmation: true, // Use legacy UI confirmation
});
}
private isAutoApproved(toolCall: ValidatingToolCall): boolean {
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
return true;
@@ -19,15 +19,17 @@ export async function executeToolCall(
abortSignal: AbortSignal,
): Promise<CompletedToolCall> {
return new Promise<CompletedToolCall>((resolve, reject) => {
new CoreToolScheduler({
const scheduler = new CoreToolScheduler({
config,
getPreferredEditor: () => undefined,
onEditorClose: () => {},
onAllToolCallsComplete: async (completedToolCalls) => {
resolve(completedToolCalls[0]);
},
})
.schedule(toolCallRequest, abortSignal)
.catch(reject);
});
scheduler.schedule(toolCallRequest, abortSignal).catch((error) => {
reject(error);
});
});
}
+37 -8
View File
@@ -14,7 +14,13 @@ import type {
ToolLocation,
ToolResult,
} from './tools.js';
import { BaseDeclarativeTool, Kind, ToolConfirmationOutcome } from './tools.js';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
ToolConfirmationOutcome,
} from './tools.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
@@ -102,13 +108,21 @@ interface CalculatedEdit {
isNewFile: boolean;
}
class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
class EditToolInvocation
extends BaseToolInvocation<EditToolParams, ToolResult>
implements ToolInvocation<EditToolParams, ToolResult>
{
constructor(
private readonly config: Config,
public params: EditToolParams,
) {}
params: EditToolParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
) {
super(params, messageBus, toolName, displayName);
}
toolLocations(): ToolLocation[] {
override toolLocations(): ToolLocation[] {
return [{ path: this.params.file_path }];
}
@@ -241,7 +255,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
* Handles the confirmation prompt for the Edit tool in the CLI.
* It needs to calculate the diff to show the user.
*/
async shouldConfirmExecute(
protected override async getConfirmationDetails(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
@@ -467,7 +481,10 @@ export class EditTool
{
static readonly Name = EDIT_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
EditTool.Name,
'Edit',
@@ -510,6 +527,9 @@ Expectation for required parameters:
required: ['file_path', 'old_string', 'new_string'],
type: 'object',
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -540,8 +560,17 @@ Expectation for required parameters:
protected createInvocation(
params: EditToolParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
): ToolInvocation<EditToolParams, ToolResult> {
return new EditToolInvocation(this.config, params);
return new EditToolInvocation(
this.config,
params,
messageBus ?? this.messageBus,
toolName ?? this.name,
displayName ?? this.displayName,
);
}
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
+4
View File
@@ -20,6 +20,7 @@ import {
import type { CallableTool, FunctionCall, Part } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
type ToolParams = Record<string, unknown>;
@@ -244,6 +245,9 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
protected createInvocation(
params: ToolParams,
_messageBus?: MessageBus,
_toolName?: string,
_displayName?: string,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredMCPToolInvocation(
this.mcpTool,
+28 -6
View File
@@ -24,6 +24,7 @@ import type {
} from './modifiable-tool.js';
import { ToolErrorType } from './tool-error.js';
import { MEMORY_TOOL_NAME } from './tool-names.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
const memoryToolSchemaData: FunctionDeclaration = {
name: MEMORY_TOOL_NAME,
@@ -58,8 +59,7 @@ Do NOT use this tool:
## Parameters
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".
`;
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".`;
export const DEFAULT_CONTEXT_FILENAME = 'GEMINI.md';
export const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
@@ -177,12 +177,21 @@ class MemoryToolInvocation extends BaseToolInvocation<
> {
private static readonly allowlist: Set<string> = new Set();
constructor(
params: SaveMemoryParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
) {
super(params, messageBus, toolName, displayName);
}
getDescription(): string {
const memoryFilePath = getGlobalMemoryFilePath();
return `in ${tildeifyPath(memoryFilePath)}`;
}
override async shouldConfirmExecute(
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolEditConfirmationDetails | false> {
const memoryFilePath = getGlobalMemoryFilePath();
@@ -291,13 +300,16 @@ export class MemoryTool
{
static readonly Name = MEMORY_TOOL_NAME;
constructor() {
constructor(messageBus?: MessageBus) {
super(
MemoryTool.Name,
'Save Memory',
memoryToolDescription,
Kind.Think,
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
true,
false,
messageBus,
);
}
@@ -311,8 +323,18 @@ export class MemoryTool
return null;
}
protected createInvocation(params: SaveMemoryParams) {
return new MemoryToolInvocation(params);
protected createInvocation(
params: SaveMemoryParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
) {
return new MemoryToolInvocation(
params,
messageBus ?? this.messageBus,
toolName ?? this.name,
displayName ?? this.displayName,
);
}
static async performAddMemoryEntry(
+16 -4
View File
@@ -41,6 +41,7 @@ import {
stripShellWrapper,
} from '../utils/shell-utils.js';
import { SHELL_TOOL_NAME } from './tool-names.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
export const OUTPUT_UPDATE_INTERVAL_MS = 1000;
@@ -58,8 +59,9 @@ export class ShellToolInvocation extends BaseToolInvocation<
private readonly config: Config,
params: ShellToolParams,
private readonly allowlist: Set<string>,
messageBus?: MessageBus,
) {
super(params);
super(params, messageBus);
}
getDescription(): string {
@@ -76,7 +78,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
return description;
}
override async shouldConfirmExecute(
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const command = stripShellWrapper(this.params.command);
@@ -372,7 +374,10 @@ export class ShellTool extends BaseDeclarativeTool<
private allowlist: Set<string> = new Set();
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
void initializeShellParsers().catch(() => {
// Errors are surfaced when parsing commands.
});
@@ -403,6 +408,7 @@ export class ShellTool extends BaseDeclarativeTool<
},
false, // output is not markdown
true, // output can be updated
messageBus,
);
}
@@ -444,7 +450,13 @@ export class ShellTool extends BaseDeclarativeTool<
protected createInvocation(
params: ShellToolParams,
messageBus?: MessageBus,
): ToolInvocation<ShellToolParams, ToolResult> {
return new ShellToolInvocation(this.config, params, this.allowlist);
return new ShellToolInvocation(
this.config,
params,
this.allowlist,
messageBus,
);
}
}
+29 -7
View File
@@ -10,6 +10,7 @@ import * as crypto from 'node:crypto';
import * as Diff from 'diff';
import {
BaseDeclarativeTool,
BaseToolInvocation,
Kind,
type ToolCallConfirmationDetails,
ToolConfirmationOutcome,
@@ -19,6 +20,7 @@ import {
type ToolResult,
type ToolResultDisplay,
} from './tools.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js';
import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
@@ -369,13 +371,21 @@ interface CalculatedEdit {
originalLineEnding: '\r\n' | '\n';
}
class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
class EditToolInvocation
extends BaseToolInvocation<EditToolParams, ToolResult>
implements ToolInvocation<EditToolParams, ToolResult>
{
constructor(
private readonly config: Config,
public params: EditToolParams,
) {}
params: EditToolParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
) {
super(params, messageBus, toolName, displayName);
}
toolLocations(): ToolLocation[] {
override toolLocations(): ToolLocation[] {
return [{ path: this.params.file_path }];
}
@@ -602,7 +612,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
* Handles the confirmation prompt for the Edit tool in the CLI.
* It needs to calculate the diff to show the user.
*/
async shouldConfirmExecute(
protected override async getConfirmationDetails(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
@@ -818,7 +828,10 @@ export class SmartEditTool
{
static readonly Name = EDIT_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
SmartEditTool.Name,
'Edit',
@@ -875,6 +888,9 @@ A good instruction should concisely answer:
required: ['file_path', 'instruction', 'old_string', 'new_string'],
type: 'object',
},
true, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
}
@@ -914,7 +930,13 @@ A good instruction should concisely answer:
protected createInvocation(
params: EditToolParams,
): ToolInvocation<EditToolParams, ToolResult> {
return new EditToolInvocation(this.config, params);
return new EditToolInvocation(
this.config,
params,
this.messageBus,
this.name,
this.displayName,
);
}
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
+4
View File
@@ -21,6 +21,7 @@ import { parse } from 'shell-quote';
import { ToolErrorType } from './tool-error.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
import type { EventEmitter } from 'node:events';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { debugLogger } from '../utils/debugLogger.js';
type ToolParams = Record<string, unknown>;
@@ -162,6 +163,9 @@ Signal: Signal number or \`(none)\` if no signal was received.
protected createInvocation(
params: ToolParams,
_messageBus?: MessageBus,
_toolName?: string,
_displayName?: string,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredToolInvocation(this.config, this.name, params);
}
+29 -17
View File
@@ -104,25 +104,37 @@ export abstract class BaseToolInvocation<
}
if (decision === 'ASK_USER') {
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
},
};
return confirmationDetails;
return this.getConfirmationDetails(abortSignal);
}
}
return false;
// When no message bus, use default confirmation flow
return this.getConfirmationDetails(abortSignal);
}
/**
* Subclasses should override this method to provide custom confirmation UI
* when the policy engine's decision is 'ASK_USER'.
* The base implementation provides a generic confirmation prompt.
*/
protected async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const confirmationDetails: ToolCallConfirmationDetails = {
type: 'info',
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
},
};
return confirmationDetails;
}
protected getMessageBusDecision(
+2 -2
View File
@@ -521,7 +521,7 @@ describe('WebFetchTool', () => {
// Should reject with error when denied
await expect(confirmationPromise).rejects.toThrow(
'Tool execution denied by policy',
'Tool execution for "WebFetch" denied by policy.',
);
});
@@ -559,7 +559,7 @@ describe('WebFetchTool', () => {
abortController.abort();
await expect(confirmationPromise).rejects.toThrow(
'Tool execution denied by policy.',
'Tool execution for "WebFetch" denied by policy.',
);
});
+2 -14
View File
@@ -205,21 +205,9 @@ ${textContent}
return `Processing URLs and instructions from prompt: "${displayPrompt}"`;
}
override async shouldConfirmExecute(
abortSignal: AbortSignal,
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
// Try message bus confirmation first if available
if (this.messageBus) {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
return false; // No confirmation needed
}
if (decision === 'DENY') {
throw new Error('Tool execution denied by policy.');
}
// if 'ASK_USER', fall through to legacy logic
}
// Legacy confirmation flow (no message bus OR policy decision was ASK_USER)
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
+20 -4
View File
@@ -42,6 +42,7 @@ import { FileOperationEvent } from '../telemetry/types.js';
import { FileOperation } from '../telemetry/metrics.js';
import { getSpecificMimeType } from '../utils/fileUtils.js';
import { getLanguageFromFilePath } from '../utils/language-detection.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
/**
* Parameters for the WriteFile tool
@@ -144,8 +145,11 @@ class WriteFileToolInvocation extends BaseToolInvocation<
constructor(
private readonly config: Config,
params: WriteFileToolParams,
messageBus?: MessageBus,
toolName?: string,
displayName?: string,
) {
super(params);
super(params, messageBus, toolName, displayName);
}
override toolLocations(): ToolLocation[] {
@@ -160,7 +164,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
return `Writing to ${shortenPath(relativePath)}`;
}
override async shouldConfirmExecute(
protected override async getConfirmationDetails(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
@@ -392,7 +396,10 @@ export class WriteFileTool
{
static readonly Name = WRITE_FILE_TOOL_NAME;
constructor(private readonly config: Config) {
constructor(
private readonly config: Config,
messageBus?: MessageBus,
) {
super(
WriteFileTool.Name,
'WriteFile',
@@ -415,6 +422,9 @@ export class WriteFileTool
required: ['file_path', 'content'],
type: 'object',
},
true,
false,
messageBus,
);
}
@@ -458,7 +468,13 @@ export class WriteFileTool
protected createInvocation(
params: WriteFileToolParams,
): ToolInvocation<WriteFileToolParams, ToolResult> {
return new WriteFileToolInvocation(this.config, params);
return new WriteFileToolInvocation(
this.config,
params,
this.messageBus,
this.name,
this.displayName,
);
}
getModifyContext(
+4
View File
@@ -12,6 +12,7 @@ import {
type Todo,
type ToolResult,
} from './tools.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { WRITE_TODOS_TOOL_NAME } from './tool-names.js';
const TODO_STATUSES = [
@@ -204,6 +205,9 @@ export class WriteTodosTool extends BaseDeclarativeTool<
protected createInvocation(
params: WriteTodosToolParams,
_messageBus?: MessageBus,
_toolName?: string,
_displayName?: string,
): ToolInvocation<WriteTodosToolParams, ToolResult> {
return new WriteTodosToolInvocation(params);
}
+7
View File
@@ -70,6 +70,13 @@ export class FatalCancellationError extends FatalError {
}
}
export class CanceledError extends Error {
constructor(message = 'The operation was canceled.') {
super(message);
this.name = 'CanceledError';
}
}
export class ForbiddenError extends Error {}
export class UnauthorizedError extends Error {}
export class BadRequestError extends Error {}