fix(hooks): support 'ask' decision for BeforeTool hooks (#21146)

This commit is contained in:
Christian Gunderman
2026-03-21 03:52:39 +00:00
committed by GitHub
parent d3766875f8
commit d1dc4902fd
32 changed files with 1016 additions and 117 deletions
+12 -2
View File
@@ -16,6 +16,7 @@ import {
ToolConfirmationOutcome,
type ToolConfirmationPayload,
type ToolCallConfirmationDetails,
type ForcedToolDecision,
} from '../tools/tools.js';
import {
type ValidatingToolCall,
@@ -116,6 +117,8 @@ export async function resolveConfirmation(
getPreferredEditor: () => EditorType | undefined;
schedulerId: string;
onWaitingForConfirmation?: (waiting: boolean) => void;
systemMessage?: string;
forcedDecision?: ForcedToolDecision;
},
): Promise<ResolutionResult> {
const { state, onWaitingForConfirmation } = deps;
@@ -126,7 +129,7 @@ export async function resolveConfirmation(
// Loop exists to allow the user to modify the parameters and see the new
// diff.
while (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
if (signal.aborted) throw new Error('Operation cancelled');
if (signal.aborted) throw new Error('Operation cancelled by user');
const currentCall = state.getToolCall(callId);
if (!currentCall || !('invocation' in currentCall)) {
@@ -134,12 +137,19 @@ export async function resolveConfirmation(
}
const currentInvocation = currentCall.invocation;
const details = await currentInvocation.shouldConfirmExecute(signal);
const details = await currentInvocation.shouldConfirmExecute(
signal,
deps.forcedDecision,
);
if (!details) {
outcome = ToolConfirmationOutcome.ProceedOnce;
break;
}
if (deps.systemMessage) {
details.systemMessage = deps.systemMessage;
}
await notifyHooks(deps, details);
const correlationId = randomUUID();
+109
View File
@@ -0,0 +1,109 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import type { ToolCallRequestInfo } from './types.js';
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { ToolErrorType } from '../tools/tool-error.js';
export type HookEvaluationResult =
| {
status: 'continue';
hookDecision?: 'ask' | 'block';
hookSystemMessage?: string;
modifiedArgs?: Record<string, unknown>;
newInvocation?: AnyToolInvocation;
}
| {
status: 'error';
error: Error;
errorType: ToolErrorType;
};
export async function evaluateBeforeToolHook(
config: Config,
tool: AnyDeclarativeTool,
request: ToolCallRequestInfo,
invocation: AnyToolInvocation,
): Promise<HookEvaluationResult> {
const hookSystem = config.getHookSystem();
if (!hookSystem) {
return { status: 'continue' };
}
const params = invocation.params || {};
const toolInput: Record<string, unknown> = { ...params };
const mcpContext = extractMcpContext(invocation, config);
const beforeOutput = await hookSystem.fireBeforeToolEvent(
request.name,
toolInput,
mcpContext,
request.originalRequestName,
);
if (!beforeOutput) {
return { status: 'continue' };
}
if (beforeOutput.shouldStopExecution()) {
return {
status: 'error',
error: new Error(
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
),
errorType: ToolErrorType.STOP_EXECUTION,
};
}
const blockingError = beforeOutput.getBlockingError();
if (blockingError?.blocked) {
return {
status: 'error',
error: new Error(`Tool execution blocked: ${blockingError.reason}`),
errorType: ToolErrorType.POLICY_VIOLATION,
};
}
let hookDecision: 'ask' | 'block' | undefined;
let hookSystemMessage: string | undefined;
if (beforeOutput.isAskDecision()) {
hookDecision = 'ask';
hookSystemMessage = beforeOutput.systemMessage;
}
let modifiedArgs: Record<string, unknown> | undefined;
let newInvocation: AnyToolInvocation | undefined;
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
modifiedArgs = modifiedInput;
try {
newInvocation = tool.build(modifiedInput);
} catch (error) {
return {
status: 'error',
error: new Error(
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
),
errorType: ToolErrorType.INVALID_TOOL_PARAMS,
};
}
}
}
return {
status: 'continue',
hookDecision,
hookSystemMessage,
modifiedArgs,
newInvocation,
};
}
@@ -824,6 +824,7 @@ describe('Plan Mode Denial Consistency', () => {
toolRegistry: mockToolRegistry,
getToolRegistry: () => mockToolRegistry,
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
getHookSystem: vi.fn().mockReturnValue(undefined),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(false),
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode
@@ -170,6 +170,8 @@ describe('Scheduler (Orchestrator)', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
toolRegistry: mockToolRegistry,
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getHookSystem: vi.fn().mockReturnValue(undefined),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
@@ -1346,6 +1348,7 @@ describe('Scheduler MCP Progress', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getHookSystem: vi.fn().mockReturnValue(undefined),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
+39 -2
View File
@@ -10,6 +10,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { SchedulerStateManager } from './state-manager.js';
import { resolveConfirmation } from './confirmation.js';
import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
import { evaluateBeforeToolHook } from './hook-utils.js';
import { ToolExecutor } from './tool-executor.js';
import { ToolModificationHandler } from './tool-modifier.js';
import {
@@ -572,12 +573,46 @@ export class Scheduler {
): Promise<void> {
const callId = toolCall.request.callId;
// Policy & Security
const { decision, rule } = await checkPolicy(
// 1. Hook Check (BeforeTool)
const hookResult = await evaluateBeforeToolHook(
this.config,
toolCall.tool,
toolCall.request,
toolCall.invocation,
);
if (hookResult.status === 'error') {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
hookResult.error,
hookResult.errorType,
),
);
return;
}
const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
hookResult;
if (modifiedArgs && newInvocation) {
toolCall.request.args = modifiedArgs;
toolCall.request.inputModifiedByHook = true;
toolCall.invocation = newInvocation;
}
// 2. Policy & Security
const { decision: policyDecision, rule } = await checkPolicy(
toolCall,
this.config,
this.subagent,
);
let decision = policyDecision;
if (hookDecision === 'ask') {
decision = PolicyDecision.ASK_USER;
}
if (decision === PolicyDecision.DENY) {
const { errorMessage, errorType } = getPolicyDenialError(
@@ -610,6 +645,8 @@ export class Scheduler {
getPreferredEditor: this.getPreferredEditor,
schedulerId: this.schedulerId,
onWaitingForConfirmation: this.onWaitingForConfirmation,
systemMessage: hookSystemMessage,
forcedDecision: hookDecision === 'ask' ? 'ask_user' : undefined,
});
outcome = result.outcome;
lastDetails = result.lastDetails;
@@ -212,6 +212,8 @@ describe('Scheduler Parallel Execution', () => {
mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
toolRegistry: mockToolRegistry,
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getHookSystem: vi.fn().mockReturnValue(undefined),
isInteractive: vi.fn().mockReturnValue(true),
getEnableHooks: vi.fn().mockReturnValue(true),
setApprovalMode: vi.fn(),
@@ -115,10 +115,25 @@ export class ToolExecutor {
{ shellExecutionConfig, setExecutionIdCallback },
this.config,
request.originalRequestName,
true, // skipBeforeHook
);
const toolResult: ToolResult = await promise;
if (call.request.inputModifiedByHook) {
const modificationMsg = `\n\n[System] Tool input parameters were modified by a hook before execution.`;
if (typeof toolResult.llmContent === 'string') {
toolResult.llmContent += modificationMsg;
} else if (Array.isArray(toolResult.llmContent)) {
toolResult.llmContent.push({ text: modificationMsg });
} else if (toolResult.llmContent) {
toolResult.llmContent = [
toolResult.llmContent,
{ text: modificationMsg },
];
}
}
if (signal.aborted) {
completedToolCall = await this.createCancelledResult(
call,
+2
View File
@@ -47,6 +47,8 @@ export interface ToolCallRequestInfo {
traceId?: string;
parentCallId?: string;
schedulerId?: string;
inputModifiedByHook?: boolean;
forcedAsk?: boolean;
}
export interface ToolCallResponseInfo {