PR cleanup.

This commit is contained in:
Christian Gunderman
2026-03-13 09:42:05 -07:00
parent 6f50890fcd
commit eaefa83036
5 changed files with 179 additions and 156 deletions

View File

@@ -50,8 +50,7 @@ import { ToolExecutor } from '../scheduler/tool-executor.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { getPolicyDenialError } from '../scheduler/policy.js';
import { GeminiCliOperation } from '../telemetry/constants.js';
import { extractMcpContext } from './coreToolHookTriggers.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { evaluateBeforeToolHook } from '../scheduler/hook-utils.js';
export type {
ToolCall,
@@ -639,85 +638,48 @@ export class CoreToolScheduler {
}
// 1. Hook Check (BeforeTool)
let hookDecision: 'ask' | 'block' | undefined;
let hookSystemMessage: string | undefined;
const hookResult = await evaluateBeforeToolHook(
this.config,
toolCall.tool,
reqInfo,
invocation,
);
const hookSystem = this.config.getHookSystem();
if (hookSystem) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const toolInput = (invocation.params || {}) as Record<
string,
unknown
>;
const mcpContext = extractMcpContext(invocation, this.config);
const beforeOutput = await hookSystem.fireBeforeToolEvent(
toolCall.request.name,
toolInput,
mcpContext,
toolCall.request.originalRequestName,
if (hookResult.status === 'error') {
this.setStatusInternal(
reqInfo.callId,
CoreToolCallStatus.Error,
signal,
createErrorResponse(
reqInfo,
hookResult.error,
hookResult.errorType,
),
);
await this.checkAndNotifyCompletion(signal);
return;
}
if (beforeOutput) {
// Check if hook requested to stop entire agent execution
if (beforeOutput.shouldStopExecution()) {
const reason = beforeOutput.getEffectiveReason();
this.setStatusInternal(
reqInfo.callId,
CoreToolCallStatus.Error,
signal,
createErrorResponse(
reqInfo,
new Error(`Agent execution stopped by hook: ${reason}`),
ToolErrorType.STOP_EXECUTION,
),
);
await this.checkAndNotifyCompletion(signal);
return;
}
const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
hookResult;
// Check if hook blocked the tool execution
const blockingError = beforeOutput.getBlockingError();
if (blockingError?.blocked) {
this.setStatusInternal(
reqInfo.callId,
CoreToolCallStatus.Error,
signal,
createErrorResponse(
reqInfo,
new Error(`Tool execution blocked: ${blockingError.reason}`),
ToolErrorType.POLICY_VIOLATION,
),
);
await this.checkAndNotifyCompletion(signal);
return;
}
if (hookDecision === 'ask') {
// Mark the request so UI knows not to auto-approve it
toolCall.request.forcedAsk = true;
}
if (beforeOutput.isAskDecision()) {
hookDecision = 'ask';
hookSystemMessage = beforeOutput.systemMessage;
// Mark the request so UI knows not to auto-approve it
toolCall.request.forcedAsk = true;
}
if (modifiedArgs && newInvocation) {
this.setArgsInternal(reqInfo.callId, modifiedArgs);
// Check if hook requested to update tool input
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
this.setArgsInternal(reqInfo.callId, modifiedInput);
// IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one
const updatedCall = this.toolCalls.find(
(c) => c.request.callId === reqInfo.callId,
);
if (updatedCall) {
toolCall = updatedCall;
toolCall.request.inputModifiedByHook = true;
if ('invocation' in updatedCall) {
invocation = updatedCall.invocation;
}
}
}
// IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one
const updatedCall = this.toolCalls.find(
(c) => c.request.callId === reqInfo.callId,
);
if (updatedCall) {
toolCall = updatedCall;
toolCall.request.inputModifiedByHook = true;
if ('invocation' in updatedCall) {
invocation = updatedCall.invocation;
}
}
}

View File

@@ -0,0 +1,109 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import type { ToolCallRequestInfo } from './types.js';
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { ToolErrorType } from '../tools/tool-error.js';
export type HookEvaluationResult =
| {
status: 'continue';
hookDecision?: 'ask' | 'block';
hookSystemMessage?: string;
modifiedArgs?: Record<string, unknown>;
newInvocation?: AnyToolInvocation;
}
| {
status: 'error';
error: Error;
errorType: ToolErrorType;
};
export async function evaluateBeforeToolHook(
config: Config,
tool: AnyDeclarativeTool,
request: ToolCallRequestInfo,
invocation: AnyToolInvocation,
): Promise<HookEvaluationResult> {
const hookSystem = config.getHookSystem();
if (!hookSystem) {
return { status: 'continue' };
}
const params = invocation.params || {};
const toolInput: Record<string, unknown> = { ...params };
const mcpContext = extractMcpContext(invocation, config);
const beforeOutput = await hookSystem.fireBeforeToolEvent(
request.name,
toolInput,
mcpContext,
request.originalRequestName,
);
if (!beforeOutput) {
return { status: 'continue' };
}
if (beforeOutput.shouldStopExecution()) {
return {
status: 'error',
error: new Error(
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
),
errorType: ToolErrorType.STOP_EXECUTION,
};
}
const blockingError = beforeOutput.getBlockingError();
if (blockingError?.blocked) {
return {
status: 'error',
error: new Error(`Tool execution blocked: ${blockingError.reason}`),
errorType: ToolErrorType.POLICY_VIOLATION,
};
}
let hookDecision: 'ask' | 'block' | undefined;
let hookSystemMessage: string | undefined;
if (beforeOutput.isAskDecision()) {
hookDecision = 'ask';
hookSystemMessage = beforeOutput.systemMessage;
}
let modifiedArgs: Record<string, unknown> | undefined;
let newInvocation: AnyToolInvocation | undefined;
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
modifiedArgs = modifiedInput;
try {
newInvocation = tool.build(modifiedInput);
} catch (error) {
return {
status: 'error',
error: new Error(
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
),
errorType: ToolErrorType.INVALID_TOOL_PARAMS,
};
}
}
}
return {
status: 'continue',
hookDecision,
hookSystemMessage,
modifiedArgs,
newInvocation,
};
}

