PR cleanup.

This commit is contained in:
Christian Gunderman
2026-03-13 09:42:05 -07:00
parent 6f50890fcd
commit eaefa83036
5 changed files with 179 additions and 156 deletions

View File

@@ -50,8 +50,7 @@ import { ToolExecutor } from '../scheduler/tool-executor.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { getPolicyDenialError } from '../scheduler/policy.js'; import { getPolicyDenialError } from '../scheduler/policy.js';
import { GeminiCliOperation } from '../telemetry/constants.js'; import { GeminiCliOperation } from '../telemetry/constants.js';
import { extractMcpContext } from './coreToolHookTriggers.js'; import { evaluateBeforeToolHook } from '../scheduler/hook-utils.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
export type { export type {
ToolCall, ToolCall,
@@ -639,85 +638,48 @@ export class CoreToolScheduler {
} }
// 1. Hook Check (BeforeTool) // 1. Hook Check (BeforeTool)
let hookDecision: 'ask' | 'block' | undefined; const hookResult = await evaluateBeforeToolHook(
let hookSystemMessage: string | undefined; this.config,
toolCall.tool,
reqInfo,
invocation,
);
const hookSystem = this.config.getHookSystem(); if (hookResult.status === 'error') {
if (hookSystem) { this.setStatusInternal(
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion reqInfo.callId,
const toolInput = (invocation.params || {}) as Record< CoreToolCallStatus.Error,
string, signal,
unknown createErrorResponse(
>; reqInfo,
const mcpContext = extractMcpContext(invocation, this.config); hookResult.error,
hookResult.errorType,
const beforeOutput = await hookSystem.fireBeforeToolEvent( ),
toolCall.request.name,
toolInput,
mcpContext,
toolCall.request.originalRequestName,
); );
await this.checkAndNotifyCompletion(signal);
return;
}
if (beforeOutput) { const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
// Check if hook requested to stop entire agent execution hookResult;
if (beforeOutput.shouldStopExecution()) {
const reason = beforeOutput.getEffectiveReason();
this.setStatusInternal(
reqInfo.callId,
CoreToolCallStatus.Error,
signal,
createErrorResponse(
reqInfo,
new Error(`Agent execution stopped by hook: ${reason}`),
ToolErrorType.STOP_EXECUTION,
),
);
await this.checkAndNotifyCompletion(signal);
return;
}
// Check if hook blocked the tool execution if (hookDecision === 'ask') {
const blockingError = beforeOutput.getBlockingError(); // Mark the request so UI knows not to auto-approve it
if (blockingError?.blocked) { toolCall.request.forcedAsk = true;
this.setStatusInternal( }
reqInfo.callId,
CoreToolCallStatus.Error,
signal,
createErrorResponse(
reqInfo,
new Error(`Tool execution blocked: ${blockingError.reason}`),
ToolErrorType.POLICY_VIOLATION,
),
);
await this.checkAndNotifyCompletion(signal);
return;
}
if (beforeOutput.isAskDecision()) { if (modifiedArgs && newInvocation) {
hookDecision = 'ask'; this.setArgsInternal(reqInfo.callId, modifiedArgs);
hookSystemMessage = beforeOutput.systemMessage;
// Mark the request so UI knows not to auto-approve it
toolCall.request.forcedAsk = true;
}
// Check if hook requested to update tool input // IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one
if (beforeOutput instanceof BeforeToolHookOutput) { const updatedCall = this.toolCalls.find(
const modifiedInput = beforeOutput.getModifiedToolInput(); (c) => c.request.callId === reqInfo.callId,
if (modifiedInput) { );
this.setArgsInternal(reqInfo.callId, modifiedInput); if (updatedCall) {
toolCall = updatedCall;
// IMPORTANT: toolCall and invocation must be updated because setArgsInternal created a new one toolCall.request.inputModifiedByHook = true;
const updatedCall = this.toolCalls.find( if ('invocation' in updatedCall) {
(c) => c.request.callId === reqInfo.callId, invocation = updatedCall.invocation;
);
if (updatedCall) {
toolCall = updatedCall;
toolCall.request.inputModifiedByHook = true;
if ('invocation' in updatedCall) {
invocation = updatedCall.invocation;
}
}
}
} }
} }
} }

View File

