mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(hooks): add mcp_context to BeforeTool and AfterTool hook inputs (#15656)
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
@@ -46,6 +46,16 @@ specific event.
|
||||
- `tool_input`: (`object`) The arguments passed to the tool.
|
||||
- `tool_response`: (`object`, **AfterTool only**) The raw output from the tool
|
||||
execution.
|
||||
- `mcp_context`: (`object`, **optional**) Present only for MCP tool invocations.
|
||||
Contains server identity information:
|
||||
- `server_name`: (`string`) The configured name of the MCP server.
|
||||
- `tool_name`: (`string`) The original tool name from the MCP server.
|
||||
- `command`: (`string`, optional) For stdio transport, the command used to
|
||||
start the server.
|
||||
- `args`: (`string[]`, optional) For stdio transport, the command arguments.
|
||||
- `cwd`: (`string`, optional) For stdio transport, the working directory.
|
||||
- `url`: (`string`, optional) For SSE/HTTP transport, the server URL.
|
||||
- `tcp`: (`string`, optional) For WebSocket transport, the TCP address.
|
||||
|
||||
#### Agent Events (`BeforeAgent`, `AfterAgent`)
|
||||
|
||||
|
||||
@@ -14,8 +14,10 @@ import {
|
||||
createHookOutput,
|
||||
NotificationType,
|
||||
type DefaultHookOutput,
|
||||
type McpToolContext,
|
||||
BeforeToolHookOutput,
|
||||
} from '../hooks/types.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolResult,
|
||||
@@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
|
||||
import type { AnyToolInvocation } from '../tools/tools.js';
|
||||
import { ShellToolInvocation } from '../tools/shell.js';
|
||||
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||
|
||||
/**
|
||||
* Serializable representation of tool confirmation details for hooks.
|
||||
@@ -154,18 +157,57 @@ export async function fireToolNotificationHook(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts MCP context from a tool invocation if it's an MCP tool.
|
||||
*
|
||||
* @param invocation The tool invocation
|
||||
* @param config Config to look up server details
|
||||
* @returns MCP context if this is an MCP tool, undefined otherwise
|
||||
*/
|
||||
function extractMcpContext(
|
||||
invocation: ShellToolInvocation | AnyToolInvocation,
|
||||
config: Config,
|
||||
): McpToolContext | undefined {
|
||||
if (!(invocation instanceof DiscoveredMCPToolInvocation)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Get the server config
|
||||
const mcpServers =
|
||||
config.getMcpClientManager()?.getMcpServers() ??
|
||||
config.getMcpServers() ??
|
||||
{};
|
||||
const serverConfig = mcpServers[invocation.serverName];
|
||||
if (!serverConfig) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
server_name: invocation.serverName,
|
||||
tool_name: invocation.serverToolName,
|
||||
// Non-sensitive connection details only
|
||||
command: serverConfig.command,
|
||||
args: serverConfig.args,
|
||||
cwd: serverConfig.cwd,
|
||||
url: serverConfig.url ?? serverConfig.httpUrl,
|
||||
tcp: serverConfig.tcp,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the BeforeTool hook and returns the hook output.
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param toolName The name of the tool being executed
|
||||
* @param toolInput The input parameters for the tool
|
||||
* @param mcpContext Optional MCP context for MCP tools
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireBeforeToolHook(
|
||||
messageBus: MessageBus,
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const response = await messageBus.request<
|
||||
@@ -178,6 +220,7 @@ export async function fireBeforeToolHook(
|
||||
input: {
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
@@ -199,6 +242,7 @@ export async function fireBeforeToolHook(
|
||||
* @param toolName The name of the tool that was executed
|
||||
* @param toolInput The input parameters for the tool
|
||||
* @param toolResponse The result from the tool execution
|
||||
* @param mcpContext Optional MCP context for MCP tools
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireAfterToolHook(
|
||||
@@ -210,6 +254,7 @@ export async function fireAfterToolHook(
|
||||
returnDisplay: ToolResult['returnDisplay'];
|
||||
error: ToolResult['error'];
|
||||
},
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const response = await messageBus.request<
|
||||
@@ -223,6 +268,7 @@ export async function fireAfterToolHook(
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
@@ -248,6 +294,7 @@ export async function fireAfterToolHook(
|
||||
* @param liveOutputCallback Optional callback for live output updates
|
||||
* @param shellExecutionConfig Optional shell execution config
|
||||
* @param setPidCallback Optional callback to set the PID for shell invocations
|
||||
* @param config Config to look up MCP server details for hook context
|
||||
* @returns The tool result
|
||||
*/
|
||||
export async function executeToolWithHooks(
|
||||
@@ -260,17 +307,22 @@ export async function executeToolWithHooks(
|
||||
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
||||
shellExecutionConfig?: ShellExecutionConfig,
|
||||
setPidCallback?: (pid: number) => void,
|
||||
config?: Config,
|
||||
): Promise<ToolResult> {
|
||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||
let inputWasModified = false;
|
||||
let modifiedKeys: string[] = [];
|
||||
|
||||
// Extract MCP context if this is an MCP tool (only if config is provided)
|
||||
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
|
||||
|
||||
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
const beforeOutput = await fireBeforeToolHook(
|
||||
messageBus,
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
@@ -378,6 +430,7 @@ export async function executeToolWithHooks(
|
||||
returnDisplay: toolResult.returnDisplay,
|
||||
error: toolResult.error,
|
||||
},
|
||||
mcpContext,
|
||||
);
|
||||
|
||||
// Check if hook requested to stop entire agent execution
|
||||
|
||||
@@ -258,6 +258,128 @@ describe('HookEventHandler', () => {
|
||||
expect.stringContaining('F12'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fire BeforeTool event with MCP context when provided', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const mcpContext = {
|
||||
server_name: 'my-mcp-server',
|
||||
tool_name: 'read_file',
|
||||
command: 'npx',
|
||||
args: ['-y', '@my-org/mcp-server'],
|
||||
};
|
||||
|
||||
const result = await hookEventHandler.fireBeforeToolEvent(
|
||||
'my-mcp-server__read_file',
|
||||
{ path: '/etc/passwd' },
|
||||
mcpContext,
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.BeforeTool,
|
||||
expect.objectContaining({
|
||||
session_id: 'test-session',
|
||||
cwd: '/test/project',
|
||||
hook_event_name: 'BeforeTool',
|
||||
tool_name: 'my-mcp-server__read_file',
|
||||
tool_input: { path: '/etc/passwd' },
|
||||
mcp_context: mcpContext,
|
||||
}),
|
||||
expect.any(Function),
|
||||
expect.any(Function),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
|
||||
it('should not include mcp_context when not provided', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
await hookEventHandler.fireBeforeToolEvent('EditTool', {
|
||||
file: 'test.txt',
|
||||
});
|
||||
|
||||
const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock
|
||||
.calls[0][2];
|
||||
expect(callArgs).not.toHaveProperty('mcp_context');
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireAfterToolEvent', () => {
|
||||
@@ -325,6 +447,78 @@ describe('HookEventHandler', () => {
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
|
||||
it('should fire AfterTool event with MCP context when provided', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './after.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.AfterTool,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './after.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.AfterTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.AfterTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const toolInput = { path: '/etc/passwd' };
|
||||
const toolResponse = { success: true, content: 'File content' };
|
||||
const mcpContext = {
|
||||
server_name: 'my-mcp-server',
|
||||
tool_name: 'read_file',
|
||||
url: 'https://mcp.example.com',
|
||||
};
|
||||
|
||||
const result = await hookEventHandler.fireAfterToolEvent(
|
||||
'my-mcp-server__read_file',
|
||||
toolInput,
|
||||
toolResponse,
|
||||
mcpContext,
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.AfterTool,
|
||||
expect.objectContaining({
|
||||
tool_name: 'my-mcp-server__read_file',
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
mcp_context: mcpContext,
|
||||
}),
|
||||
expect.any(Function),
|
||||
expect.any(Function),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireBeforeAgentEvent', () => {
|
||||
|
||||
@@ -29,6 +29,7 @@ import type {
|
||||
SessionEndReason,
|
||||
PreCompressTrigger,
|
||||
HookExecutionResult,
|
||||
McpToolContext,
|
||||
} from './types.js';
|
||||
import { defaultHookTranslator } from './hookTranslator.js';
|
||||
import type {
|
||||
@@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record<string, unknown> {
|
||||
function validateBeforeToolInput(input: Record<string, unknown>): {
|
||||
toolName: string;
|
||||
toolInput: Record<string, unknown>;
|
||||
mcpContext?: McpToolContext;
|
||||
} {
|
||||
const toolName = input['tool_name'];
|
||||
const toolInput = input['tool_input'];
|
||||
const mcpContext = input['mcp_context'];
|
||||
if (typeof toolName !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for BeforeTool hook event: tool_name must be a string',
|
||||
@@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record<string, unknown>): {
|
||||
'Invalid input for BeforeTool hook event: tool_input must be an object',
|
||||
);
|
||||
}
|
||||
return { toolName, toolInput };
|
||||
if (mcpContext !== undefined && !isObject(mcpContext)) {
|
||||
throw new Error(
|
||||
'Invalid input for BeforeTool hook event: mcp_context must be an object',
|
||||
);
|
||||
}
|
||||
return {
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext: mcpContext as McpToolContext | undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record<string, unknown>): {
|
||||
toolName: string;
|
||||
toolInput: Record<string, unknown>;
|
||||
toolResponse: Record<string, unknown>;
|
||||
mcpContext?: McpToolContext;
|
||||
} {
|
||||
const toolName = input['tool_name'];
|
||||
const toolInput = input['tool_input'];
|
||||
const toolResponse = input['tool_response'];
|
||||
const mcpContext = input['mcp_context'];
|
||||
if (typeof toolName !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for AfterTool hook event: tool_name must be a string',
|
||||
@@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record<string, unknown>): {
|
||||
'Invalid input for AfterTool hook event: tool_response must be an object',
|
||||
);
|
||||
}
|
||||
return { toolName, toolInput, toolResponse };
|
||||
if (mcpContext !== undefined && !isObject(mcpContext)) {
|
||||
throw new Error(
|
||||
'Invalid input for AfterTool hook event: mcp_context must be an object',
|
||||
);
|
||||
}
|
||||
return {
|
||||
toolName,
|
||||
toolInput,
|
||||
toolResponse,
|
||||
mcpContext: mcpContext as McpToolContext | undefined,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -313,11 +337,13 @@ export class HookEventHandler {
|
||||
async fireBeforeToolEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: BeforeToolInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeTool),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
@@ -332,12 +358,14 @@ export class HookEventHandler {
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
toolResponse: Record<string, unknown>,
|
||||
mcpContext?: McpToolContext,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: AfterToolInput = {
|
||||
...this.createBaseInput(HookEventName.AfterTool),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
...(mcpContext && { mcp_context: mcpContext }),
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
@@ -725,18 +753,23 @@ export class HookEventHandler {
|
||||
// Route to appropriate event handler based on eventName
|
||||
switch (request.eventName) {
|
||||
case HookEventName.BeforeTool: {
|
||||
const { toolName, toolInput } =
|
||||
const { toolName, toolInput, mcpContext } =
|
||||
validateBeforeToolInput(enrichedInput);
|
||||
result = await this.fireBeforeToolEvent(toolName, toolInput);
|
||||
result = await this.fireBeforeToolEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
mcpContext,
|
||||
);
|
||||
break;
|
||||
}
|
||||
case HookEventName.AfterTool: {
|
||||
const { toolName, toolInput, toolResponse } =
|
||||
const { toolName, toolInput, toolResponse, mcpContext } =
|
||||
validateAfterToolInput(enrichedInput);
|
||||
result = await this.fireAfterToolEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
toolResponse,
|
||||
mcpContext,
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Context for MCP tool executions.
|
||||
* Contains non-sensitive connection information about the MCP server
|
||||
* identity. Since server_name is user controlled and arbitrary, we
|
||||
* also include connection information (e.g., command or url) to
|
||||
* help identify the MCP server.
|
||||
*
|
||||
* NOTE: In the future, consider defining a shared sanitized interface
|
||||
* from MCPServerConfig to avoid duplication and ensure consistency.
|
||||
*/
|
||||
export interface McpToolContext {
|
||||
server_name: string;
|
||||
tool_name: string; // Original tool name from the MCP server
|
||||
|
||||
// Connection info (mutually exclusive based on transport type)
|
||||
command?: string; // For stdio transport
|
||||
args?: string[]; // For stdio transport
|
||||
cwd?: string; // For stdio transport
|
||||
|
||||
url?: string; // For SSE/HTTP transport
|
||||
|
||||
tcp?: string; // For WebSocket transport
|
||||
}
|
||||
|
||||
/**
|
||||
* BeforeTool hook input
|
||||
*/
|
||||
export interface BeforeToolInput extends HookInput {
|
||||
tool_name: string;
|
||||
tool_input: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput {
|
||||
tool_name: string;
|
||||
tool_input: Record<string, unknown>;
|
||||
tool_response: Record<string, unknown>;
|
||||
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -98,6 +98,7 @@ export class ToolExecutor {
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
setPidCallback,
|
||||
this.config,
|
||||
);
|
||||
} else {
|
||||
promise = executeToolWithHooks(
|
||||
@@ -109,6 +110,8 @@ export class ToolExecutor {
|
||||
tool,
|
||||
liveOutputCallback,
|
||||
shellExecutionConfig,
|
||||
undefined,
|
||||
this.config,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ type McpContentBlock =
|
||||
| McpResourceBlock
|
||||
| McpResourceLinkBlock;
|
||||
|
||||
class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
ToolParams,
|
||||
ToolResult
|
||||
> {
|
||||
|
||||
Reference in New Issue
Block a user