Migrate beforeTool and afterTool hooks to hookSystem (#17204)

Co-authored-by: Ishaan Gupta <ishaankone@gmail.com>
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Vedant Mahajan
2026-01-21 22:43:03 +05:30
committed by GitHub
parent 4d77934a83
commit 6b14dc8240
6 changed files with 149 additions and 225 deletions

View File

@@ -13,10 +13,12 @@ import {
type AnyDeclarativeTool,
} from '../tools/tools.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { HookSystem } from '../hooks/hookSystem.js';
import type { Config } from '../config/config.js';
import {
MessageBusType,
type HookExecutionResponse,
} from '../confirmation-bus/types.js';
type DefaultHookOutput,
BeforeToolHookOutput,
} from '../hooks/types.js';
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
constructor(params: { key?: string }, messageBus: MessageBus) {
@@ -38,6 +40,8 @@ class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
describe('executeToolWithHooks', () => {
let messageBus: MessageBus;
let mockTool: AnyDeclarativeTool;
let mockHookSystem: HookSystem;
let mockConfig: Config;
beforeEach(() => {
messageBus = {
@@ -46,6 +50,15 @@ describe('executeToolWithHooks', () => {
subscribe: vi.fn(),
unsubscribe: vi.fn(),
} as unknown as MessageBus;
mockHookSystem = {
fireBeforeToolEvent: vi.fn(),
fireAfterToolEvent: vi.fn(),
} as unknown as HookSystem;
mockConfig = {
getHookSystem: vi.fn().mockReturnValue(mockHookSystem),
getMcpClientManager: vi.fn().mockReturnValue(undefined),
getMcpServers: vi.fn().mockReturnValue({}),
} as unknown as Config;
mockTool = {
build: vi
.fn()
@@ -57,25 +70,24 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
continue: false,
stopReason: 'Stop immediately',
decision: 'block',
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stop immediately',
getBlockingError: () => ({
blocked: false,
reason: 'Should be ignored because continue is false',
},
} as HookExecutionResponse);
}),
} as unknown as DefaultHookOutput);
const result = await executeToolWithHooks(
invocation,
'test_tool',
abortSignal,
messageBus,
true,
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
@@ -86,23 +98,21 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
decision: 'block',
reason: 'Execution blocked',
},
} as HookExecutionResponse);
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
shouldStopExecution: () => false,
getEffectiveReason: () => '',
getBlockingError: () => ({ blocked: true, reason: 'Execution blocked' }),
} as unknown as DefaultHookOutput);
const result = await executeToolWithHooks(
invocation,
'test_tool',
abortSignal,
messageBus,
true,
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
@@ -114,32 +124,27 @@ describe('executeToolWithHooks', () => {
const abortSignal = new AbortController().signal;
const spy = vi.spyOn(invocation, 'execute');
// BeforeTool allow
vi.mocked(messageBus.request)
.mockResolvedValueOnce({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: { decision: 'allow' },
} as HookExecutionResponse)
// AfterTool stop
.mockResolvedValueOnce({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
continue: false,
stopReason: 'Stop after execution',
},
} as HookExecutionResponse);
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
shouldStopExecution: () => false,
getEffectiveReason: () => '',
getBlockingError: () => ({ blocked: false, reason: '' }),
} as unknown as DefaultHookOutput);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stop after execution',
getBlockingError: () => ({ blocked: false, reason: '' }),
} as unknown as DefaultHookOutput);
const result = await executeToolWithHooks(
invocation,
'test_tool',
abortSignal,
messageBus,
true,
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
@@ -151,32 +156,27 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal;
// BeforeTool allow
vi.mocked(messageBus.request)
.mockResolvedValueOnce({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: { decision: 'allow' },
} as HookExecutionResponse)
// AfterTool deny
.mockResolvedValueOnce({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
decision: 'deny',
reason: 'Result denied',
},
} as HookExecutionResponse);
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
shouldStopExecution: () => false,
getEffectiveReason: () => '',
getBlockingError: () => ({ blocked: false, reason: '' }),
} as unknown as DefaultHookOutput);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
shouldStopExecution: () => false,
getEffectiveReason: () => '',
getBlockingError: () => ({ blocked: true, reason: 'Result denied' }),
} as unknown as DefaultHookOutput);
const result = await executeToolWithHooks(
invocation,
'test_tool',
abortSignal,
messageBus,
true,
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
@@ -189,39 +189,28 @@ describe('executeToolWithHooks', () => {
const toolName = 'test-tool';
const abortSignal = new AbortController().signal;
// Capture arguments to verify what was passed before modification
const requestSpy = vi.fn().mockImplementation(async (request) => {
if (request.eventName === 'BeforeTool') {
// Verify input is original before we return modification instruction
expect(request.input.tool_input.key).toBe('original');
return {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
hookSpecificOutput: {
hookEventName: 'BeforeTool',
tool_input: { key: 'modified' },
},
},
} as HookExecutionResponse;
}
return {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {},
} as HookExecutionResponse;
const mockBeforeOutput = new BeforeToolHookOutput({
continue: true,
hookSpecificOutput: {
hookEventName: 'BeforeTool',
tool_input: { key: 'modified' },
},
});
messageBus.request = requestSpy;
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
mockBeforeOutput,
);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
const result = await executeToolWithHooks(
invocation,
toolName,
abortSignal,
messageBus,
true, // hooksEnabled
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
// Verify result reflects modified input
@@ -231,7 +220,7 @@ describe('executeToolWithHooks', () => {
// Verify params object was modified in place
expect(invocation.params.key).toBe('modified');
expect(requestSpy).toHaveBeenCalled();
expect(mockHookSystem.fireBeforeToolEvent).toHaveBeenCalled();
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
});
@@ -241,25 +230,28 @@ describe('executeToolWithHooks', () => {
const toolName = 'test-tool';
const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
hookSpecificOutput: {
hookEventName: 'BeforeTool',
// No tool_input
},
const mockBeforeOutput = new BeforeToolHookOutput({
continue: true,
hookSpecificOutput: {
hookEventName: 'BeforeTool',
// No tool input
},
} as HookExecutionResponse);
});
vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
mockBeforeOutput,
);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
const result = await executeToolWithHooks(
invocation,
toolName,
abortSignal,
messageBus,
true, // hooksEnabled
mockTool,
undefined,
undefined,
undefined,
mockConfig,
);
expect(result.llmContent).toBe('key: original');

View File

@@ -11,9 +11,7 @@ import {
type HookExecutionResponse,
} from '../confirmation-bus/types.js';
import {
createHookOutput,
NotificationType,
type DefaultHookOutput,
type McpToolContext,
BeforeToolHookOutput,
} from '../hooks/types.js';
@@ -194,103 +192,12 @@ function extractMcpContext(
};
}
/**
* Fires the BeforeTool hook and returns the hook output.
*
* @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool being executed
* @param toolInput The input parameters for the tool
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireBeforeToolHook(
messageBus: MessageBus,
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'BeforeTool',
input: {
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('BeforeTool', response.output)
: undefined;
} catch (error) {
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
return undefined;
}
}
/**
* Fires the AfterTool hook and returns the hook output.
*
* @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool that was executed
* @param toolInput The input parameters for the tool
* @param toolResponse The result from the tool execution
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireAfterToolHook(
messageBus: MessageBus,
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: {
llmContent: ToolResult['llmContent'];
returnDisplay: ToolResult['returnDisplay'];
error: ToolResult['error'];
},
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'AfterTool',
input: {
tool_name: toolName,
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('AfterTool', response.output)
: undefined;
} catch (error) {
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
return undefined;
}
}
/**
* Execute a tool with BeforeTool and AfterTool hooks.
*
* @param invocation The tool invocation to execute
* @param toolName The name of the tool
* @param signal Abort signal for cancellation
* @param messageBus Optional message bus for hook communication
* @param hooksEnabled Whether hooks are enabled
* @param liveOutputCallback Optional callback for live output updates
* @param shellExecutionConfig Optional shell execution config
* @param setPidCallback Optional callback to set the PID for shell invocations
@@ -301,8 +208,6 @@ export async function executeToolWithHooks(
invocation: ShellToolInvocation | AnyToolInvocation,
toolName: string,
signal: AbortSignal,
messageBus: MessageBus | undefined,
hooksEnabled: boolean,
tool: AnyDeclarativeTool,
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
@@ -316,10 +221,9 @@ export async function executeToolWithHooks(
// Extract MCP context if this is an MCP tool (only if config is provided)
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const beforeOutput = await fireBeforeToolHook(
messageBus,
const hookSystem = config?.getHookSystem();
if (hookSystem) {
const beforeOutput = await hookSystem.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
@@ -419,10 +323,8 @@ export async function executeToolWithHooks(
}
}
// Fire AfterTool hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const afterOutput = await fireAfterToolHook(
messageBus,
if (hookSystem) {
const afterOutput = await hookSystem.fireAfterToolEvent(
toolName,
toolInput,
{

View File

@@ -1889,6 +1889,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
}) as unknown as PolicyEngine,
isInteractive: () => false,
});
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({
config: mockConfig,
@@ -2018,6 +2019,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getApprovalMode: () => ApprovalMode.YOLO,
isInteractive: () => false,
});
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({
config: mockConfig,
@@ -2105,6 +2107,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
check: async () => ({ decision: PolicyDecision.DENY }),
}) as unknown as PolicyEngine,
});
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({
config: mockConfig,

View File

@@ -22,6 +22,7 @@ import type {
BeforeModelHookOutput,
AfterModelHookOutput,
BeforeToolSelectionHookOutput,
McpToolContext,
} from './types.js';
import type { AggregatedHookResult } from './hookAggregator.js';
import type {
@@ -297,4 +298,46 @@ export class HookSystem {
return {};
}
}
async fireBeforeToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
);
return result.finalOutput;
} catch (error) {
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
return undefined;
}
}
async fireAfterToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: {
llmContent: unknown;
returnDisplay: unknown;
error: unknown;
},
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireAfterToolEvent(
toolName,
toolInput,
toolResponse as Record<string, unknown>,
mcpContext,
);
return result.finalOutput;
} catch (error) {
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
return undefined;
}
}
}

View File

@@ -253,17 +253,7 @@ describe('ToolExecutor', () => {
// 2. Mock executeToolWithHooks to trigger the PID callback
const testPid = 12345;
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async (
_inv,
_name,
_sig,
_bus,
_hooks,
_tool,
_liveCb,
_shellCfg,
setPidCallback,
) => {
async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
// Simulate the shell tool reporting a PID
if (setPidCallback) {
setPidCallback(testPid);

View File

@@ -66,8 +66,6 @@ export class ToolExecutor {
: undefined;
const shellExecutionConfig = this.config.getShellExecutionConfig();
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
return runInDevTraceSpan(
{
@@ -95,8 +93,6 @@ export class ToolExecutor {
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
tool,
liveOutputCallback,
shellExecutionConfig,
@@ -108,8 +104,6 @@ export class ToolExecutor {
invocation,
toolName,
signal,
messageBus,
hooksEnabled,
tool,
liveOutputCallback,
shellExecutionConfig,