Migrate beforeTool and afterTool hooks to hookSystem (#17204)

Co-authored-by: Ishaan Gupta <ishaankone@gmail.com>
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Vedant Mahajan
2026-01-21 22:43:03 +05:30
committed by GitHub
parent 4d77934a83
commit 6b14dc8240
6 changed files with 149 additions and 225 deletions
@@ -13,10 +13,12 @@ import {
type AnyDeclarativeTool, type AnyDeclarativeTool,
} from '../tools/tools.js'; } from '../tools/tools.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type { HookSystem } from '../hooks/hookSystem.js';
import type { Config } from '../config/config.js';
import { import {
MessageBusType, type DefaultHookOutput,
type HookExecutionResponse, BeforeToolHookOutput,
} from '../confirmation-bus/types.js'; } from '../hooks/types.js';
class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> { class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
constructor(params: { key?: string }, messageBus: MessageBus) { constructor(params: { key?: string }, messageBus: MessageBus) {
@@ -38,6 +40,8 @@ class MockInvocation extends BaseToolInvocation<{ key?: string }, ToolResult> {
describe('executeToolWithHooks', () => { describe('executeToolWithHooks', () => {
let messageBus: MessageBus; let messageBus: MessageBus;
let mockTool: AnyDeclarativeTool; let mockTool: AnyDeclarativeTool;
let mockHookSystem: HookSystem;
let mockConfig: Config;
beforeEach(() => { beforeEach(() => {
messageBus = { messageBus = {
@@ -46,6 +50,15 @@ describe('executeToolWithHooks', () => {
subscribe: vi.fn(), subscribe: vi.fn(),
unsubscribe: vi.fn(), unsubscribe: vi.fn(),
} as unknown as MessageBus; } as unknown as MessageBus;
mockHookSystem = {
fireBeforeToolEvent: vi.fn(),
fireAfterToolEvent: vi.fn(),
} as unknown as HookSystem;
mockConfig = {
getHookSystem: vi.fn().mockReturnValue(mockHookSystem),
getMcpClientManager: vi.fn().mockReturnValue(undefined),
getMcpServers: vi.fn().mockReturnValue({}),
} as unknown as Config;
mockTool = { mockTool = {
build: vi build: vi
.fn() .fn()
@@ -57,25 +70,24 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus); const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({ vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
type: MessageBusType.HOOK_EXECUTION_RESPONSE, shouldStopExecution: () => true,
correlationId: 'test-id', getEffectiveReason: () => 'Stop immediately',
success: true, getBlockingError: () => ({
output: { blocked: false,
continue: false,
stopReason: 'Stop immediately',
decision: 'block',
reason: 'Should be ignored because continue is false', reason: 'Should be ignored because continue is false',
}, }),
} as HookExecutionResponse); } as unknown as DefaultHookOutput);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
'test_tool', 'test_tool',
abortSignal, abortSignal,
messageBus,
true,
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION); expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
@@ -86,23 +98,21 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus); const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({ vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
type: MessageBusType.HOOK_EXECUTION_RESPONSE, shouldStopExecution: () => false,
correlationId: 'test-id', getEffectiveReason: () => '',
success: true, getBlockingError: () => ({ blocked: true, reason: 'Execution blocked' }),
output: { } as unknown as DefaultHookOutput);
decision: 'block',
reason: 'Execution blocked',
},
} as HookExecutionResponse);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
'test_tool', 'test_tool',
abortSignal, abortSignal,
messageBus,
true,
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED); expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
@@ -114,32 +124,27 @@ describe('executeToolWithHooks', () => {
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
const spy = vi.spyOn(invocation, 'execute'); const spy = vi.spyOn(invocation, 'execute');
// BeforeTool allow vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
vi.mocked(messageBus.request) shouldStopExecution: () => false,
.mockResolvedValueOnce({ getEffectiveReason: () => '',
type: MessageBusType.HOOK_EXECUTION_RESPONSE, getBlockingError: () => ({ blocked: false, reason: '' }),
correlationId: 'test-id', } as unknown as DefaultHookOutput);
success: true,
output: { decision: 'allow' }, vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
} as HookExecutionResponse) shouldStopExecution: () => true,
// AfterTool stop getEffectiveReason: () => 'Stop after execution',
.mockResolvedValueOnce({ getBlockingError: () => ({ blocked: false, reason: '' }),
type: MessageBusType.HOOK_EXECUTION_RESPONSE, } as unknown as DefaultHookOutput);
correlationId: 'test-id',
success: true,
output: {
continue: false,
stopReason: 'Stop after execution',
},
} as HookExecutionResponse);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
'test_tool', 'test_tool',
abortSignal, abortSignal,
messageBus,
true,
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION); expect(result.error?.type).toBe(ToolErrorType.STOP_EXECUTION);
@@ -151,32 +156,27 @@ describe('executeToolWithHooks', () => {
const invocation = new MockInvocation({}, messageBus); const invocation = new MockInvocation({}, messageBus);
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
// BeforeTool allow vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue({
vi.mocked(messageBus.request) shouldStopExecution: () => false,
.mockResolvedValueOnce({ getEffectiveReason: () => '',
type: MessageBusType.HOOK_EXECUTION_RESPONSE, getBlockingError: () => ({ blocked: false, reason: '' }),
correlationId: 'test-id', } as unknown as DefaultHookOutput);
success: true,
output: { decision: 'allow' }, vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue({
} as HookExecutionResponse) shouldStopExecution: () => false,
// AfterTool deny getEffectiveReason: () => '',
.mockResolvedValueOnce({ getBlockingError: () => ({ blocked: true, reason: 'Result denied' }),
type: MessageBusType.HOOK_EXECUTION_RESPONSE, } as unknown as DefaultHookOutput);
correlationId: 'test-id',
success: true,
output: {
decision: 'deny',
reason: 'Result denied',
},
} as HookExecutionResponse);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
'test_tool', 'test_tool',
abortSignal, abortSignal,
messageBus,
true,
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED); expect(result.error?.type).toBe(ToolErrorType.EXECUTION_FAILED);
@@ -189,39 +189,28 @@ describe('executeToolWithHooks', () => {
const toolName = 'test-tool'; const toolName = 'test-tool';
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
// Capture arguments to verify what was passed before modification const mockBeforeOutput = new BeforeToolHookOutput({
const requestSpy = vi.fn().mockImplementation(async (request) => { continue: true,
if (request.eventName === 'BeforeTool') {
// Verify input is original before we return modification instruction
expect(request.input.tool_input.key).toBe('original');
return {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {
hookSpecificOutput: { hookSpecificOutput: {
hookEventName: 'BeforeTool', hookEventName: 'BeforeTool',
tool_input: { key: 'modified' }, tool_input: { key: 'modified' },
}, },
},
} as HookExecutionResponse;
}
return {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: 'test-id',
success: true,
output: {},
} as HookExecutionResponse;
}); });
messageBus.request = requestSpy; vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
mockBeforeOutput,
);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
toolName, toolName,
abortSignal, abortSignal,
messageBus,
true, // hooksEnabled
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
// Verify result reflects modified input // Verify result reflects modified input
@@ -231,7 +220,7 @@ describe('executeToolWithHooks', () => {
// Verify params object was modified in place // Verify params object was modified in place
expect(invocation.params.key).toBe('modified'); expect(invocation.params.key).toBe('modified');
expect(requestSpy).toHaveBeenCalled(); expect(mockHookSystem.fireBeforeToolEvent).toHaveBeenCalled();
expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' }); expect(mockTool.build).toHaveBeenCalledWith({ key: 'modified' });
}); });
@@ -241,25 +230,28 @@ describe('executeToolWithHooks', () => {
const toolName = 'test-tool'; const toolName = 'test-tool';
const abortSignal = new AbortController().signal; const abortSignal = new AbortController().signal;
vi.mocked(messageBus.request).mockResolvedValue({ const mockBeforeOutput = new BeforeToolHookOutput({
type: MessageBusType.HOOK_EXECUTION_RESPONSE, continue: true,
correlationId: 'test-id',
success: true,
output: {
hookSpecificOutput: { hookSpecificOutput: {
hookEventName: 'BeforeTool', hookEventName: 'BeforeTool',
// No tool_input // No tool input
}, },
}, });
} as HookExecutionResponse); vi.mocked(mockHookSystem.fireBeforeToolEvent).mockResolvedValue(
mockBeforeOutput,
);
vi.mocked(mockHookSystem.fireAfterToolEvent).mockResolvedValue(undefined);
const result = await executeToolWithHooks( const result = await executeToolWithHooks(
invocation, invocation,
toolName, toolName,
abortSignal, abortSignal,
messageBus,
true, // hooksEnabled
mockTool, mockTool,
undefined,
undefined,
undefined,
mockConfig,
); );
expect(result.llmContent).toBe('key: original'); expect(result.llmContent).toBe('key: original');
+5 -103
View File
@@ -11,9 +11,7 @@ import {
type HookExecutionResponse, type HookExecutionResponse,
} from '../confirmation-bus/types.js'; } from '../confirmation-bus/types.js';
import { import {
createHookOutput,
NotificationType, NotificationType,
type DefaultHookOutput,
type McpToolContext, type McpToolContext,
BeforeToolHookOutput, BeforeToolHookOutput,
} from '../hooks/types.js'; } from '../hooks/types.js';
@@ -194,103 +192,12 @@ function extractMcpContext(
}; };
} }
/**
* Fires the BeforeTool hook and returns the hook output.
*
* @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool being executed
* @param toolInput The input parameters for the tool
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireBeforeToolHook(
messageBus: MessageBus,
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'BeforeTool',
input: {
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('BeforeTool', response.output)
: undefined;
} catch (error) {
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
return undefined;
}
}
/**
* Fires the AfterTool hook and returns the hook output.
*
* @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool that was executed
* @param toolInput The input parameters for the tool
* @param toolResponse The result from the tool execution
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireAfterToolHook(
messageBus: MessageBus,
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: {
llmContent: ToolResult['llmContent'];
returnDisplay: ToolResult['returnDisplay'];
error: ToolResult['error'];
},
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'AfterTool',
input: {
tool_name: toolName,
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('AfterTool', response.output)
: undefined;
} catch (error) {
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
return undefined;
}
}
/** /**
* Execute a tool with BeforeTool and AfterTool hooks. * Execute a tool with BeforeTool and AfterTool hooks.
* *
* @param invocation The tool invocation to execute * @param invocation The tool invocation to execute
* @param toolName The name of the tool * @param toolName The name of the tool
* @param signal Abort signal for cancellation * @param signal Abort signal for cancellation
* @param messageBus Optional message bus for hook communication
* @param hooksEnabled Whether hooks are enabled
* @param liveOutputCallback Optional callback for live output updates * @param liveOutputCallback Optional callback for live output updates
* @param shellExecutionConfig Optional shell execution config * @param shellExecutionConfig Optional shell execution config
* @param setPidCallback Optional callback to set the PID for shell invocations * @param setPidCallback Optional callback to set the PID for shell invocations
@@ -301,8 +208,6 @@ export async function executeToolWithHooks(
invocation: ShellToolInvocation | AnyToolInvocation, invocation: ShellToolInvocation | AnyToolInvocation,
toolName: string, toolName: string,
signal: AbortSignal, signal: AbortSignal,
messageBus: MessageBus | undefined,
hooksEnabled: boolean,
tool: AnyDeclarativeTool, tool: AnyDeclarativeTool,
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void, liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig, shellExecutionConfig?: ShellExecutionConfig,
@@ -316,10 +221,9 @@ export async function executeToolWithHooks(
// Extract MCP context if this is an MCP tool (only if config is provided) // Extract MCP context if this is an MCP tool (only if config is provided)
const mcpContext = config ? extractMcpContext(invocation, config) : undefined; const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
// Fire BeforeTool hook through MessageBus (only if hooks are enabled) const hookSystem = config?.getHookSystem();
if (hooksEnabled && messageBus) { if (hookSystem) {
const beforeOutput = await fireBeforeToolHook( const beforeOutput = await hookSystem.fireBeforeToolEvent(
messageBus,
toolName, toolName,
toolInput, toolInput,
mcpContext, mcpContext,
@@ -419,10 +323,8 @@ export async function executeToolWithHooks(
} }
} }
// Fire AfterTool hook through MessageBus (only if hooks are enabled) if (hookSystem) {
if (hooksEnabled && messageBus) { const afterOutput = await hookSystem.fireAfterToolEvent(
const afterOutput = await fireAfterToolHook(
messageBus,
toolName, toolName,
toolInput, toolInput,
{ {
@@ -1889,6 +1889,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
}) as unknown as PolicyEngine, }) as unknown as PolicyEngine,
isInteractive: () => false, isInteractive: () => false,
}); });
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({ const scheduler = new CoreToolScheduler({
config: mockConfig, config: mockConfig,
@@ -2018,6 +2019,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
getApprovalMode: () => ApprovalMode.YOLO, getApprovalMode: () => ApprovalMode.YOLO,
isInteractive: () => false, isInteractive: () => false,
}); });
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({ const scheduler = new CoreToolScheduler({
config: mockConfig, config: mockConfig,
@@ -2105,6 +2107,7 @@ describe('CoreToolScheduler Sequential Execution', () => {
check: async () => ({ decision: PolicyDecision.DENY }), check: async () => ({ decision: PolicyDecision.DENY }),
}) as unknown as PolicyEngine, }) as unknown as PolicyEngine,
}); });
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
const scheduler = new CoreToolScheduler({ const scheduler = new CoreToolScheduler({
config: mockConfig, config: mockConfig,
+43
View File
@@ -22,6 +22,7 @@ import type {
BeforeModelHookOutput, BeforeModelHookOutput,
AfterModelHookOutput, AfterModelHookOutput,
BeforeToolSelectionHookOutput, BeforeToolSelectionHookOutput,
McpToolContext,
} from './types.js'; } from './types.js';
import type { AggregatedHookResult } from './hookAggregator.js'; import type { AggregatedHookResult } from './hookAggregator.js';
import type { import type {
@@ -297,4 +298,46 @@ export class HookSystem {
return {}; return {};
} }
} }
async fireBeforeToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
);
return result.finalOutput;
} catch (error) {
debugLogger.debug(`BeforeTool hook failed for ${toolName}:`, error);
return undefined;
}
}
async fireAfterToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: {
llmContent: unknown;
returnDisplay: unknown;
error: unknown;
},
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const result = await this.hookEventHandler.fireAfterToolEvent(
toolName,
toolInput,
toolResponse as Record<string, unknown>,
mcpContext,
);
return result.finalOutput;
} catch (error) {
debugLogger.debug(`AfterTool hook failed for ${toolName}:`, error);
return undefined;
}
}
} }
@@ -253,17 +253,7 @@ describe('ToolExecutor', () => {
// 2. Mock executeToolWithHooks to trigger the PID callback // 2. Mock executeToolWithHooks to trigger the PID callback
const testPid = 12345; const testPid = 12345;
vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation( vi.mocked(coreToolHookTriggers.executeToolWithHooks).mockImplementation(
async ( async (_inv, _name, _sig, _tool, _liveCb, _shellCfg, setPidCallback) => {
_inv,
_name,
_sig,
_bus,
_hooks,
_tool,
_liveCb,
_shellCfg,
setPidCallback,
) => {
// Simulate the shell tool reporting a PID // Simulate the shell tool reporting a PID
if (setPidCallback) { if (setPidCallback) {
setPidCallback(testPid); setPidCallback(testPid);
@@ -66,8 +66,6 @@ export class ToolExecutor {
: undefined; : undefined;
const shellExecutionConfig = this.config.getShellExecutionConfig(); const shellExecutionConfig = this.config.getShellExecutionConfig();
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
return runInDevTraceSpan( return runInDevTraceSpan(
{ {
@@ -95,8 +93,6 @@ export class ToolExecutor {
invocation, invocation,
toolName, toolName,
signal, signal,
messageBus,
hooksEnabled,
tool, tool,
liveOutputCallback, liveOutputCallback,
shellExecutionConfig, shellExecutionConfig,
@@ -108,8 +104,6 @@ export class ToolExecutor {
invocation, invocation,
toolName, toolName,
signal, signal,
messageBus,
hooksEnabled,
tool, tool,
liveOutputCallback, liveOutputCallback,
shellExecutionConfig, shellExecutionConfig,