fix(hooks): support 'ask' decision for BeforeTool hooks (#21146)

This commit is contained in:
Christian Gunderman
2026-03-21 03:52:39 +00:00
committed by GitHub
parent ad98e9c17c
commit 148b8f3ebd
32 changed files with 1016 additions and 117 deletions
@@ -166,7 +166,7 @@ describe('Tool Confirmation Policy Updates', () => {
// Mock getMessageBusDecision to trigger ASK_USER flow
vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue(
'ASK_USER',
'ask_user',
);
const confirmation = await invocation.shouldConfirmExecute(
@@ -194,5 +194,39 @@ describe('Tool Confirmation Policy Updates', () => {
}
},
);
it('should skip confirmation in AUTO_EDIT mode', async () => {
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
ApprovalMode.AUTO_EDIT,
);
const tool = create(mockConfig, mockMessageBus);
const invocation = tool.build(params as any);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
it('should NOT skip confirmation in AUTO_EDIT mode if forcedDecision is ask_user', async () => {
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
ApprovalMode.AUTO_EDIT,
);
const tool = create(mockConfig, mockMessageBus);
const invocation = tool.build(params as any);
// Mock getMessageBusDecision to return ask_user
vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue(
'ask_user',
);
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
'ask_user',
);
expect(confirmation).not.toBe(false);
});
});
});
+10 -6
View File
@@ -29,7 +29,6 @@ import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js';
import { correctPath } from '../utils/pathCorrector.js';
import type { Config } from '../config/config.js';
import { ApprovalMode } from '../policy/types.js';
import { CoreToolCallStatus } from '../scheduler/types.js';
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
@@ -454,7 +453,16 @@ class EditToolInvocation
toolName?: string,
displayName?: string,
) {
super(params, messageBus, toolName, displayName);
super(
params,
messageBus,
toolName,
displayName,
undefined,
undefined,
true,
() => this.config.getApprovalMode(),
);
if (!path.isAbsolute(this.params.file_path)) {
const result = correctPath(this.params.file_path, this.config);
if (result.success) {
@@ -732,10 +740,6 @@ class EditToolInvocation
protected override async getConfirmationDetails(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
let editData: CalculatedEdit;
try {
editData = await this.calculateEdit(this.params, abortSignal);
@@ -47,7 +47,7 @@ describe('EnterPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('ASK_USER');
).mockResolvedValue('ask_user');
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
@@ -74,7 +74,7 @@ describe('EnterPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('ALLOW');
).mockResolvedValue('allow');
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
@@ -92,7 +92,7 @@ describe('EnterPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('DENY');
).mockResolvedValue('deny');
await expect(
invocation.shouldConfirmExecute(new AbortController().signal),
@@ -136,7 +136,7 @@ describe('EnterPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('ASK_USER');
).mockResolvedValue('ask_user');
const details = await invocation.shouldConfirmExecute(
new AbortController().signal,
+3 -3
View File
@@ -87,11 +87,11 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
abortSignal: AbortSignal,
): Promise<ToolInfoConfirmationDetails | false> {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
if (decision === 'allow') {
return false;
}
if (decision === 'DENY') {
if (decision === 'deny') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
@@ -99,7 +99,7 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
);
}
// ASK_USER
// ask_user
return {
type: 'info',
title: 'Enter Plan Mode',
@@ -59,7 +59,7 @@ describe('ExitPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('ASK_USER');
).mockResolvedValue('ask_user');
});
afterEach(() => {
@@ -127,7 +127,7 @@ describe('ExitPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('ALLOW');
).mockResolvedValue('allow');
const result = await invocation.shouldConfirmExecute(
new AbortController().signal,
@@ -150,7 +150,7 @@ describe('ExitPlanModeTool', () => {
getMessageBusDecision: () => Promise<string>;
},
'getMessageBusDecision',
).mockResolvedValue('DENY');
).mockResolvedValue('deny');
await expect(
invocation.shouldConfirmExecute(new AbortController().signal),
+3 -3
View File
@@ -138,7 +138,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
}
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'DENY') {
if (decision === 'deny') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
@@ -146,7 +146,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
);
}
if (decision === 'ALLOW') {
if (decision === 'allow') {
// If policy is allow, auto-approve with default settings and execute.
this.confirmationOutcome = ToolConfirmationOutcome.ProceedOnce;
this.approvalPayload = {
@@ -156,7 +156,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
return false;
}
// decision is 'ASK_USER'
// decision is 'ask_user'
return {
type: 'exit_plan_mode',
title: 'Plan Approval',
@@ -57,10 +57,10 @@ class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
abortSignal: AbortSignal,
): Promise<false> {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
if (decision === 'allow') {
return false;
}
if (decision === 'DENY') {
if (decision === 'deny') {
throw new Error('Tool execution denied by policy');
}
return false;
+44 -17
View File
@@ -19,9 +19,15 @@ import {
type ToolConfirmationResponse,
type Question,
} from '../confirmation-bus/types.js';
import { type ApprovalMode } from '../policy/types.js';
import { ApprovalMode } from '../policy/types.js';
import type { SubagentProgress } from '../agents/types.js';
/**
/**
* Supported decisions for forcing tool execution behavior.
*/
export type ForcedToolDecision = 'allow' | 'deny' | 'ask_user';
/**
* Options bag for tool execution, replacing positional parameters that are
* only relevant to specific tool types.
@@ -65,6 +71,7 @@ export interface ToolInvocation<
*/
shouldConfirmExecute(
abortSignal: AbortSignal,
forcedDecision?: ForcedToolDecision,
): Promise<ToolCallConfirmationDetails | false>;
/**
@@ -148,6 +155,8 @@ export abstract class BaseToolInvocation<
readonly _toolDisplayName?: string,
readonly _serverName?: string,
readonly _toolAnnotations?: Record<string, unknown>,
readonly respectsAutoEdit: boolean = false,
readonly getApprovalMode: () => ApprovalMode = () => ApprovalMode.DEFAULT,
) {}
abstract getDescription(): string;
@@ -158,13 +167,23 @@ export abstract class BaseToolInvocation<
async shouldConfirmExecute(
abortSignal: AbortSignal,
forcedDecision?: ForcedToolDecision,
): Promise<ToolCallConfirmationDetails | false> {
const decision = await this.getMessageBusDecision(abortSignal);
if (decision === 'ALLOW') {
if (
this.respectsAutoEdit &&
this.getApprovalMode() === ApprovalMode.AUTO_EDIT &&
forcedDecision !== 'ask_user'
) {
return false;
}
if (decision === 'DENY') {
const decision =
forcedDecision ?? (await this.getMessageBusDecision(abortSignal));
if (decision === 'allow') {
return false;
}
if (decision === 'deny') {
throw new Error(
`Tool execution for "${
this._toolDisplayName || this._toolName
@@ -172,7 +191,7 @@ export abstract class BaseToolInvocation<
);
}
if (decision === 'ASK_USER') {
if (decision === 'ask_user') {
return this.getConfirmationDetails(abortSignal);
}
@@ -216,7 +235,7 @@ export abstract class BaseToolInvocation<
/**
* Subclasses should override this method to provide custom confirmation UI
* when the policy engine's decision is 'ASK_USER'.
* when the policy engine's decision is 'ask_user'.
* The base implementation provides a generic confirmation prompt.
*/
protected async getConfirmationDetails(
@@ -239,11 +258,12 @@ export abstract class BaseToolInvocation<
protected getMessageBusDecision(
abortSignal: AbortSignal,
): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> {
forcedDecision?: ForcedToolDecision,
): Promise<ForcedToolDecision> {
if (!this.messageBus || !this._toolName) {
// If there's no message bus, we can't make a decision, so we allow.
// The legacy confirmation flow will still apply if the tool needs it.
return Promise.resolve('ALLOW');
return Promise.resolve('allow');
}
const correlationId = randomUUID();
@@ -257,11 +277,12 @@ export abstract class BaseToolInvocation<
},
serverName: this._serverName,
toolAnnotations: this._toolAnnotations,
forcedDecision,
};
return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => {
return new Promise<ForcedToolDecision>((resolve) => {
if (!this.messageBus) {
resolve('ALLOW');
resolve('allow');
return;
}
@@ -282,11 +303,11 @@ export abstract class BaseToolInvocation<
const abortHandler = () => {
cleanup();
resolve('DENY');
resolve('deny');
};
if (abortSignal.aborted) {
resolve('DENY');
resolve('deny');
return;
}
@@ -294,11 +315,11 @@ export abstract class BaseToolInvocation<
if (response.correlationId === correlationId) {
cleanup();
if (response.requiresUserConfirmation) {
resolve('ASK_USER');
resolve('ask_user');
} else if (response.confirmed) {
resolve('ALLOW');
resolve('allow');
} else {
resolve('DENY');
resolve('deny');
}
}
};
@@ -307,7 +328,7 @@ export abstract class BaseToolInvocation<
timeoutId = setTimeout(() => {
cleanup();
resolve('ASK_USER'); // Default to ASK_USER on timeout
resolve('ask_user'); // Default to ask_user on timeout
}, 30000);
this.messageBus.subscribe(
@@ -325,7 +346,7 @@ export abstract class BaseToolInvocation<
void this.messageBus.publish(request);
} catch (_error) {
cleanup();
resolve('ALLOW');
resolve('allow');
}
});
}
@@ -859,6 +880,7 @@ export interface DiffStat {
export interface ToolEditConfirmationDetails {
type: 'edit';
title: string;
systemMessage?: string;
onConfirm: (
outcome: ToolConfirmationOutcome,
payload?: ToolConfirmationPayload,
@@ -897,6 +919,7 @@ export type ToolConfirmationPayload =
export interface ToolExecuteConfirmationDetails {
type: 'exec';
title: string;
systemMessage?: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
command: string;
rootCommand: string;
@@ -907,6 +930,7 @@ export interface ToolExecuteConfirmationDetails {
export interface ToolMcpConfirmationDetails {
type: 'mcp';
title: string;
systemMessage?: string;
serverName: string;
toolName: string;
toolDisplayName: string;
@@ -919,6 +943,7 @@ export interface ToolMcpConfirmationDetails {
export interface ToolInfoConfirmationDetails {
type: 'info';
title: string;
systemMessage?: string;
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
prompt: string;
urls?: string[];
@@ -927,6 +952,7 @@ export interface ToolInfoConfirmationDetails {
export interface ToolAskUserConfirmationDetails {
type: 'ask_user';
title: string;
systemMessage?: string;
questions: Question[];
onConfirm: (
outcome: ToolConfirmationOutcome,
@@ -937,6 +963,7 @@ export interface ToolAskUserConfirmationDetails {
export interface ToolExitPlanModeConfirmationDetails {
type: 'exit_plan_mode';
title: string;
systemMessage?: string;
planPath: string;
onConfirm: (
outcome: ToolConfirmationOutcome,
+10 -8
View File
@@ -18,7 +18,6 @@ import { buildParamArgsPattern } from '../policy/utils.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js';
import { ApprovalMode } from '../policy/types.js';
import { getResponseText } from '../utils/partUtils.js';
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
import { truncateString } from '../utils/textUtils.js';
@@ -231,7 +230,16 @@ class WebFetchToolInvocation extends BaseToolInvocation<
_toolName?: string,
_toolDisplayName?: string,
) {
super(params, messageBus, _toolName, _toolDisplayName);
super(
params,
messageBus,
_toolName,
_toolDisplayName,
undefined,
undefined,
true,
() => this.context.config.getApprovalMode(),
);
}
private handleRetry(attempt: number, error: unknown, delayMs: number): void {
@@ -516,12 +524,6 @@ ${aggregatedContent}
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
// Check for AUTO_EDIT approval mode. This tool has a specific behavior
// where ProceedAlways switches the entire session to AUTO_EDIT.
if (this.context.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
let urls: string[] = [];
let prompt = this.params.prompt || '';
+10 -6
View File
@@ -11,7 +11,6 @@ import os from 'node:os';
import * as Diff from 'diff';
import { WRITE_FILE_TOOL_NAME, WRITE_FILE_DISPLAY_NAME } from './tool-names.js';
import type { Config } from '../config/config.js';
import { ApprovalMode } from '../policy/types.js';
import {
BaseDeclarativeTool,
@@ -156,7 +155,16 @@ class WriteFileToolInvocation extends BaseToolInvocation<
toolName?: string,
displayName?: string,
) {
super(params, messageBus, toolName, displayName);
super(
params,
messageBus,
toolName,
displayName,
undefined,
undefined,
true,
() => this.config.getApprovalMode(),
);
this.resolvedPath = path.resolve(
this.config.getTargetDir(),
this.params.file_path,
@@ -186,10 +194,6 @@ class WriteFileToolInvocation extends BaseToolInvocation<
protected override async getConfirmationDetails(
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
return false;
}
const correctedContentResult = await getCorrectedFileContent(
this.config,
this.resolvedPath,