diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index bc7b6e5fa2..b5174f827e 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -46,6 +46,16 @@ specific event. - `tool_input`: (`object`) The arguments passed to the tool. - `tool_response`: (`object`, **AfterTool only**) The raw output from the tool execution. +- `mcp_context`: (`object`, **optional**) Present only for MCP tool invocations. + Contains server identity information: + - `server_name`: (`string`) The configured name of the MCP server. + - `tool_name`: (`string`) The original tool name from the MCP server. + - `command`: (`string`, optional) For stdio transport, the command used to + start the server. + - `args`: (`string[]`, optional) For stdio transport, the command arguments. + - `cwd`: (`string`, optional) For stdio transport, the working directory. + - `url`: (`string`, optional) For SSE/HTTP transport, the server URL. + - `tcp`: (`string`, optional) For WebSocket transport, the TCP address. #### Agent Events (`BeforeAgent`, `AfterAgent`) diff --git a/packages/core/src/core/coreToolHookTriggers.ts b/packages/core/src/core/coreToolHookTriggers.ts index 70f9e93c1d..ca1467518b 100644 --- a/packages/core/src/core/coreToolHookTriggers.ts +++ b/packages/core/src/core/coreToolHookTriggers.ts @@ -14,8 +14,10 @@ import { createHookOutput, NotificationType, type DefaultHookOutput, + type McpToolContext, BeforeToolHookOutput, } from '../hooks/types.js'; +import type { Config } from '../config/config.js'; import type { ToolCallConfirmationDetails, ToolResult, @@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js'; import type { AnsiOutput, ShellExecutionConfig } from '../index.js'; import type { AnyToolInvocation } from '../tools/tools.js'; import { ShellToolInvocation } from '../tools/shell.js'; +import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; /** * Serializable representation of tool confirmation details for hooks. @@ -154,18 +157,57 @@ export async function fireToolNotificationHook( } } +/** + * Extracts MCP context from a tool invocation if it's an MCP tool. + * + * @param invocation The tool invocation + * @param config Config to look up server details + * @returns MCP context if this is an MCP tool, undefined otherwise + */ +function extractMcpContext( + invocation: ShellToolInvocation | AnyToolInvocation, + config: Config, +): McpToolContext | undefined { + if (!(invocation instanceof DiscoveredMCPToolInvocation)) { + return undefined; + } + + // Get the server config + const mcpServers = + config.getMcpClientManager()?.getMcpServers() ?? + config.getMcpServers() ?? + {}; + const serverConfig = mcpServers[invocation.serverName]; + if (!serverConfig) { + return undefined; + } + + return { + server_name: invocation.serverName, + tool_name: invocation.serverToolName, + // Non-sensitive connection details only + command: serverConfig.command, + args: serverConfig.args, + cwd: serverConfig.cwd, + url: serverConfig.url ?? serverConfig.httpUrl, + tcp: serverConfig.tcp, + }; +} + /** * Fires the BeforeTool hook and returns the hook output. * * @param messageBus The message bus to use for hook communication * @param toolName The name of the tool being executed * @param toolInput The input parameters for the tool + * @param mcpContext Optional MCP context for MCP tools * @returns The hook output, or undefined if no hook was executed or on error */ export async function fireBeforeToolHook( messageBus: MessageBus, toolName: string, toolInput: Record, + mcpContext?: McpToolContext, ): Promise { try { const response = await messageBus.request< @@ -178,6 +220,7 @@ export async function fireBeforeToolHook( input: { tool_name: toolName, tool_input: toolInput, + ...(mcpContext && { mcp_context: mcpContext }), }, }, MessageBusType.HOOK_EXECUTION_RESPONSE, @@ -199,6 +242,7 @@ export async function fireBeforeToolHook( * @param toolName The name of the tool that was executed * @param toolInput The input parameters for the tool * @param toolResponse The result from the tool execution + * @param mcpContext Optional MCP context for MCP tools * @returns The hook output, or undefined if no hook was executed or on error */ export async function fireAfterToolHook( @@ -210,6 +254,7 @@ export async function fireAfterToolHook( returnDisplay: ToolResult['returnDisplay']; error: ToolResult['error']; }, + mcpContext?: McpToolContext, ): Promise { try { const response = await messageBus.request< @@ -223,6 +268,7 @@ export async function fireAfterToolHook( tool_name: toolName, tool_input: toolInput, tool_response: toolResponse, + ...(mcpContext && { mcp_context: mcpContext }), }, }, MessageBusType.HOOK_EXECUTION_RESPONSE, @@ -248,6 +294,7 @@ export async function fireAfterToolHook( * @param liveOutputCallback Optional callback for live output updates * @param shellExecutionConfig Optional shell execution config * @param setPidCallback Optional callback to set the PID for shell invocations + * @param config Config to look up MCP server details for hook context * @returns The tool result */ export async function executeToolWithHooks( @@ -260,17 +307,22 @@ export async function executeToolWithHooks( liveOutputCallback?: (outputChunk: string | AnsiOutput) => void, shellExecutionConfig?: ShellExecutionConfig, setPidCallback?: (pid: number) => void, + config?: Config, ): Promise { const toolInput = (invocation.params || {}) as Record; let inputWasModified = false; let modifiedKeys: string[] = []; + // Extract MCP context if this is an MCP tool (only if config is provided) + const mcpContext = config ? extractMcpContext(invocation, config) : undefined; + // Fire BeforeTool hook through MessageBus (only if hooks are enabled) if (hooksEnabled && messageBus) { const beforeOutput = await fireBeforeToolHook( messageBus, toolName, toolInput, + mcpContext, ); // Check if hook requested to stop entire agent execution @@ -378,6 +430,7 @@ export async function executeToolWithHooks( returnDisplay: toolResult.returnDisplay, error: toolResult.error, }, + mcpContext, ); // Check if hook requested to stop entire agent execution diff --git a/packages/core/src/hooks/hookEventHandler.test.ts b/packages/core/src/hooks/hookEventHandler.test.ts index 2bffc805b6..af7a6be37a 100644 --- a/packages/core/src/hooks/hookEventHandler.test.ts +++ b/packages/core/src/hooks/hookEventHandler.test.ts @@ -258,6 +258,128 @@ describe('HookEventHandler', () => { expect.stringContaining('F12'), ); }); + + it('should fire BeforeTool event with MCP context when provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const mcpContext = { + server_name: 'my-mcp-server', + tool_name: 'read_file', + command: 'npx', + args: ['-y', '@my-org/mcp-server'], + }; + + const result = await hookEventHandler.fireBeforeToolEvent( + 'my-mcp-server__read_file', + { path: '/etc/passwd' }, + mcpContext, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.BeforeTool, + expect.objectContaining({ + session_id: 'test-session', + cwd: '/test/project', + hook_event_name: 'BeforeTool', + tool_name: 'my-mcp-server__read_file', + tool_input: { path: '/etc/passwd' }, + mcp_context: mcpContext, + }), + expect.any(Function), + expect.any(Function), + ); + + expect(result).toBe(mockAggregated); + }); + + it('should not include mcp_context when not provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + await hookEventHandler.fireBeforeToolEvent('EditTool', { + file: 'test.txt', + }); + + const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock + .calls[0][2]; + expect(callArgs).not.toHaveProperty('mcp_context'); + }); }); describe('fireAfterToolEvent', () => { @@ -325,6 +447,78 @@ describe('HookEventHandler', () => { expect(result).toBe(mockAggregated); }); + + it('should fire AfterTool event with MCP context when provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './after.sh', + } as unknown as HookConfig, + eventName: HookEventName.AfterTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './after.sh', + timeout: 30000, + }, + eventName: HookEventName.AfterTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.AfterTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const toolInput = { path: '/etc/passwd' }; + const toolResponse = { success: true, content: 'File content' }; + const mcpContext = { + server_name: 'my-mcp-server', + tool_name: 'read_file', + url: 'https://mcp.example.com', + }; + + const result = await hookEventHandler.fireAfterToolEvent( + 'my-mcp-server__read_file', + toolInput, + toolResponse, + mcpContext, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.AfterTool, + expect.objectContaining({ + tool_name: 'my-mcp-server__read_file', + tool_input: toolInput, + tool_response: toolResponse, + mcp_context: mcpContext, + }), + expect.any(Function), + expect.any(Function), + ); + + expect(result).toBe(mockAggregated); + }); }); describe('fireBeforeAgentEvent', () => { diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts index e72aee913a..e208dd1ed4 100644 --- a/packages/core/src/hooks/hookEventHandler.ts +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -29,6 +29,7 @@ import type { SessionEndReason, PreCompressTrigger, HookExecutionResult, + McpToolContext, } from './types.js'; import { defaultHookTranslator } from './hookTranslator.js'; import type { @@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record { function validateBeforeToolInput(input: Record): { toolName: string; toolInput: Record; + mcpContext?: McpToolContext; } { const toolName = input['tool_name']; const toolInput = input['tool_input']; + const mcpContext = input['mcp_context']; if (typeof toolName !== 'string') { throw new Error( 'Invalid input for BeforeTool hook event: tool_name must be a string', @@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record): { 'Invalid input for BeforeTool hook event: tool_input must be an object', ); } - return { toolName, toolInput }; + if (mcpContext !== undefined && !isObject(mcpContext)) { + throw new Error( + 'Invalid input for BeforeTool hook event: mcp_context must be an object', + ); + } + return { + toolName, + toolInput, + mcpContext: mcpContext as McpToolContext | undefined, + }; } /** @@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record): { toolName: string; toolInput: Record; toolResponse: Record; + mcpContext?: McpToolContext; } { const toolName = input['tool_name']; const toolInput = input['tool_input']; const toolResponse = input['tool_response']; + const mcpContext = input['mcp_context']; if (typeof toolName !== 'string') { throw new Error( 'Invalid input for AfterTool hook event: tool_name must be a string', @@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record): { 'Invalid input for AfterTool hook event: tool_response must be an object', ); } - return { toolName, toolInput, toolResponse }; + if (mcpContext !== undefined && !isObject(mcpContext)) { + throw new Error( + 'Invalid input for AfterTool hook event: mcp_context must be an object', + ); + } + return { + toolName, + toolInput, + toolResponse, + mcpContext: mcpContext as McpToolContext | undefined, + }; } /** @@ -313,11 +337,13 @@ export class HookEventHandler { async fireBeforeToolEvent( toolName: string, toolInput: Record, + mcpContext?: McpToolContext, ): Promise { const input: BeforeToolInput = { ...this.createBaseInput(HookEventName.BeforeTool), tool_name: toolName, tool_input: toolInput, + ...(mcpContext && { mcp_context: mcpContext }), }; const context: HookEventContext = { toolName }; @@ -332,12 +358,14 @@ export class HookEventHandler { toolName: string, toolInput: Record, toolResponse: Record, + mcpContext?: McpToolContext, ): Promise { const input: AfterToolInput = { ...this.createBaseInput(HookEventName.AfterTool), tool_name: toolName, tool_input: toolInput, tool_response: toolResponse, + ...(mcpContext && { mcp_context: mcpContext }), }; const context: HookEventContext = { toolName }; @@ -725,18 +753,23 @@ export class HookEventHandler { // Route to appropriate event handler based on eventName switch (request.eventName) { case HookEventName.BeforeTool: { - const { toolName, toolInput } = + const { toolName, toolInput, mcpContext } = validateBeforeToolInput(enrichedInput); - result = await this.fireBeforeToolEvent(toolName, toolInput); + result = await this.fireBeforeToolEvent( + toolName, + toolInput, + mcpContext, + ); break; } case HookEventName.AfterTool: { - const { toolName, toolInput, toolResponse } = + const { toolName, toolInput, toolResponse, mcpContext } = validateAfterToolInput(enrichedInput); result = await this.fireAfterToolEvent( toolName, toolInput, toolResponse, + mcpContext, ); break; } diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index e54a03f840..5ca7bd5fb1 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput { } } +/** + * Context for MCP tool executions. + * Contains non-sensitive connection information about the MCP server + * identity. Since server_name is user controlled and arbitrary, we + * also include connection information (e.g., command or url) to + * help identify the MCP server. + * + * NOTE: In the future, consider defining a shared sanitized interface + * from MCPServerConfig to avoid duplication and ensure consistency. + */ +export interface McpToolContext { + server_name: string; + tool_name: string; // Original tool name from the MCP server + + // Connection info (mutually exclusive based on transport type) + command?: string; // For stdio transport + args?: string[]; // For stdio transport + cwd?: string; // For stdio transport + + url?: string; // For SSE/HTTP transport + + tcp?: string; // For WebSocket transport +} + /** * BeforeTool hook input */ export interface BeforeToolInput extends HookInput { tool_name: string; tool_input: Record; + mcp_context?: McpToolContext; // Only present for MCP tools } /** @@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput { tool_name: string; tool_input: Record; tool_response: Record; + mcp_context?: McpToolContext; // Only present for MCP tools } /** diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 8334168b93..233ff998ff 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -98,6 +98,7 @@ export class ToolExecutor { liveOutputCallback, shellExecutionConfig, setPidCallback, + this.config, ); } else { promise = executeToolWithHooks( @@ -109,6 +110,8 @@ export class ToolExecutor { tool, liveOutputCallback, shellExecutionConfig, + undefined, + this.config, ); } diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 44a07d99e8..8259b6c2f3 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -59,7 +59,7 @@ type McpContentBlock = | McpResourceBlock | McpResourceLinkBlock; -class DiscoveredMCPToolInvocation extends BaseToolInvocation< +export class DiscoveredMCPToolInvocation extends BaseToolInvocation< ToolParams, ToolResult > {