View File

@@ -25,8 +25,8 @@ import {
type ScheduledToolCall,
} from './types.js';
import { ToolErrorType } from '../tools/tool-error.js';
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { evaluateBeforeToolHook } from './hook-utils.js';
import { PolicyDecision, type ApprovalMode } from '../policy/types.js';
import {
ToolConfirmationOutcome,
@@ -564,85 +564,37 @@ export class Scheduler {
): Promise<void> {
const callId = toolCall.request.callId;
let hookDecision: 'ask' | 'block' | undefined;
let hookSystemMessage: string | undefined;
const hookResult = await evaluateBeforeToolHook(
this.config,
toolCall.tool,
toolCall.request,
toolCall.invocation,
);
const hookSystem = this.config.getHookSystem();
if (hookSystem) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const toolInput = (toolCall.invocation.params || {}) as Record<
string,
unknown
>;
const mcpContext = extractMcpContext(toolCall.invocation, this.config);
const beforeOutput = await hookSystem.fireBeforeToolEvent(
toolCall.request.name,
toolInput,
mcpContext,
toolCall.request.originalRequestName,
if (hookResult.status === 'error') {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
hookResult.error,
hookResult.errorType,
),
);
return;
}
if (beforeOutput) {
if (beforeOutput.shouldStopExecution()) {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
),
ToolErrorType.STOP_EXECUTION,
),
);
return;
}
const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
hookResult;
const blockingError = beforeOutput.getBlockingError();
if (blockingError?.blocked) {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(`Tool execution blocked: ${blockingError.reason}`),
ToolErrorType.POLICY_VIOLATION,
),
);
return;
}
if (hookDecision === 'ask') {
toolCall.request.forcedAsk = true;
}
if (beforeOutput.isAskDecision()) {
hookDecision = 'ask';
hookSystemMessage = beforeOutput.systemMessage;
toolCall.request.forcedAsk = true;
}
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
toolCall.request.args = modifiedInput;
toolCall.request.inputModifiedByHook = true;
try {
toolCall.invocation = toolCall.tool.build(modifiedInput);
} catch (error) {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
),
ToolErrorType.INVALID_TOOL_PARAMS,
),
);
return;
}
}
}
}
if (modifiedArgs && newInvocation) {
toolCall.request.args = modifiedArgs;
toolCall.request.inputModifiedByHook = true;
toolCall.invocation = newInvocation;
}
// Policy & Security

View File

@@ -712,13 +712,13 @@ class EditToolInvocation
* It needs to calculate the diff to show the user.
*/
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
let editData: CalculatedEdit;
try {
editData = await this.calculateEdit(this.params, _abortSignal);
editData = await this.calculateEdit(this.params, abortSignal);
} catch (error) {
if (_abortSignal.aborted) {
if (abortSignal.aborted) {
throw error;
}
const errorMsg = error instanceof Error ? error.message : String(error);

View File

@@ -181,13 +181,13 @@ class WriteFileToolInvocation extends BaseToolInvocation<
}
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const correctedContentResult = await getCorrectedFileContent(
this.config,
this.resolvedPath,
this.params.content,
_abortSignal,
abortSignal,
);
if (correctedContentResult.error) {