mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-25 04:24:51 -07:00
feat(core): Introduce message bus for tool execution confirmation (#11544)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -46,6 +46,7 @@ import levenshtein from 'fast-levenshtein';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import type { ToolConfirmationRequest } from '../confirmation-bus/types.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
export type ValidatingToolCall = {
|
||||
status: 'validating';
|
||||
@@ -331,6 +332,13 @@ interface CoreToolSchedulerOptions {
|
||||
}
|
||||
|
||||
export class CoreToolScheduler {
|
||||
// Static WeakMap to track which MessageBus instances already have a handler subscribed
|
||||
// This prevents duplicate subscriptions when multiple CoreToolScheduler instances are created
|
||||
private static subscribedMessageBuses = new WeakMap<
|
||||
MessageBus,
|
||||
(request: ToolConfirmationRequest) => void
|
||||
>();
|
||||
|
||||
private toolCalls: ToolCall[] = [];
|
||||
private outputUpdateHandler?: OutputUpdateHandler;
|
||||
private onAllToolCallsComplete?: AllToolCallsCompleteHandler;
|
||||
@@ -356,12 +364,34 @@ export class CoreToolScheduler {
|
||||
this.onEditorClose = options.onEditorClose;
|
||||
|
||||
// Subscribe to message bus for ASK_USER policy decisions
|
||||
// Use a static WeakMap to ensure we only subscribe ONCE per MessageBus instance
|
||||
// This prevents memory leaks when multiple CoreToolScheduler instances are created
|
||||
// (e.g., on every React render, or for each non-interactive tool call)
|
||||
if (this.config.getEnableMessageBusIntegration()) {
|
||||
const messageBus = this.config.getMessageBus();
|
||||
messageBus.subscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
this.handleToolConfirmationRequest.bind(this),
|
||||
);
|
||||
|
||||
// Check if we've already subscribed a handler to this message bus
|
||||
if (!CoreToolScheduler.subscribedMessageBuses.has(messageBus)) {
|
||||
// Create a shared handler that will be used for this message bus
|
||||
const sharedHandler = (request: ToolConfirmationRequest) => {
|
||||
// When ASK_USER policy decision is made, respond with requiresUserConfirmation=true
|
||||
// to tell tools to use their legacy confirmation flow
|
||||
messageBus.publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
confirmed: false,
|
||||
requiresUserConfirmation: true,
|
||||
});
|
||||
};
|
||||
|
||||
messageBus.subscribe(
|
||||
MessageBusType.TOOL_CONFIRMATION_REQUEST,
|
||||
sharedHandler,
|
||||
);
|
||||
|
||||
// Store the handler in the WeakMap so we don't subscribe again
|
||||
CoreToolScheduler.subscribedMessageBuses.set(messageBus, sharedHandler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1170,26 +1200,6 @@ export class CoreToolScheduler {
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tool confirmation requests from the message bus when policy decision is ASK_USER.
|
||||
* This publishes a response with requiresUserConfirmation=true to signal the tool
|
||||
* that it should fall back to its legacy confirmation UI.
|
||||
*/
|
||||
private handleToolConfirmationRequest(
|
||||
request: ToolConfirmationRequest,
|
||||
): void {
|
||||
// When ASK_USER policy decision is made, the message bus emits the request here.
|
||||
// We respond with requiresUserConfirmation=true to tell the tool to use its
|
||||
// legacy confirmation flow (which will show diffs, URLs, etc in the UI).
|
||||
const messageBus = this.config.getMessageBus();
|
||||
messageBus.publish({
|
||||
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
confirmed: false, // Not auto-approved
|
||||
requiresUserConfirmation: true, // Use legacy UI confirmation
|
||||
});
|
||||
}
|
||||
|
||||
private isAutoApproved(toolCall: ValidatingToolCall): boolean {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
return true;
|
||||
|
||||
@@ -19,15 +19,17 @@ export async function executeToolCall(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CompletedToolCall> {
|
||||
return new Promise<CompletedToolCall>((resolve, reject) => {
|
||||
new CoreToolScheduler({
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config,
|
||||
getPreferredEditor: () => undefined,
|
||||
onEditorClose: () => {},
|
||||
onAllToolCallsComplete: async (completedToolCalls) => {
|
||||
resolve(completedToolCalls[0]);
|
||||
},
|
||||
})
|
||||
.schedule(toolCallRequest, abortSignal)
|
||||
.catch(reject);
|
||||
});
|
||||
|
||||
scheduler.schedule(toolCallRequest, abortSignal).catch((error) => {
|
||||
reject(error);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -14,7 +14,13 @@ import type {
|
||||
ToolLocation,
|
||||
ToolResult,
|
||||
} from './tools.js';
|
||||
import { BaseDeclarativeTool, Kind, ToolConfirmationOutcome } from './tools.js';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
ToolConfirmationOutcome,
|
||||
} from './tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
@@ -102,13 +108,21 @@ interface CalculatedEdit {
|
||||
isNewFile: boolean;
|
||||
}
|
||||
|
||||
class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
class EditToolInvocation
|
||||
extends BaseToolInvocation<EditToolParams, ToolResult>
|
||||
implements ToolInvocation<EditToolParams, ToolResult>
|
||||
{
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
public params: EditToolParams,
|
||||
) {}
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, displayName);
|
||||
}
|
||||
|
||||
toolLocations(): ToolLocation[] {
|
||||
override toolLocations(): ToolLocation[] {
|
||||
return [{ path: this.params.file_path }];
|
||||
}
|
||||
|
||||
@@ -241,7 +255,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
* Handles the confirmation prompt for the Edit tool in the CLI.
|
||||
* It needs to calculate the diff to show the user.
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
@@ -467,7 +481,10 @@ export class EditTool
|
||||
{
|
||||
static readonly Name = EDIT_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
EditTool.Name,
|
||||
'Edit',
|
||||
@@ -510,6 +527,9 @@ Expectation for required parameters:
|
||||
required: ['file_path', 'old_string', 'new_string'],
|
||||
type: 'object',
|
||||
},
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -540,8 +560,17 @@ Expectation for required parameters:
|
||||
|
||||
protected createInvocation(
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
): ToolInvocation<EditToolParams, ToolResult> {
|
||||
return new EditToolInvocation(this.config, params);
|
||||
return new EditToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
toolName ?? this.name,
|
||||
displayName ?? this.displayName,
|
||||
);
|
||||
}
|
||||
|
||||
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
|
||||
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
import type { CallableTool, FunctionCall, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
@@ -244,6 +245,9 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: ToolParams,
|
||||
_messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<ToolParams, ToolResult> {
|
||||
return new DiscoveredMCPToolInvocation(
|
||||
this.mcpTool,
|
||||
|
||||
@@ -24,6 +24,7 @@ import type {
|
||||
} from './modifiable-tool.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { MEMORY_TOOL_NAME } from './tool-names.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
const memoryToolSchemaData: FunctionDeclaration = {
|
||||
name: MEMORY_TOOL_NAME,
|
||||
@@ -58,8 +59,7 @@ Do NOT use this tool:
|
||||
|
||||
## Parameters
|
||||
|
||||
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".
|
||||
`;
|
||||
- \`fact\` (string, required): The specific fact or piece of information to remember. This should be a clear, self-contained statement. For example, if the user says "My favorite color is blue", the fact would be "My favorite color is blue".`;
|
||||
|
||||
export const DEFAULT_CONTEXT_FILENAME = 'GEMINI.md';
|
||||
export const MEMORY_SECTION_HEADER = '## Gemini Added Memories';
|
||||
@@ -177,12 +177,21 @@ class MemoryToolInvocation extends BaseToolInvocation<
|
||||
> {
|
||||
private static readonly allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(
|
||||
params: SaveMemoryParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, displayName);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
const memoryFilePath = getGlobalMemoryFilePath();
|
||||
return `in ${tildeifyPath(memoryFilePath)}`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolEditConfirmationDetails | false> {
|
||||
const memoryFilePath = getGlobalMemoryFilePath();
|
||||
@@ -291,13 +300,16 @@ export class MemoryTool
|
||||
{
|
||||
static readonly Name = MEMORY_TOOL_NAME;
|
||||
|
||||
constructor() {
|
||||
constructor(messageBus?: MessageBus) {
|
||||
super(
|
||||
MemoryTool.Name,
|
||||
'Save Memory',
|
||||
memoryToolDescription,
|
||||
Kind.Think,
|
||||
memoryToolSchemaData.parametersJsonSchema as Record<string, unknown>,
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -311,8 +323,18 @@ export class MemoryTool
|
||||
return null;
|
||||
}
|
||||
|
||||
protected createInvocation(params: SaveMemoryParams) {
|
||||
return new MemoryToolInvocation(params);
|
||||
protected createInvocation(
|
||||
params: SaveMemoryParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
return new MemoryToolInvocation(
|
||||
params,
|
||||
messageBus ?? this.messageBus,
|
||||
toolName ?? this.name,
|
||||
displayName ?? this.displayName,
|
||||
);
|
||||
}
|
||||
|
||||
static async performAddMemoryEntry(
|
||||
|
||||
@@ -41,6 +41,7 @@ import {
|
||||
stripShellWrapper,
|
||||
} from '../utils/shell-utils.js';
|
||||
import { SHELL_TOOL_NAME } from './tool-names.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
export const OUTPUT_UPDATE_INTERVAL_MS = 1000;
|
||||
|
||||
@@ -58,8 +59,9 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
private readonly config: Config,
|
||||
params: ShellToolParams,
|
||||
private readonly allowlist: Set<string>,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
@@ -76,7 +78,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
|
||||
return description;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const command = stripShellWrapper(this.params.command);
|
||||
@@ -372,7 +374,10 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
private allowlist: Set<string> = new Set();
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
void initializeShellParsers().catch(() => {
|
||||
// Errors are surfaced when parsing commands.
|
||||
});
|
||||
@@ -403,6 +408,7 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
},
|
||||
false, // output is not markdown
|
||||
true, // output can be updated
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -444,7 +450,13 @@ export class ShellTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: ShellToolParams,
|
||||
messageBus?: MessageBus,
|
||||
): ToolInvocation<ShellToolParams, ToolResult> {
|
||||
return new ShellToolInvocation(this.config, params, this.allowlist);
|
||||
return new ShellToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
this.allowlist,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import * as crypto from 'node:crypto';
|
||||
import * as Diff from 'diff';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
type ToolCallConfirmationDetails,
|
||||
ToolConfirmationOutcome,
|
||||
@@ -19,6 +20,7 @@ import {
|
||||
type ToolResult,
|
||||
type ToolResultDisplay,
|
||||
} from './tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
@@ -369,13 +371,21 @@ interface CalculatedEdit {
|
||||
originalLineEnding: '\r\n' | '\n';
|
||||
}
|
||||
|
||||
class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
class EditToolInvocation
|
||||
extends BaseToolInvocation<EditToolParams, ToolResult>
|
||||
implements ToolInvocation<EditToolParams, ToolResult>
|
||||
{
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
public params: EditToolParams,
|
||||
) {}
|
||||
params: EditToolParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, displayName);
|
||||
}
|
||||
|
||||
toolLocations(): ToolLocation[] {
|
||||
override toolLocations(): ToolLocation[] {
|
||||
return [{ path: this.params.file_path }];
|
||||
}
|
||||
|
||||
@@ -602,7 +612,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
* Handles the confirmation prompt for the Edit tool in the CLI.
|
||||
* It needs to calculate the diff to show the user.
|
||||
*/
|
||||
async shouldConfirmExecute(
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
@@ -818,7 +828,10 @@ export class SmartEditTool
|
||||
{
|
||||
static readonly Name = EDIT_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
SmartEditTool.Name,
|
||||
'Edit',
|
||||
@@ -875,6 +888,9 @@ A good instruction should concisely answer:
|
||||
required: ['file_path', 'instruction', 'old_string', 'new_string'],
|
||||
type: 'object',
|
||||
},
|
||||
true, // isOutputMarkdown
|
||||
false, // canUpdateOutput
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -914,7 +930,13 @@ A good instruction should concisely answer:
|
||||
protected createInvocation(
|
||||
params: EditToolParams,
|
||||
): ToolInvocation<EditToolParams, ToolResult> {
|
||||
return new EditToolInvocation(this.config, params);
|
||||
return new EditToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
this.messageBus,
|
||||
this.name,
|
||||
this.displayName,
|
||||
);
|
||||
}
|
||||
|
||||
getModifyContext(_: AbortSignal): ModifyContext<EditToolParams> {
|
||||
|
||||
@@ -21,6 +21,7 @@ import { parse } from 'shell-quote';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
@@ -162,6 +163,9 @@ Signal: Signal number or \`(none)\` if no signal was received.
|
||||
|
||||
protected createInvocation(
|
||||
params: ToolParams,
|
||||
_messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<ToolParams, ToolResult> {
|
||||
return new DiscoveredToolInvocation(this.config, this.name, params);
|
||||
}
|
||||
|
||||
@@ -104,25 +104,37 @@ export abstract class BaseToolInvocation<
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
const confirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'info',
|
||||
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
|
||||
prompt: this.getDescription(),
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: this._toolName,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
// When no message bus, use default confirmation flow
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
|
||||
/**
|
||||
* Subclasses should override this method to provide custom confirmation UI
|
||||
* when the policy engine's decision is 'ASK_USER'.
|
||||
* The base implementation provides a generic confirmation prompt.
|
||||
*/
|
||||
protected async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const confirmationDetails: ToolCallConfirmationDetails = {
|
||||
type: 'info',
|
||||
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
|
||||
prompt: this.getDescription(),
|
||||
onConfirm: async (outcome: ToolConfirmationOutcome) => {
|
||||
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
|
||||
if (this.messageBus && this._toolName) {
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: this._toolName,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
return confirmationDetails;
|
||||
}
|
||||
|
||||
protected getMessageBusDecision(
|
||||
|
||||
@@ -521,7 +521,7 @@ describe('WebFetchTool', () => {
|
||||
|
||||
// Should reject with error when denied
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool execution denied by policy',
|
||||
'Tool execution for "WebFetch" denied by policy.',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -559,7 +559,7 @@ describe('WebFetchTool', () => {
|
||||
abortController.abort();
|
||||
|
||||
await expect(confirmationPromise).rejects.toThrow(
|
||||
'Tool execution denied by policy.',
|
||||
'Tool execution for "WebFetch" denied by policy.',
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -205,21 +205,9 @@ ${textContent}
|
||||
return `Processing URLs and instructions from prompt: "${displayPrompt}"`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Try message bus confirmation first if available
|
||||
if (this.messageBus) {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
return false; // No confirmation needed
|
||||
}
|
||||
if (decision === 'DENY') {
|
||||
throw new Error('Tool execution denied by policy.');
|
||||
}
|
||||
// if 'ASK_USER', fall through to legacy logic
|
||||
}
|
||||
|
||||
// Legacy confirmation flow (no message bus OR policy decision was ASK_USER)
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
|
||||
@@ -42,6 +42,7 @@ import { FileOperationEvent } from '../telemetry/types.js';
|
||||
import { FileOperation } from '../telemetry/metrics.js';
|
||||
import { getSpecificMimeType } from '../utils/fileUtils.js';
|
||||
import { getLanguageFromFilePath } from '../utils/language-detection.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
|
||||
/**
|
||||
* Parameters for the WriteFile tool
|
||||
@@ -144,8 +145,11 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
params: WriteFileToolParams,
|
||||
messageBus?: MessageBus,
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params);
|
||||
super(params, messageBus, toolName, displayName);
|
||||
}
|
||||
|
||||
override toolLocations(): ToolLocation[] {
|
||||
@@ -160,7 +164,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
return `Writing to ${shortenPath(relativePath)}`;
|
||||
}
|
||||
|
||||
override async shouldConfirmExecute(
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
@@ -392,7 +396,10 @@ export class WriteFileTool
|
||||
{
|
||||
static readonly Name = WRITE_FILE_TOOL_NAME;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
constructor(
|
||||
private readonly config: Config,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
super(
|
||||
WriteFileTool.Name,
|
||||
'WriteFile',
|
||||
@@ -415,6 +422,9 @@ export class WriteFileTool
|
||||
required: ['file_path', 'content'],
|
||||
type: 'object',
|
||||
},
|
||||
true,
|
||||
false,
|
||||
messageBus,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -458,7 +468,13 @@ export class WriteFileTool
|
||||
protected createInvocation(
|
||||
params: WriteFileToolParams,
|
||||
): ToolInvocation<WriteFileToolParams, ToolResult> {
|
||||
return new WriteFileToolInvocation(this.config, params);
|
||||
return new WriteFileToolInvocation(
|
||||
this.config,
|
||||
params,
|
||||
this.messageBus,
|
||||
this.name,
|
||||
this.displayName,
|
||||
);
|
||||
}
|
||||
|
||||
getModifyContext(
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
type Todo,
|
||||
type ToolResult,
|
||||
} from './tools.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { WRITE_TODOS_TOOL_NAME } from './tool-names.js';
|
||||
|
||||
const TODO_STATUSES = [
|
||||
@@ -204,6 +205,9 @@ export class WriteTodosTool extends BaseDeclarativeTool<
|
||||
|
||||
protected createInvocation(
|
||||
params: WriteTodosToolParams,
|
||||
_messageBus?: MessageBus,
|
||||
_toolName?: string,
|
||||
_displayName?: string,
|
||||
): ToolInvocation<WriteTodosToolParams, ToolResult> {
|
||||
return new WriteTodosToolInvocation(params);
|
||||
}
|
||||
|
||||
@@ -70,6 +70,13 @@ export class FatalCancellationError extends FatalError {
|
||||
}
|
||||
}
|
||||
|
||||
export class CanceledError extends Error {
|
||||
constructor(message = 'The operation was canceled.') {
|
||||
super(message);
|
||||
this.name = 'CanceledError';
|
||||
}
|
||||
}
|
||||
|
||||
export class ForbiddenError extends Error {}
|
||||
export class UnauthorizedError extends Error {}
|
||||
export class BadRequestError extends Error {}
|
||||
|
||||
Reference in New Issue
Block a user