feat(hooks): add mcp_context to BeforeTool and AfterTool hook inputs (#15656)

Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Vijay Vasudevan
2026-01-08 10:35:33 -08:00
committed by GitHub
parent 660368f249
commit eb3f3cfdb8
7 changed files with 325 additions and 6 deletions
+10
View File
@@ -46,6 +46,16 @@ specific event.
- `tool_input`: (`object`) The arguments passed to the tool. - `tool_input`: (`object`) The arguments passed to the tool.
- `tool_response`: (`object`, **AfterTool only**) The raw output from the tool - `tool_response`: (`object`, **AfterTool only**) The raw output from the tool
execution. execution.
- `mcp_context`: (`object`, **optional**) Present only for MCP tool invocations.
Contains server identity information:
- `server_name`: (`string`) The configured name of the MCP server.
- `tool_name`: (`string`) The original tool name from the MCP server.
- `command`: (`string`, optional) For stdio transport, the command used to
start the server.
- `args`: (`string[]`, optional) For stdio transport, the command arguments.
- `cwd`: (`string`, optional) For stdio transport, the working directory.
- `url`: (`string`, optional) For SSE/HTTP transport, the server URL.
- `tcp`: (`string`, optional) For WebSocket transport, the TCP address.
#### Agent Events (`BeforeAgent`, `AfterAgent`) #### Agent Events (`BeforeAgent`, `AfterAgent`)
@@ -14,8 +14,10 @@ import {
createHookOutput, createHookOutput,
NotificationType, NotificationType,
type DefaultHookOutput, type DefaultHookOutput,
type McpToolContext,
BeforeToolHookOutput, BeforeToolHookOutput,
} from '../hooks/types.js'; } from '../hooks/types.js';
import type { Config } from '../config/config.js';
import type { import type {
ToolCallConfirmationDetails, ToolCallConfirmationDetails,
ToolResult, ToolResult,
@@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js';
import type { AnsiOutput, ShellExecutionConfig } from '../index.js'; import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
import type { AnyToolInvocation } from '../tools/tools.js'; import type { AnyToolInvocation } from '../tools/tools.js';
import { ShellToolInvocation } from '../tools/shell.js'; import { ShellToolInvocation } from '../tools/shell.js';
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
/** /**
* Serializable representation of tool confirmation details for hooks. * Serializable representation of tool confirmation details for hooks.
@@ -154,18 +157,57 @@ export async function fireToolNotificationHook(
} }
} }
/**
* Extracts MCP context from a tool invocation if it's an MCP tool.
*
* @param invocation The tool invocation
* @param config Config to look up server details
* @returns MCP context if this is an MCP tool, undefined otherwise
*/
function extractMcpContext(
invocation: ShellToolInvocation | AnyToolInvocation,
config: Config,
): McpToolContext | undefined {
if (!(invocation instanceof DiscoveredMCPToolInvocation)) {
return undefined;
}
// Get the server config
const mcpServers =
config.getMcpClientManager()?.getMcpServers() ??
config.getMcpServers() ??
{};
const serverConfig = mcpServers[invocation.serverName];
if (!serverConfig) {
return undefined;
}
return {
server_name: invocation.serverName,
tool_name: invocation.serverToolName,
// Non-sensitive connection details only
command: serverConfig.command,
args: serverConfig.args,
cwd: serverConfig.cwd,
url: serverConfig.url ?? serverConfig.httpUrl,
tcp: serverConfig.tcp,
};
}
/** /**
* Fires the BeforeTool hook and returns the hook output. * Fires the BeforeTool hook and returns the hook output.
* *
* @param messageBus The message bus to use for hook communication * @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool being executed * @param toolName The name of the tool being executed
* @param toolInput The input parameters for the tool * @param toolInput The input parameters for the tool
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error * @returns The hook output, or undefined if no hook was executed or on error
*/ */
export async function fireBeforeToolHook( export async function fireBeforeToolHook(
messageBus: MessageBus, messageBus: MessageBus,
toolName: string, toolName: string,
toolInput: Record<string, unknown>, toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> { ): Promise<DefaultHookOutput | undefined> {
try { try {
const response = await messageBus.request< const response = await messageBus.request<
@@ -178,6 +220,7 @@ export async function fireBeforeToolHook(
input: { input: {
tool_name: toolName, tool_name: toolName,
tool_input: toolInput, tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
}, },
}, },
MessageBusType.HOOK_EXECUTION_RESPONSE, MessageBusType.HOOK_EXECUTION_RESPONSE,
@@ -199,6 +242,7 @@ export async function fireBeforeToolHook(
* @param toolName The name of the tool that was executed * @param toolName The name of the tool that was executed
* @param toolInput The input parameters for the tool * @param toolInput The input parameters for the tool
* @param toolResponse The result from the tool execution * @param toolResponse The result from the tool execution
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error * @returns The hook output, or undefined if no hook was executed or on error
*/ */
export async function fireAfterToolHook( export async function fireAfterToolHook(
@@ -210,6 +254,7 @@ export async function fireAfterToolHook(
returnDisplay: ToolResult['returnDisplay']; returnDisplay: ToolResult['returnDisplay'];
error: ToolResult['error']; error: ToolResult['error'];
}, },
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> { ): Promise<DefaultHookOutput | undefined> {
try { try {
const response = await messageBus.request< const response = await messageBus.request<
@@ -223,6 +268,7 @@ export async function fireAfterToolHook(
tool_name: toolName, tool_name: toolName,
tool_input: toolInput, tool_input: toolInput,
tool_response: toolResponse, tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
}, },
}, },
MessageBusType.HOOK_EXECUTION_RESPONSE, MessageBusType.HOOK_EXECUTION_RESPONSE,
@@ -248,6 +294,7 @@ export async function fireAfterToolHook(
* @param liveOutputCallback Optional callback for live output updates * @param liveOutputCallback Optional callback for live output updates
* @param shellExecutionConfig Optional shell execution config * @param shellExecutionConfig Optional shell execution config
* @param setPidCallback Optional callback to set the PID for shell invocations * @param setPidCallback Optional callback to set the PID for shell invocations
* @param config Config to look up MCP server details for hook context
* @returns The tool result * @returns The tool result
*/ */
export async function executeToolWithHooks( export async function executeToolWithHooks(
@@ -260,17 +307,22 @@ export async function executeToolWithHooks(
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void, liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig, shellExecutionConfig?: ShellExecutionConfig,
setPidCallback?: (pid: number) => void, setPidCallback?: (pid: number) => void,
config?: Config,
): Promise<ToolResult> { ): Promise<ToolResult> {
const toolInput = (invocation.params || {}) as Record<string, unknown>; const toolInput = (invocation.params || {}) as Record<string, unknown>;
let inputWasModified = false; let inputWasModified = false;
let modifiedKeys: string[] = []; let modifiedKeys: string[] = [];
// Extract MCP context if this is an MCP tool (only if config is provided)
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
// Fire BeforeTool hook through MessageBus (only if hooks are enabled) // Fire BeforeTool hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) { if (hooksEnabled && messageBus) {
const beforeOutput = await fireBeforeToolHook( const beforeOutput = await fireBeforeToolHook(
messageBus, messageBus,
toolName, toolName,
toolInput, toolInput,
mcpContext,
); );
// Check if hook requested to stop entire agent execution // Check if hook requested to stop entire agent execution
@@ -378,6 +430,7 @@ export async function executeToolWithHooks(
returnDisplay: toolResult.returnDisplay, returnDisplay: toolResult.returnDisplay,
error: toolResult.error, error: toolResult.error,
}, },
mcpContext,
); );
// Check if hook requested to stop entire agent execution // Check if hook requested to stop entire agent execution
@@ -258,6 +258,128 @@ describe('HookEventHandler', () => {
expect.stringContaining('F12'), expect.stringContaining('F12'),
); );
}); });
it('should fire BeforeTool event with MCP context when provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const mcpContext = {
server_name: 'my-mcp-server',
tool_name: 'read_file',
command: 'npx',
args: ['-y', '@my-org/mcp-server'],
};
const result = await hookEventHandler.fireBeforeToolEvent(
'my-mcp-server__read_file',
{ path: '/etc/passwd' },
mcpContext,
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.BeforeTool,
expect.objectContaining({
session_id: 'test-session',
cwd: '/test/project',
hook_event_name: 'BeforeTool',
tool_name: 'my-mcp-server__read_file',
tool_input: { path: '/etc/passwd' },
mcp_context: mcpContext,
}),
expect.any(Function),
expect.any(Function),
);
expect(result).toBe(mockAggregated);
});
it('should not include mcp_context when not provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
await hookEventHandler.fireBeforeToolEvent('EditTool', {
file: 'test.txt',
});
const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock
.calls[0][2];
expect(callArgs).not.toHaveProperty('mcp_context');
});
}); });
describe('fireAfterToolEvent', () => { describe('fireAfterToolEvent', () => {
@@ -325,6 +447,78 @@ describe('HookEventHandler', () => {
expect(result).toBe(mockAggregated); expect(result).toBe(mockAggregated);
}); });
it('should fire AfterTool event with MCP context when provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './after.sh',
} as unknown as HookConfig,
eventName: HookEventName.AfterTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './after.sh',
timeout: 30000,
},
eventName: HookEventName.AfterTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.AfterTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const toolInput = { path: '/etc/passwd' };
const toolResponse = { success: true, content: 'File content' };
const mcpContext = {
server_name: 'my-mcp-server',
tool_name: 'read_file',
url: 'https://mcp.example.com',
};
const result = await hookEventHandler.fireAfterToolEvent(
'my-mcp-server__read_file',
toolInput,
toolResponse,
mcpContext,
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.AfterTool,
expect.objectContaining({
tool_name: 'my-mcp-server__read_file',
tool_input: toolInput,
tool_response: toolResponse,
mcp_context: mcpContext,
}),
expect.any(Function),
expect.any(Function),
);
expect(result).toBe(mockAggregated);
});
}); });
describe('fireBeforeAgentEvent', () => { describe('fireBeforeAgentEvent', () => {
+38 -5
View File
@@ -29,6 +29,7 @@ import type {
SessionEndReason, SessionEndReason,
PreCompressTrigger, PreCompressTrigger,
HookExecutionResult, HookExecutionResult,
McpToolContext,
} from './types.js'; } from './types.js';
import { defaultHookTranslator } from './hookTranslator.js'; import { defaultHookTranslator } from './hookTranslator.js';
import type { import type {
@@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record<string, unknown> {
function validateBeforeToolInput(input: Record<string, unknown>): { function validateBeforeToolInput(input: Record<string, unknown>): {
toolName: string; toolName: string;
toolInput: Record<string, unknown>; toolInput: Record<string, unknown>;
mcpContext?: McpToolContext;
} { } {
const toolName = input['tool_name']; const toolName = input['tool_name'];
const toolInput = input['tool_input']; const toolInput = input['tool_input'];
const mcpContext = input['mcp_context'];
if (typeof toolName !== 'string') { if (typeof toolName !== 'string') {
throw new Error( throw new Error(
'Invalid input for BeforeTool hook event: tool_name must be a string', 'Invalid input for BeforeTool hook event: tool_name must be a string',
@@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record<string, unknown>): {
'Invalid input for BeforeTool hook event: tool_input must be an object', 'Invalid input for BeforeTool hook event: tool_input must be an object',
); );
} }
return { toolName, toolInput }; if (mcpContext !== undefined && !isObject(mcpContext)) {
throw new Error(
'Invalid input for BeforeTool hook event: mcp_context must be an object',
);
}
return {
toolName,
toolInput,
mcpContext: mcpContext as McpToolContext | undefined,
};
} }
/** /**
@@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record<string, unknown>): {
toolName: string; toolName: string;
toolInput: Record<string, unknown>; toolInput: Record<string, unknown>;
toolResponse: Record<string, unknown>; toolResponse: Record<string, unknown>;
mcpContext?: McpToolContext;
} { } {
const toolName = input['tool_name']; const toolName = input['tool_name'];
const toolInput = input['tool_input']; const toolInput = input['tool_input'];
const toolResponse = input['tool_response']; const toolResponse = input['tool_response'];
const mcpContext = input['mcp_context'];
if (typeof toolName !== 'string') { if (typeof toolName !== 'string') {
throw new Error( throw new Error(
'Invalid input for AfterTool hook event: tool_name must be a string', 'Invalid input for AfterTool hook event: tool_name must be a string',
@@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record<string, unknown>): {
'Invalid input for AfterTool hook event: tool_response must be an object', 'Invalid input for AfterTool hook event: tool_response must be an object',
); );
} }
return { toolName, toolInput, toolResponse }; if (mcpContext !== undefined && !isObject(mcpContext)) {
throw new Error(
'Invalid input for AfterTool hook event: mcp_context must be an object',
);
}
return {
toolName,
toolInput,
toolResponse,
mcpContext: mcpContext as McpToolContext | undefined,
};
} }
/** /**
@@ -313,11 +337,13 @@ export class HookEventHandler {
async fireBeforeToolEvent( async fireBeforeToolEvent(
toolName: string, toolName: string,
toolInput: Record<string, unknown>, toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<AggregatedHookResult> { ): Promise<AggregatedHookResult> {
const input: BeforeToolInput = { const input: BeforeToolInput = {
...this.createBaseInput(HookEventName.BeforeTool), ...this.createBaseInput(HookEventName.BeforeTool),
tool_name: toolName, tool_name: toolName,
tool_input: toolInput, tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
}; };
const context: HookEventContext = { toolName }; const context: HookEventContext = { toolName };
@@ -332,12 +358,14 @@ export class HookEventHandler {
toolName: string, toolName: string,
toolInput: Record<string, unknown>, toolInput: Record<string, unknown>,
toolResponse: Record<string, unknown>, toolResponse: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<AggregatedHookResult> { ): Promise<AggregatedHookResult> {
const input: AfterToolInput = { const input: AfterToolInput = {
...this.createBaseInput(HookEventName.AfterTool), ...this.createBaseInput(HookEventName.AfterTool),
tool_name: toolName, tool_name: toolName,
tool_input: toolInput, tool_input: toolInput,
tool_response: toolResponse, tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
}; };
const context: HookEventContext = { toolName }; const context: HookEventContext = { toolName };
@@ -725,18 +753,23 @@ export class HookEventHandler {
// Route to appropriate event handler based on eventName // Route to appropriate event handler based on eventName
switch (request.eventName) { switch (request.eventName) {
case HookEventName.BeforeTool: { case HookEventName.BeforeTool: {
const { toolName, toolInput } = const { toolName, toolInput, mcpContext } =
validateBeforeToolInput(enrichedInput); validateBeforeToolInput(enrichedInput);
result = await this.fireBeforeToolEvent(toolName, toolInput); result = await this.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
);
break; break;
} }
case HookEventName.AfterTool: { case HookEventName.AfterTool: {
const { toolName, toolInput, toolResponse } = const { toolName, toolInput, toolResponse, mcpContext } =
validateAfterToolInput(enrichedInput); validateAfterToolInput(enrichedInput);
result = await this.fireAfterToolEvent( result = await this.fireAfterToolEvent(
toolName, toolName,
toolInput, toolInput,
toolResponse, toolResponse,
mcpContext,
); );
break; break;
} }
+26
View File
@@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput {
} }
} }
/**
* Context for MCP tool executions.
* Contains non-sensitive connection information about the MCP server
* identity. Since server_name is user controlled and arbitrary, we
* also include connection information (e.g., command or url) to
* help identify the MCP server.
*
* NOTE: In the future, consider defining a shared sanitized interface
* from MCPServerConfig to avoid duplication and ensure consistency.
*/
export interface McpToolContext {
server_name: string;
tool_name: string; // Original tool name from the MCP server
// Connection info (mutually exclusive based on transport type)
command?: string; // For stdio transport
args?: string[]; // For stdio transport
cwd?: string; // For stdio transport
url?: string; // For SSE/HTTP transport
tcp?: string; // For WebSocket transport
}
/** /**
* BeforeTool hook input * BeforeTool hook input
*/ */
export interface BeforeToolInput extends HookInput { export interface BeforeToolInput extends HookInput {
tool_name: string; tool_name: string;
tool_input: Record<string, unknown>; tool_input: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
} }
/** /**
@@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput {
tool_name: string; tool_name: string;
tool_input: Record<string, unknown>; tool_input: Record<string, unknown>;
tool_response: Record<string, unknown>; tool_response: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
} }
/** /**
@@ -98,6 +98,7 @@ export class ToolExecutor {
liveOutputCallback, liveOutputCallback,
shellExecutionConfig, shellExecutionConfig,
setPidCallback, setPidCallback,
this.config,
); );
} else { } else {
promise = executeToolWithHooks( promise = executeToolWithHooks(
@@ -109,6 +110,8 @@ export class ToolExecutor {
tool, tool,
liveOutputCallback, liveOutputCallback,
shellExecutionConfig, shellExecutionConfig,
undefined,
this.config,
); );
} }
+1 -1
View File
@@ -59,7 +59,7 @@ type McpContentBlock =
| McpResourceBlock | McpResourceBlock
| McpResourceLinkBlock; | McpResourceLinkBlock;
class DiscoveredMCPToolInvocation extends BaseToolInvocation< export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
ToolParams, ToolParams,
ToolResult ToolResult
> { > {