mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 06:54:15 -07:00
fix(hooks): support 'ask' decision for BeforeTool hooks (#21146)
This commit is contained in:
committed by
GitHub
parent
d3766875f8
commit
d1dc4902fd
@@ -16,6 +16,7 @@ import {
|
||||
ToolConfirmationOutcome,
|
||||
type ToolConfirmationPayload,
|
||||
type ToolCallConfirmationDetails,
|
||||
type ForcedToolDecision,
|
||||
} from '../tools/tools.js';
|
||||
import {
|
||||
type ValidatingToolCall,
|
||||
@@ -116,6 +117,8 @@ export async function resolveConfirmation(
|
||||
getPreferredEditor: () => EditorType | undefined;
|
||||
schedulerId: string;
|
||||
onWaitingForConfirmation?: (waiting: boolean) => void;
|
||||
systemMessage?: string;
|
||||
forcedDecision?: ForcedToolDecision;
|
||||
},
|
||||
): Promise<ResolutionResult> {
|
||||
const { state, onWaitingForConfirmation } = deps;
|
||||
@@ -126,7 +129,7 @@ export async function resolveConfirmation(
|
||||
// Loop exists to allow the user to modify the parameters and see the new
|
||||
// diff.
|
||||
while (outcome === ToolConfirmationOutcome.ModifyWithEditor) {
|
||||
if (signal.aborted) throw new Error('Operation cancelled');
|
||||
if (signal.aborted) throw new Error('Operation cancelled by user');
|
||||
|
||||
const currentCall = state.getToolCall(callId);
|
||||
if (!currentCall || !('invocation' in currentCall)) {
|
||||
@@ -134,12 +137,19 @@ export async function resolveConfirmation(
|
||||
}
|
||||
const currentInvocation = currentCall.invocation;
|
||||
|
||||
const details = await currentInvocation.shouldConfirmExecute(signal);
|
||||
const details = await currentInvocation.shouldConfirmExecute(
|
||||
signal,
|
||||
deps.forcedDecision,
|
||||
);
|
||||
if (!details) {
|
||||
outcome = ToolConfirmationOutcome.ProceedOnce;
|
||||
break;
|
||||
}
|
||||
|
||||
if (deps.systemMessage) {
|
||||
details.systemMessage = deps.systemMessage;
|
||||
}
|
||||
|
||||
await notifyHooks(deps, details);
|
||||
|
||||
const correlationId = randomUUID();
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
|
||||
import type { ToolCallRequestInfo } from './types.js';
|
||||
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
|
||||
import { BeforeToolHookOutput } from '../hooks/types.js';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
|
||||
export type HookEvaluationResult =
|
||||
| {
|
||||
status: 'continue';
|
||||
hookDecision?: 'ask' | 'block';
|
||||
hookSystemMessage?: string;
|
||||
modifiedArgs?: Record<string, unknown>;
|
||||
newInvocation?: AnyToolInvocation;
|
||||
}
|
||||
| {
|
||||
status: 'error';
|
||||
error: Error;
|
||||
errorType: ToolErrorType;
|
||||
};
|
||||
|
||||
export async function evaluateBeforeToolHook(
|
||||
config: Config,
|
||||
tool: AnyDeclarativeTool,
|
||||
request: ToolCallRequestInfo,
|
||||
invocation: AnyToolInvocation,
|
||||
): Promise<HookEvaluationResult> {
|
||||
const hookSystem = config.getHookSystem();
|
||||
if (!hookSystem) {
|
||||
return { status: 'continue' };
|
||||
}
|
||||
|
||||
const params = invocation.params || {};
|
||||
const toolInput: Record<string, unknown> = { ...params };
|
||||
const mcpContext = extractMcpContext(invocation, config);
|
||||
|
||||
const beforeOutput = await hookSystem.fireBeforeToolEvent(
|
||||
request.name,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
request.originalRequestName,
|
||||
);
|
||||
|
||||
if (!beforeOutput) {
|
||||
return { status: 'continue' };
|
||||
}
|
||||
|
||||
if (beforeOutput.shouldStopExecution()) {
|
||||
return {
|
||||
status: 'error',
|
||||
error: new Error(
|
||||
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
|
||||
),
|
||||
errorType: ToolErrorType.STOP_EXECUTION,
|
||||
};
|
||||
}
|
||||
|
||||
const blockingError = beforeOutput.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
return {
|
||||
status: 'error',
|
||||
error: new Error(`Tool execution blocked: ${blockingError.reason}`),
|
||||
errorType: ToolErrorType.POLICY_VIOLATION,
|
||||
};
|
||||
}
|
||||
|
||||
let hookDecision: 'ask' | 'block' | undefined;
|
||||
let hookSystemMessage: string | undefined;
|
||||
|
||||
if (beforeOutput.isAskDecision()) {
|
||||
hookDecision = 'ask';
|
||||
hookSystemMessage = beforeOutput.systemMessage;
|
||||
}
|
||||
|
||||
let modifiedArgs: Record<string, unknown> | undefined;
|
||||
let newInvocation: AnyToolInvocation | undefined;
|
||||
|
||||
if (beforeOutput instanceof BeforeToolHookOutput) {
|
||||
const modifiedInput = beforeOutput.getModifiedToolInput();
|
||||
if (modifiedInput) {
|
||||
modifiedArgs = modifiedInput;
|
||||
try {
|
||||
newInvocation = tool.build(modifiedInput);
|
||||
} catch (error) {
|
||||
return {
|
||||
status: 'error',
|
||||
error: new Error(
|
||||
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
|
||||
),
|
||||
errorType: ToolErrorType.INVALID_TOOL_PARAMS,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
status: 'continue',
|
||||
hookDecision,
|
||||
hookSystemMessage,
|
||||
modifiedArgs,
|
||||
newInvocation,
|
||||
};
|
||||
}
|
||||
@@ -824,6 +824,7 @@ describe('Plan Mode Denial Consistency', () => {
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getMessageBus: vi.fn().mockReturnValue(mockMessageBus),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.PLAN), // Key: Plan Mode
|
||||
|
||||
@@ -170,6 +170,8 @@ describe('Scheduler (Orchestrator)', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
@@ -1346,6 +1348,7 @@ describe('Scheduler MCP Progress', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -10,6 +10,7 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { SchedulerStateManager } from './state-manager.js';
|
||||
import { resolveConfirmation } from './confirmation.js';
|
||||
import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
|
||||
import { evaluateBeforeToolHook } from './hook-utils.js';
|
||||
import { ToolExecutor } from './tool-executor.js';
|
||||
import { ToolModificationHandler } from './tool-modifier.js';
|
||||
import {
|
||||
@@ -572,12 +573,46 @@ export class Scheduler {
|
||||
): Promise<void> {
|
||||
const callId = toolCall.request.callId;
|
||||
|
||||
// Policy & Security
|
||||
const { decision, rule } = await checkPolicy(
|
||||
// 1. Hook Check (BeforeTool)
|
||||
const hookResult = await evaluateBeforeToolHook(
|
||||
this.config,
|
||||
toolCall.tool,
|
||||
toolCall.request,
|
||||
toolCall.invocation,
|
||||
);
|
||||
|
||||
if (hookResult.status === 'error') {
|
||||
this.state.updateStatus(
|
||||
callId,
|
||||
CoreToolCallStatus.Error,
|
||||
createErrorResponse(
|
||||
toolCall.request,
|
||||
hookResult.error,
|
||||
hookResult.errorType,
|
||||
),
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
|
||||
hookResult;
|
||||
|
||||
if (modifiedArgs && newInvocation) {
|
||||
toolCall.request.args = modifiedArgs;
|
||||
toolCall.request.inputModifiedByHook = true;
|
||||
toolCall.invocation = newInvocation;
|
||||
}
|
||||
|
||||
// 2. Policy & Security
|
||||
const { decision: policyDecision, rule } = await checkPolicy(
|
||||
toolCall,
|
||||
this.config,
|
||||
this.subagent,
|
||||
);
|
||||
let decision = policyDecision;
|
||||
if (hookDecision === 'ask') {
|
||||
decision = PolicyDecision.ASK_USER;
|
||||
}
|
||||
|
||||
if (decision === PolicyDecision.DENY) {
|
||||
const { errorMessage, errorType } = getPolicyDenialError(
|
||||
@@ -610,6 +645,8 @@ export class Scheduler {
|
||||
getPreferredEditor: this.getPreferredEditor,
|
||||
schedulerId: this.schedulerId,
|
||||
onWaitingForConfirmation: this.onWaitingForConfirmation,
|
||||
systemMessage: hookSystemMessage,
|
||||
forcedDecision: hookDecision === 'ask' ? 'ask_user' : undefined,
|
||||
});
|
||||
outcome = result.outcome;
|
||||
lastDetails = result.lastDetails;
|
||||
|
||||
@@ -212,6 +212,8 @@ describe('Scheduler Parallel Execution', () => {
|
||||
mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
toolRegistry: mockToolRegistry,
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getHookSystem: vi.fn().mockReturnValue(undefined),
|
||||
isInteractive: vi.fn().mockReturnValue(true),
|
||||
getEnableHooks: vi.fn().mockReturnValue(true),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -115,10 +115,25 @@ export class ToolExecutor {
|
||||
{ shellExecutionConfig, setExecutionIdCallback },
|
||||
this.config,
|
||||
request.originalRequestName,
|
||||
true, // skipBeforeHook
|
||||
);
|
||||
|
||||
const toolResult: ToolResult = await promise;
|
||||
|
||||
if (call.request.inputModifiedByHook) {
|
||||
const modificationMsg = `\n\n[System] Tool input parameters were modified by a hook before execution.`;
|
||||
if (typeof toolResult.llmContent === 'string') {
|
||||
toolResult.llmContent += modificationMsg;
|
||||
} else if (Array.isArray(toolResult.llmContent)) {
|
||||
toolResult.llmContent.push({ text: modificationMsg });
|
||||
} else if (toolResult.llmContent) {
|
||||
toolResult.llmContent = [
|
||||
toolResult.llmContent,
|
||||
{ text: modificationMsg },
|
||||
];
|
||||
}
|
||||
}
|
||||
|
||||
if (signal.aborted) {
|
||||
completedToolCall = await this.createCancelledResult(
|
||||
call,
|
||||
|
||||
@@ -47,6 +47,8 @@ export interface ToolCallRequestInfo {
|
||||
traceId?: string;
|
||||
parentCallId?: string;
|
||||
schedulerId?: string;
|
||||
inputModifiedByHook?: boolean;
|
||||
forcedAsk?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolCallResponseInfo {
|
||||
|
||||
Reference in New Issue
Block a user