mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-06-21 16:57:08 -07:00
fix(hooks): support 'ask' decision for BeforeTool hooks (#21146)
This commit is contained in:
committed by
GitHub
parent
ad98e9c17c
commit
148b8f3ebd
@@ -166,7 +166,7 @@ describe('Tool Confirmation Policy Updates', () => {
|
||||
|
||||
// Mock getMessageBusDecision to trigger ASK_USER flow
|
||||
vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue(
|
||||
'ASK_USER',
|
||||
'ask_user',
|
||||
);
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
@@ -194,5 +194,39 @@ describe('Tool Confirmation Policy Updates', () => {
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
it('should skip confirmation in AUTO_EDIT mode', async () => {
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const tool = create(mockConfig, mockMessageBus);
|
||||
const invocation = tool.build(params as any);
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
expect(confirmation).toBe(false);
|
||||
});
|
||||
|
||||
it('should NOT skip confirmation in AUTO_EDIT mode if forcedDecision is ask_user', async () => {
|
||||
vi.spyOn(mockConfig, 'getApprovalMode').mockReturnValue(
|
||||
ApprovalMode.AUTO_EDIT,
|
||||
);
|
||||
const tool = create(mockConfig, mockMessageBus);
|
||||
const invocation = tool.build(params as any);
|
||||
|
||||
// Mock getMessageBusDecision to return ask_user
|
||||
vi.spyOn(invocation as any, 'getMessageBusDecision').mockResolvedValue(
|
||||
'ask_user',
|
||||
);
|
||||
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
'ask_user',
|
||||
);
|
||||
|
||||
expect(confirmation).not.toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -29,7 +29,6 @@ import { makeRelative, shortenPath } from '../utils/paths.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import { correctPath } from '../utils/pathCorrector.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { CoreToolCallStatus } from '../scheduler/types.js';
|
||||
|
||||
import { DEFAULT_DIFF_OPTIONS, getDiffStat } from './diffOptions.js';
|
||||
@@ -454,7 +453,16 @@ class EditToolInvocation
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, displayName);
|
||||
super(
|
||||
params,
|
||||
messageBus,
|
||||
toolName,
|
||||
displayName,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
() => this.config.getApprovalMode(),
|
||||
);
|
||||
if (!path.isAbsolute(this.params.file_path)) {
|
||||
const result = correctPath(this.params.file_path, this.config);
|
||||
if (result.success) {
|
||||
@@ -732,10 +740,6 @@ class EditToolInvocation
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let editData: CalculatedEdit;
|
||||
try {
|
||||
editData = await this.calculateEdit(this.params, abortSignal);
|
||||
|
||||
@@ -47,7 +47,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -74,7 +74,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ALLOW');
|
||||
).mockResolvedValue('allow');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -92,7 +92,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('DENY');
|
||||
).mockResolvedValue('deny');
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
@@ -136,7 +136,7 @@ describe('EnterPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
|
||||
const details = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
|
||||
@@ -87,11 +87,11 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolInfoConfirmationDetails | false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -99,7 +99,7 @@ export class EnterPlanModeInvocation extends BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
// ASK_USER
|
||||
// ask_user
|
||||
return {
|
||||
type: 'info',
|
||||
title: 'Enter Plan Mode',
|
||||
|
||||
@@ -59,7 +59,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ASK_USER');
|
||||
).mockResolvedValue('ask_user');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -127,7 +127,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('ALLOW');
|
||||
).mockResolvedValue('allow');
|
||||
|
||||
const result = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
@@ -150,7 +150,7 @@ describe('ExitPlanModeTool', () => {
|
||||
getMessageBusDecision: () => Promise<string>;
|
||||
},
|
||||
'getMessageBusDecision',
|
||||
).mockResolvedValue('DENY');
|
||||
).mockResolvedValue('deny');
|
||||
|
||||
await expect(
|
||||
invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
|
||||
@@ -138,7 +138,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
}
|
||||
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -146,7 +146,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ALLOW') {
|
||||
if (decision === 'allow') {
|
||||
// If policy is allow, auto-approve with default settings and execute.
|
||||
this.confirmationOutcome = ToolConfirmationOutcome.ProceedOnce;
|
||||
this.approvalPayload = {
|
||||
@@ -156,7 +156,7 @@ export class ExitPlanModeInvocation extends BaseToolInvocation<
|
||||
return false;
|
||||
}
|
||||
|
||||
// decision is 'ASK_USER'
|
||||
// decision is 'ask_user'
|
||||
return {
|
||||
type: 'exit_plan_mode',
|
||||
title: 'Plan Approval',
|
||||
|
||||
@@ -57,10 +57,10 @@ class TestToolInvocation extends BaseToolInvocation<TestParams, TestResult> {
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
if (decision === 'DENY') {
|
||||
if (decision === 'deny') {
|
||||
throw new Error('Tool execution denied by policy');
|
||||
}
|
||||
return false;
|
||||
|
||||
@@ -19,9 +19,15 @@ import {
|
||||
type ToolConfirmationResponse,
|
||||
type Question,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { type ApprovalMode } from '../policy/types.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import type { SubagentProgress } from '../agents/types.js';
|
||||
|
||||
/**
|
||||
/**
|
||||
* Supported decisions for forcing tool execution behavior.
|
||||
*/
|
||||
export type ForcedToolDecision = 'allow' | 'deny' | 'ask_user';
|
||||
|
||||
/**
|
||||
* Options bag for tool execution, replacing positional parameters that are
|
||||
* only relevant to specific tool types.
|
||||
@@ -65,6 +71,7 @@ export interface ToolInvocation<
|
||||
*/
|
||||
shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
|
||||
/**
|
||||
@@ -148,6 +155,8 @@ export abstract class BaseToolInvocation<
|
||||
readonly _toolDisplayName?: string,
|
||||
readonly _serverName?: string,
|
||||
readonly _toolAnnotations?: Record<string, unknown>,
|
||||
readonly respectsAutoEdit: boolean = false,
|
||||
readonly getApprovalMode: () => ApprovalMode = () => ApprovalMode.DEFAULT,
|
||||
) {}
|
||||
|
||||
abstract getDescription(): string;
|
||||
@@ -158,13 +167,23 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
async shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
const decision = await this.getMessageBusDecision(abortSignal);
|
||||
if (decision === 'ALLOW') {
|
||||
if (
|
||||
this.respectsAutoEdit &&
|
||||
this.getApprovalMode() === ApprovalMode.AUTO_EDIT &&
|
||||
forcedDecision !== 'ask_user'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'DENY') {
|
||||
const decision =
|
||||
forcedDecision ?? (await this.getMessageBusDecision(abortSignal));
|
||||
if (decision === 'allow') {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (decision === 'deny') {
|
||||
throw new Error(
|
||||
`Tool execution for "${
|
||||
this._toolDisplayName || this._toolName
|
||||
@@ -172,7 +191,7 @@ export abstract class BaseToolInvocation<
|
||||
);
|
||||
}
|
||||
|
||||
if (decision === 'ASK_USER') {
|
||||
if (decision === 'ask_user') {
|
||||
return this.getConfirmationDetails(abortSignal);
|
||||
}
|
||||
|
||||
@@ -216,7 +235,7 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
/**
|
||||
* Subclasses should override this method to provide custom confirmation UI
|
||||
* when the policy engine's decision is 'ASK_USER'.
|
||||
* when the policy engine's decision is 'ask_user'.
|
||||
* The base implementation provides a generic confirmation prompt.
|
||||
*/
|
||||
protected async getConfirmationDetails(
|
||||
@@ -239,11 +258,12 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
protected getMessageBusDecision(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<'ALLOW' | 'DENY' | 'ASK_USER'> {
|
||||
forcedDecision?: ForcedToolDecision,
|
||||
): Promise<ForcedToolDecision> {
|
||||
if (!this.messageBus || !this._toolName) {
|
||||
// If there's no message bus, we can't make a decision, so we allow.
|
||||
// The legacy confirmation flow will still apply if the tool needs it.
|
||||
return Promise.resolve('ALLOW');
|
||||
return Promise.resolve('allow');
|
||||
}
|
||||
|
||||
const correlationId = randomUUID();
|
||||
@@ -257,11 +277,12 @@ export abstract class BaseToolInvocation<
|
||||
},
|
||||
serverName: this._serverName,
|
||||
toolAnnotations: this._toolAnnotations,
|
||||
forcedDecision,
|
||||
};
|
||||
|
||||
return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => {
|
||||
return new Promise<ForcedToolDecision>((resolve) => {
|
||||
if (!this.messageBus) {
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -282,11 +303,11 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
const abortHandler = () => {
|
||||
cleanup();
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
};
|
||||
|
||||
if (abortSignal.aborted) {
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -294,11 +315,11 @@ export abstract class BaseToolInvocation<
|
||||
if (response.correlationId === correlationId) {
|
||||
cleanup();
|
||||
if (response.requiresUserConfirmation) {
|
||||
resolve('ASK_USER');
|
||||
resolve('ask_user');
|
||||
} else if (response.confirmed) {
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
} else {
|
||||
resolve('DENY');
|
||||
resolve('deny');
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -307,7 +328,7 @@ export abstract class BaseToolInvocation<
|
||||
|
||||
timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
resolve('ASK_USER'); // Default to ASK_USER on timeout
|
||||
resolve('ask_user'); // Default to ask_user on timeout
|
||||
}, 30000);
|
||||
|
||||
this.messageBus.subscribe(
|
||||
@@ -325,7 +346,7 @@ export abstract class BaseToolInvocation<
|
||||
void this.messageBus.publish(request);
|
||||
} catch (_error) {
|
||||
cleanup();
|
||||
resolve('ALLOW');
|
||||
resolve('allow');
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -859,6 +880,7 @@ export interface DiffStat {
|
||||
export interface ToolEditConfirmationDetails {
|
||||
type: 'edit';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
payload?: ToolConfirmationPayload,
|
||||
@@ -897,6 +919,7 @@ export type ToolConfirmationPayload =
|
||||
export interface ToolExecuteConfirmationDetails {
|
||||
type: 'exec';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
command: string;
|
||||
rootCommand: string;
|
||||
@@ -907,6 +930,7 @@ export interface ToolExecuteConfirmationDetails {
|
||||
export interface ToolMcpConfirmationDetails {
|
||||
type: 'mcp';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDisplayName: string;
|
||||
@@ -919,6 +943,7 @@ export interface ToolMcpConfirmationDetails {
|
||||
export interface ToolInfoConfirmationDetails {
|
||||
type: 'info';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
onConfirm: (outcome: ToolConfirmationOutcome) => Promise<void>;
|
||||
prompt: string;
|
||||
urls?: string[];
|
||||
@@ -927,6 +952,7 @@ export interface ToolInfoConfirmationDetails {
|
||||
export interface ToolAskUserConfirmationDetails {
|
||||
type: 'ask_user';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
questions: Question[];
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
@@ -937,6 +963,7 @@ export interface ToolAskUserConfirmationDetails {
|
||||
export interface ToolExitPlanModeConfirmationDetails {
|
||||
type: 'exit_plan_mode';
|
||||
title: string;
|
||||
systemMessage?: string;
|
||||
planPath: string;
|
||||
onConfirm: (
|
||||
outcome: ToolConfirmationOutcome,
|
||||
|
||||
@@ -18,7 +18,6 @@ import { buildParamArgsPattern } from '../policy/utils.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { fetchWithTimeout, isPrivateIp } from '../utils/fetch.js';
|
||||
import { truncateString } from '../utils/textUtils.js';
|
||||
@@ -231,7 +230,16 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
_toolName?: string,
|
||||
_toolDisplayName?: string,
|
||||
) {
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
super(
|
||||
params,
|
||||
messageBus,
|
||||
_toolName,
|
||||
_toolDisplayName,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
() => this.context.config.getApprovalMode(),
|
||||
);
|
||||
}
|
||||
|
||||
private handleRetry(attempt: number, error: unknown, delayMs: number): void {
|
||||
@@ -516,12 +524,6 @@ ${aggregatedContent}
|
||||
protected override async getConfirmationDetails(
|
||||
_abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
// Check for AUTO_EDIT approval mode. This tool has a specific behavior
|
||||
// where ProceedAlways switches the entire session to AUTO_EDIT.
|
||||
if (this.context.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let urls: string[] = [];
|
||||
let prompt = this.params.prompt || '';
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import os from 'node:os';
|
||||
import * as Diff from 'diff';
|
||||
import { WRITE_FILE_TOOL_NAME, WRITE_FILE_DISPLAY_NAME } from './tool-names.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
@@ -156,7 +155,16 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
toolName?: string,
|
||||
displayName?: string,
|
||||
) {
|
||||
super(params, messageBus, toolName, displayName);
|
||||
super(
|
||||
params,
|
||||
messageBus,
|
||||
toolName,
|
||||
displayName,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
() => this.config.getApprovalMode(),
|
||||
);
|
||||
this.resolvedPath = path.resolve(
|
||||
this.config.getTargetDir(),
|
||||
this.params.file_path,
|
||||
@@ -186,10 +194,6 @@ class WriteFileToolInvocation extends BaseToolInvocation<
|
||||
protected override async getConfirmationDetails(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
if (this.config.getApprovalMode() === ApprovalMode.AUTO_EDIT) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const correctedContentResult = await getCorrectedFileContent(
|
||||
this.config,
|
||||
this.resolvedPath,
|
||||
|
||||
Reference in New Issue
Block a user