@@ -0,0 +1,109 @@
/**
* @license
* Copyright 2026 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type { AnyDeclarativeTool, AnyToolInvocation } from '../tools/tools.js';
import type { ToolCallRequestInfo } from './types.js';
import { extractMcpContext } from '../core/coreToolHookTriggers.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { ToolErrorType } from '../tools/tool-error.js';
export type HookEvaluationResult =
| {
status: 'continue';
hookDecision?: 'ask' | 'block';
hookSystemMessage?: string;
modifiedArgs?: Record<string, unknown>;
newInvocation?: AnyToolInvocation;
}
| {
status: 'error';
error: Error;
errorType: ToolErrorType;
};
export async function evaluateBeforeToolHook(
config: Config,
tool: AnyDeclarativeTool,
request: ToolCallRequestInfo,
invocation: AnyToolInvocation,
): Promise<HookEvaluationResult> {
const hookSystem = config.getHookSystem();
if (!hookSystem) {
return { status: 'continue' };
}
const params = invocation.params || {};
const toolInput: Record<string, unknown> = { ...params };
const mcpContext = extractMcpContext(invocation, config);
const beforeOutput = await hookSystem.fireBeforeToolEvent(
request.name,
toolInput,
mcpContext,
request.originalRequestName,
);
if (!beforeOutput) {
return { status: 'continue' };
}
if (beforeOutput.shouldStopExecution()) {
return {
status: 'error',
error: new Error(
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
),
errorType: ToolErrorType.STOP_EXECUTION,
};
}
const blockingError = beforeOutput.getBlockingError();
if (blockingError?.blocked) {
return {
status: 'error',
error: new Error(`Tool execution blocked: ${blockingError.reason}`),
errorType: ToolErrorType.POLICY_VIOLATION,
};
}
let hookDecision: 'ask' | 'block' | undefined;
let hookSystemMessage: string | undefined;
if (beforeOutput.isAskDecision()) {
hookDecision = 'ask';
hookSystemMessage = beforeOutput.systemMessage;
}
let modifiedArgs: Record<string, unknown> | undefined;
let newInvocation: AnyToolInvocation | undefined;
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
modifiedArgs = modifiedInput;
try {
newInvocation = tool.build(modifiedInput);
} catch (error) {
return {
status: 'error',
error: new Error(
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
),
errorType: ToolErrorType.INVALID_TOOL_PARAMS,
};
}
}
}
return {
status: 'continue',
hookDecision,
hookSystemMessage,
modifiedArgs,
newInvocation,
};
}

View File

@@ -25,8 +25,8 @@ import {
type ScheduledToolCall, type ScheduledToolCall,
} from './types.js'; } from './types.js';
import { ToolErrorType } from '../tools/tool-error.js'; import { ToolErrorType } from '../tools/tool-error.js';
import { extractMcpContext } from '../core/coreToolHookTriggers.js'; import { evaluateBeforeToolHook } from './hook-utils.js';
import { BeforeToolHookOutput } from '../hooks/types.js';
import { PolicyDecision, type ApprovalMode } from '../policy/types.js'; import { PolicyDecision, type ApprovalMode } from '../policy/types.js';
import { import {
ToolConfirmationOutcome, ToolConfirmationOutcome,
@@ -564,85 +564,37 @@ export class Scheduler {
): Promise<void> { ): Promise<void> {
const callId = toolCall.request.callId; const callId = toolCall.request.callId;
let hookDecision: 'ask' | 'block' | undefined; const hookResult = await evaluateBeforeToolHook(
let hookSystemMessage: string | undefined; this.config,
toolCall.tool,
toolCall.request,
toolCall.invocation,
);
const hookSystem = this.config.getHookSystem(); if (hookResult.status === 'error') {
if (hookSystem) { this.state.updateStatus(
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion callId,
const toolInput = (toolCall.invocation.params || {}) as Record< CoreToolCallStatus.Error,
string, createErrorResponse(
unknown toolCall.request,
>; hookResult.error,
const mcpContext = extractMcpContext(toolCall.invocation, this.config); hookResult.errorType,
),
const beforeOutput = await hookSystem.fireBeforeToolEvent(
toolCall.request.name,
toolInput,
mcpContext,
toolCall.request.originalRequestName,
); );
return;
}
if (beforeOutput) { const { hookDecision, hookSystemMessage, modifiedArgs, newInvocation } =
if (beforeOutput.shouldStopExecution()) { hookResult;
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(
`Agent execution stopped by hook: ${beforeOutput.getEffectiveReason()}`,
),
ToolErrorType.STOP_EXECUTION,
),
);
return;
}
const blockingError = beforeOutput.getBlockingError(); if (hookDecision === 'ask') {
if (blockingError?.blocked) { toolCall.request.forcedAsk = true;
this.state.updateStatus( }
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(`Tool execution blocked: ${blockingError.reason}`),
ToolErrorType.POLICY_VIOLATION,
),
);
return;
}
if (beforeOutput.isAskDecision()) { if (modifiedArgs && newInvocation) {
hookDecision = 'ask'; toolCall.request.args = modifiedArgs;
hookSystemMessage = beforeOutput.systemMessage; toolCall.request.inputModifiedByHook = true;
toolCall.request.forcedAsk = true; toolCall.invocation = newInvocation;
}
if (beforeOutput instanceof BeforeToolHookOutput) {
const modifiedInput = beforeOutput.getModifiedToolInput();
if (modifiedInput) {
toolCall.request.args = modifiedInput;
toolCall.request.inputModifiedByHook = true;
try {
toolCall.invocation = toolCall.tool.build(modifiedInput);
} catch (error) {
this.state.updateStatus(
callId,
CoreToolCallStatus.Error,
createErrorResponse(
toolCall.request,
new Error(
`Tool parameter modification by hook failed validation: ${error instanceof Error ? error.message : String(error)}`,
),
ToolErrorType.INVALID_TOOL_PARAMS,
),
);
return;
}
}
}
}
} }
// Policy & Security // Policy & Security

View File

@@ -712,13 +712,13 @@ class EditToolInvocation
* It needs to calculate the diff to show the user. * It needs to calculate the diff to show the user.
*/ */
protected override async getConfirmationDetails( protected override async getConfirmationDetails(
_abortSignal: AbortSignal, abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
let editData: CalculatedEdit; let editData: CalculatedEdit;
try { try {
editData = await this.calculateEdit(this.params, _abortSignal); editData = await this.calculateEdit(this.params, abortSignal);
} catch (error) { } catch (error) {
if (_abortSignal.aborted) { if (abortSignal.aborted) {
throw error; throw error;
} }
const errorMsg = error instanceof Error ? error.message : String(error); const errorMsg = error instanceof Error ? error.message : String(error);

View File

@@ -181,13 +181,13 @@ class WriteFileToolInvocation extends BaseToolInvocation<
} }
protected override async getConfirmationDetails( protected override async getConfirmationDetails(
_abortSignal: AbortSignal, abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
const correctedContentResult = await getCorrectedFileContent( const correctedContentResult = await getCorrectedFileContent(
this.config, this.config,
this.resolvedPath, this.resolvedPath,
this.params.content, this.params.content,
_abortSignal, abortSignal,
); );
if (correctedContentResult.error) { if (correctedContentResult.error) {