diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index fdf67dba8e..e6d4c62d9f 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -50,8 +50,7 @@ import { ToolExecutor } from '../scheduler/tool-executor.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { getPolicyDenialError } from '../scheduler/policy.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; -import { extractMcpContext } from './coreToolHookTriggers.js'; -import { BeforeToolHookOutput } from '../hooks/types.js'; +import { evaluateBeforeToolHook } from '../scheduler/hook-utils.js'; export type { ToolCall, @@ -639,85 +638,48 @@ export class CoreToolScheduler { } // 1. Hook Check (BeforeTool) - let hookDecision: 'ask' | 'block' | undefined; - let hookSystemMessage: string | undefined; + const hookResult = await evaluateBeforeToolHook( + this.config, + toolCall.tool, + reqInfo, + invocation, + ); - const hookSystem = this.config.getHookSystem(); - if (hookSystem) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const toolInput = (invocation.params || {}) as Record< - string, - unknown - >; - const mcpContext = extractMcpContext(invocation, this.config); - - const beforeOutput = await hookSystem.fireBeforeToolEvent( - toolCall.request.name, - toolInput, - mcpContext, - toolCall.request.originalRequestName, + if (hookResult.status === 'error') { + this.setStatusInternal( + reqInfo.callId, + CoreToolCallStatus.Error, + signal, + createErrorResponse( + reqInfo, + hookResult.error, + hookResult.errorType, + ), ); + await this.checkAndNotifyCompletion(signal); + return; + } - if (beforeOutput) { - // Check if hook requested to stop entire agent execution - if (beforeOutput.shouldStopExecution()) { - const reason = beforeOutput.getEffectiveReason(); - this.setStatusInternal( - reqInfo.callId, - CoreToolCallStatus.Error, - signal, - createErrorResponse( - reqInfo, - new Error(`Agent execution stopped by hook: ${reason}`), - ToolErrorType.STOP_EXECUTION, - ), - ); - await this.checkAndNotifyCompletion(signal); - return; - } + const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } = + hookResult; - // Check if hook blocked the tool execution - const blockingError = beforeOutput.getBlockingError(); - if (blockingError?.blocked) { - this.setStatusInternal( - reqInfo.callId, - CoreToolCallStatus.Error, - signal, - createErrorResponse( - reqInfo, - new Error(`Tool execution blocked: ${blockingError.reason}`), - ToolErrorType.POLICY_VIOLATION, - ), - ); - await this.checkAndNotifyCompletion(signal); - return; - } + if (hookDecision === 'ask') { + // Mark the request so UI knows not to auto-approve it + toolCall.request.forcedAsk = true; + } - if (beforeOutput.isAskDecision()) { - hookDecision = 'ask'; - hookSystemMessage = beforeOutput.systemMessage; - // Mark the request so UI knows not to auto-approve it - toolCall.request.forcedAsk = true; - } + if (modifiedArgs && newInvocation) { + this.setArgsInternal(reqInfo.callId, modifiedArgs); - // Check if hook requested to update tool input - if (beforeOutput instanceof BeforeToolHookOutput) { - const modifiedInput = beforeOutput.getModifiedToolInput(); - if (modifiedInput) { - this.setArgsInternal(reqInfo.callId, modifiedInput); - - // IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one - const updatedCall = this.toolCalls.find( - (c) => c.request.callId === reqInfo.callId, - ); - if (updatedCall) { - toolCall = updatedCall; - toolCall.request.inputModifiedByHook = true; - if ('invocation' in updatedCall) { - invocation = updatedCall.invocation; - } - } - } + // IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one + const updatedCall = this.toolCalls.find( + (c) => c.request.callId === reqInfo.callId, + ); + if (updatedCall) { + toolCall = updatedCall; + toolCall.request.inputModifiedByHook = true; + if ('invocation' in updatedCall) { + invocation = updatedCall.invocation; } } } diff --git a/packages/core/src/scheduler/hook-utils.ts b/packages/core/src/scheduler/hook-utils.ts new file mode 100644 index 0000000000..78d5aeaa53 --- /dev/null +++ b/packages/core/src/scheduler/hook-utils.ts @@ -0,0 +1,109 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js'; +import type { ToolCallRequestInfo } from './types.js'; +import { extractMcpContext } from '../core/coreToolHookTriggers.js'; +import { BeforeToolHookOutput } from '../hooks/types.js'; +import { ToolErrorType } from '../tools/tool-error.js'; + +export type HookEvaluationResult = + | { + status: 'continue'; + hookDecision?: 'ask' | 'block'; + hookSystemMessage?: string; + modifiedArgs?: Record; + newInvocation?: AnyToolInvocation; + } + | { + status: 'error'; + error: Error; + errorType: ToolErrorType; + }; + +export async function evaluateBeforeToolHook( + config: Config, + tool: AnyDeclarativeTool, + request: ToolCallRequestInfo, + invocation: AnyToolInvocation, +): Promise { + const hookSystem = config.getHookSystem(); + if (!hookSystem) { + return { status: 'continue' }; + } + + const params = invocation.params || {}; + const toolInput: Record = { ...params }; + const mcpContext = extractMcpContext(invocation, config); + + const beforeOutput = await hookSystem.fireBeforeToolEvent( + request.name, + toolInput, + mcpContext, + request.originalRequestName, + ); + + if (!beforeOutput) { + return { status: 'continue' }; + } + + if (beforeOutput.shouldStopExecution()) { + return { + status: 'error', + error: new Error( + `Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`, + ), + errorType: ToolErrorType.STOP_EXECUTION, + }; + } + + const blockingError = beforeOutput.getBlockingError(); + if (blockingError?.blocked) { + return { + status: 'error', + error: new Error(`Tool execution blocked: ${blockingError.reason}`), + errorType: ToolErrorType.POLICY_VIOLATION, + }; + } + + let hookDecision: 'ask' | 'block' | undefined; + let hookSystemMessage: string | undefined; + + if (beforeOutput.isAskDecision()) { + hookDecision = 'ask'; + hookSystemMessage = beforeOutput.systemMessage; + } + + let modifiedArgs: Record | undefined; + let newInvocation: AnyToolInvocation | undefined; + + if (beforeOutput instanceof BeforeToolHookOutput) { + const modifiedInput = beforeOutput.getModifiedToolInput(); + if (modifiedInput) { + modifiedArgs = modifiedInput; + try { + newInvocation = tool.build(modifiedInput); + } catch (error) { + return { + status: 'error', + error: new Error( + `Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`, + ), + errorType: ToolErrorType.INVALID_TOOL_PARAMS, + }; + } + } + } + + return { + status: 'continue', + hookDecision, + hookSystemMessage, + modifiedArgs, + newInvocation, + }; +} diff --git a/packages/core/src/scheduler/scheduler.ts b/packages/core/src/scheduler/scheduler.ts index 8b841ee69e..9aa1eb376a 100644 --- a/packages/core/src/scheduler/scheduler.ts +++ b/packages/core/src/scheduler/scheduler.ts @@ -25,8 +25,8 @@ import { type ScheduledToolCall, } from './types.js'; import { ToolErrorType } from '../tools/tool-error.js'; -import { extractMcpContext } from '../core/coreToolHookTriggers.js'; -import { BeforeToolHookOutput } from '../hooks/types.js'; +import { evaluateBeforeToolHook } from './hook-utils.js'; + import { PolicyDecision, type ApprovalMode } from '../policy/types.js'; import { ToolConfirmationOutcome, @@ -564,85 +564,37 @@ export class Scheduler { ): Promise { const callId = toolCall.request.callId; - let hookDecision: 'ask' | 'block' | undefined; - let hookSystemMessage: string | undefined; + const hookResult = await evaluateBeforeToolHook( + this.config, + toolCall.tool, + toolCall.request, + toolCall.invocation, + ); - const hookSystem = this.config.getHookSystem(); - if (hookSystem) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const toolInput = (toolCall.invocation.params || {}) as Record< - string, - unknown - >; - const mcpContext = extractMcpContext(toolCall.invocation, this.config); - - const beforeOutput = await hookSystem.fireBeforeToolEvent( - toolCall.request.name, - toolInput, - mcpContext, - toolCall.request.originalRequestName, + if (hookResult.status === 'error') { + this.state.updateStatus( + callId, + CoreToolCallStatus.Error, + createErrorResponse( + toolCall.request, + hookResult.error, + hookResult.errorType, + ), ); + return; + } - if (beforeOutput) { - if (beforeOutput.shouldStopExecution()) { - this.state.updateStatus( - callId, - CoreToolCallStatus.Error, - createErrorResponse( - toolCall.request, - new Error( - `Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`, - ), - ToolErrorType.STOP_EXECUTION, - ), - ); - return; - } + const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } = + hookResult; - const blockingError = beforeOutput.getBlockingError(); - if (blockingError?.blocked) { - this.state.updateStatus( - callId, - CoreToolCallStatus.Error, - createErrorResponse( - toolCall.request, - new Error(`Tool execution blocked: ${blockingError.reason}`), - ToolErrorType.POLICY_VIOLATION, - ), - ); - return; - } + if (hookDecision === 'ask') { + toolCall.request.forcedAsk = true; + } - if (beforeOutput.isAskDecision()) { - hookDecision = 'ask'; - hookSystemMessage = beforeOutput.systemMessage; - toolCall.request.forcedAsk = true; - } - - if (beforeOutput instanceof BeforeToolHookOutput) { - const modifiedInput = beforeOutput.getModifiedToolInput(); - if (modifiedInput) { - toolCall.request.args = modifiedInput; - toolCall.request.inputModifiedByHook = true; - try { - toolCall.invocation = toolCall.tool.build(modifiedInput); - } catch (error) { - this.state.updateStatus( - callId, - CoreToolCallStatus.Error, - createErrorResponse( - toolCall.request, - new Error( - `Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`, - ), - ToolErrorType.INVALID_TOOL_PARAMS, - ), - ); - return; - } - } - } - } + if (modifiedArgs && newInvocation) { + toolCall.request.args = modifiedArgs; + toolCall.request.inputModifiedByHook = true; + toolCall.invocation = newInvocation; } // Policy & Security diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index ae4b592f39..588f2b30fc 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -712,13 +712,13 @@ class EditToolInvocation * It needs to calculate the diff to show the user. */ protected override async getConfirmationDetails( - _abortSignal: AbortSignal, + abortSignal: AbortSignal, ): Promise { let editData: CalculatedEdit; try { - editData = await this.calculateEdit(this.params, _abortSignal); + editData = await this.calculateEdit(this.params, abortSignal); } catch (error) { - if (_abortSignal.aborted) { + if (abortSignal.aborted) { throw error; } const errorMsg = error instanceof Error ? error.message : String(error); diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index dc69c96f5e..0bb868eff8 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -181,13 +181,13 @@ class WriteFileToolInvocation extends BaseToolInvocation< } protected override async getConfirmationDetails( - _abortSignal: AbortSignal, + abortSignal: AbortSignal, ): Promise { const correctedContentResult = await getCorrectedFileContent( this.config, this.resolvedPath, this.params.content, - _abortSignal, + abortSignal, ); if (correctedContentResult.error) {