diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index bc7b6e5fa2..b5174f827e 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -46,6 +46,16 @@ specific event. - `tool_input`: (`object`) The arguments passed to the tool. - `tool_response`: (`object`, **AfterTool only**) The raw output from the tool execution. +- `mcp_context`: (`object`, **optional**) Present only for MCP tool invocations. + Contains server identity information: + - `server_name`: (`string`) The configured name of the MCP server. + - `tool_name`: (`string`) The original tool name from the MCP server. + - `command`: (`string`, optional) For stdio transport, the command used to + start the server. + - `args`: (`string[]`, optional) For stdio transport, the command arguments. + - `cwd`: (`string`, optional) For stdio transport, the working directory. + - `url`: (`string`, optional) For SSE/HTTP transport, the server URL. + - `tcp`: (`string`, optional) For WebSocket transport, the TCP address. #### Agent Events (`BeforeAgent`, `AfterAgent`) diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 9c26173a69..a748c0b2d7 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -79,6 +79,7 @@ export async function loadConfig( : settings.checkpointing?.enabled, previewFeatures: settings.general?.previewFeatures, interactive: true, + enableInteractiveShell: true, }; const fileService = new FileDiscoveryService(workspaceDir); diff --git a/packages/cli/src/config/config.ts b/packages/cli/src/config/config.ts index 2e2ecbd87f..7ca8d2934d 100755 --- a/packages/cli/src/config/config.ts +++ b/packages/cli/src/config/config.ts @@ -637,6 +637,7 @@ export async function loadCliConfig( const ptyInfo = await getPty(); const mcpEnabled = settings.admin?.mcp?.enabled ?? true; + const extensionsEnabled = settings.admin?.extensions?.enabled ?? true; return new Config({ sessionId, @@ -659,6 +660,7 @@ export async function loadCliConfig( mcpServerCommand: mcpEnabled ? settings.mcp?.serverCommand : undefined, mcpServers: mcpEnabled ? settings.mcpServers : {}, mcpEnabled, + extensionsEnabled, allowedMcpServers: mcpEnabled ? (argv.allowedMcpServerNames ?? settings.mcp?.allowed) : undefined, diff --git a/packages/cli/src/config/extension-manager.ts b/packages/cli/src/config/extension-manager.ts index 3c4ed226c8..998b91529c 100644 --- a/packages/cli/src/config/extension-manager.ts +++ b/packages/cli/src/config/extension-manager.ts @@ -465,6 +465,12 @@ Would you like to attempt to install via "git clone" instead?`, if (this.loadedExtensions) { throw new Error('Extensions already loaded, only load extensions once.'); } + + if (this.settings.admin?.extensions?.enabled === false) { + this.loadedExtensions = []; + return this.loadedExtensions; + } + const extensionsDir = ExtensionStorage.getUserExtensionsDir(); this.loadedExtensions = []; if (!fs.existsSync(extensionsDir)) { @@ -537,12 +543,16 @@ Would you like to attempt to install via "git clone" instead?`, } if (config.mcpServers) { - config.mcpServers = Object.fromEntries( - Object.entries(config.mcpServers).map(([key, value]) => [ - key, - filterMcpConfig(value), - ]), - ); + if (this.settings.admin?.mcp?.enabled === false) { + config.mcpServers = undefined; + } else { + config.mcpServers = Object.fromEntries( + Object.entries(config.mcpServers).map(([key, value]) => [ + key, + filterMcpConfig(value), + ]), + ); + } } const contextFiles = getContextFileNames(config) diff --git a/packages/cli/src/config/extension.test.ts b/packages/cli/src/config/extension.test.ts index 0bfa7a0358..1807144e82 100644 --- a/packages/cli/src/config/extension.test.ts +++ b/packages/cli/src/config/extension.test.ts @@ -632,6 +632,79 @@ describe('extension tests', () => { expect(extension).toBeUndefined(); }); + it('should not load any extensions if admin.extensions.enabled is false', async () => { + createExtension({ + extensionsDir: userExtensionsDir, + name: 'test-extension', + version: '1.0.0', + }); + const loadedSettings = loadSettings(tempWorkspaceDir).merged; + (loadedSettings.admin ??= {}).extensions ??= {}; + loadedSettings.admin.extensions.enabled = false; + + extensionManager = new ExtensionManager({ + workspaceDir: tempWorkspaceDir, + requestConsent: mockRequestConsent, + requestSetting: mockPromptForSettings, + settings: loadedSettings, + }); + + const extensions = await extensionManager.loadExtensions(); + expect(extensions).toEqual([]); + }); + + it('should not load mcpServers if admin.mcp.enabled is false', async () => { + createExtension({ + extensionsDir: userExtensionsDir, + name: 'test-extension', + version: '1.0.0', + mcpServers: { + 'test-server': { command: 'echo', args: ['hello'] }, + }, + }); + const loadedSettings = loadSettings(tempWorkspaceDir).merged; + (loadedSettings.admin ??= {}).mcp ??= {}; + loadedSettings.admin.mcp.enabled = false; + + extensionManager = new ExtensionManager({ + workspaceDir: tempWorkspaceDir, + requestConsent: mockRequestConsent, + requestSetting: mockPromptForSettings, + settings: loadedSettings, + }); + + const extensions = await extensionManager.loadExtensions(); + expect(extensions).toHaveLength(1); + expect(extensions[0].mcpServers).toBeUndefined(); + }); + + it('should load mcpServers if admin.mcp.enabled is true', async () => { + createExtension({ + extensionsDir: userExtensionsDir, + name: 'test-extension', + version: '1.0.0', + mcpServers: { + 'test-server': { command: 'echo', args: ['hello'] }, + }, + }); + const loadedSettings = loadSettings(tempWorkspaceDir).merged; + (loadedSettings.admin ??= {}).mcp ??= {}; + loadedSettings.admin.mcp.enabled = true; + + extensionManager = new ExtensionManager({ + workspaceDir: tempWorkspaceDir, + requestConsent: mockRequestConsent, + requestSetting: mockPromptForSettings, + settings: loadedSettings, + }); + + const extensions = await extensionManager.loadExtensions(); + expect(extensions).toHaveLength(1); + expect(extensions[0].mcpServers).toEqual({ + 'test-server': { command: 'echo', args: ['hello'] }, + }); + }); + describe('id generation', () => { it.each([ { diff --git a/packages/cli/src/services/BuiltinCommandLoader.test.ts b/packages/cli/src/services/BuiltinCommandLoader.test.ts index 6bebf0b06e..22b7a47ffc 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.test.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.test.ts @@ -102,6 +102,7 @@ describe('BuiltinCommandLoader', () => { getEnableExtensionReloading: () => false, getEnableHooks: () => false, getEnableHooksUI: () => false, + getExtensionsEnabled: vi.fn().mockReturnValue(true), isSkillsSupportEnabled: vi.fn().mockReturnValue(false), getMcpEnabled: vi.fn().mockReturnValue(true), getSkillManager: vi.fn().mockReturnValue({ @@ -201,6 +202,7 @@ describe('BuiltinCommandLoader profile', () => { getEnableExtensionReloading: () => false, getEnableHooks: () => false, getEnableHooksUI: () => false, + getExtensionsEnabled: vi.fn().mockReturnValue(true), isSkillsSupportEnabled: vi.fn().mockReturnValue(false), getMcpEnabled: vi.fn().mockReturnValue(true), getSkillManager: vi.fn().mockReturnValue({ diff --git a/packages/cli/src/services/BuiltinCommandLoader.ts b/packages/cli/src/services/BuiltinCommandLoader.ts index ea72ecdb05..4320217220 100644 --- a/packages/cli/src/services/BuiltinCommandLoader.ts +++ b/packages/cli/src/services/BuiltinCommandLoader.ts @@ -76,7 +76,24 @@ export class BuiltinCommandLoader implements ICommandLoader { docsCommand, directoryCommand, editorCommand, - extensionsCommand(this.config?.getEnableExtensionReloading()), + ...(this.config?.getExtensionsEnabled() === false + ? [ + { + name: 'extensions', + description: 'Manage extensions', + kind: CommandKind.BUILT_IN, + autoExecute: false, + subCommands: [], + action: async ( + _context: CommandContext, + ): Promise => ({ + type: 'message', + messageType: 'error', + content: 'Extensions are disabled by your admin.', + }), + }, + ] + : [extensionsCommand(this.config?.getEnableExtensionReloading())]), helpCommand, ...(this.config?.getEnableHooksUI() ? [hooksCommand] : []), await ideCommand(), @@ -95,7 +112,7 @@ export class BuiltinCommandLoader implements ICommandLoader { ): Promise => ({ type: 'message', messageType: 'error', - content: 'MCP disabled by your admin.', + content: 'MCP is disabled by your admin.', }), }, ] diff --git a/packages/cli/src/ui/commands/extensionsCommand.test.ts b/packages/cli/src/ui/commands/extensionsCommand.test.ts index 4af145b631..55f20eb25d 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.test.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.test.ts @@ -31,6 +31,7 @@ import { inferInstallMetadata, } from '../../config/extension-manager.js'; import { SettingScope } from '../../config/settings.js'; +import { stat } from 'node:fs/promises'; vi.mock('../../config/extension-manager.js', async (importOriginal) => { const actual = @@ -42,11 +43,16 @@ vi.mock('../../config/extension-manager.js', async (importOriginal) => { }); import open from 'open'; +import type { Stats } from 'node:fs'; vi.mock('open', () => ({ default: vi.fn(), })); +vi.mock('node:fs/promises', () => ({ + stat: vi.fn(), +})); + vi.mock('../../config/extensions/update.js', () => ({ updateExtension: vi.fn(), checkForAllExtensionUpdates: vi.fn(), @@ -493,34 +499,37 @@ describe('extensionsCommand', () => { }); describe('when enableExtensionReloading is true', () => { - it('should include enable, disable, install, and uninstall subcommands', () => { + it('should include enable, disable, install, link, and uninstall subcommands', () => { const command = extensionsCommand(true); const subCommandNames = command.subCommands?.map((cmd) => cmd.name); expect(subCommandNames).toContain('enable'); expect(subCommandNames).toContain('disable'); expect(subCommandNames).toContain('install'); + expect(subCommandNames).toContain('link'); expect(subCommandNames).toContain('uninstall'); }); }); describe('when enableExtensionReloading is false', () => { - it('should not include enable, disable, install, and uninstall subcommands', () => { + it('should not include enable, disable, install, link, and uninstall subcommands', () => { const command = extensionsCommand(false); const subCommandNames = command.subCommands?.map((cmd) => cmd.name); expect(subCommandNames).not.toContain('enable'); expect(subCommandNames).not.toContain('disable'); expect(subCommandNames).not.toContain('install'); + expect(subCommandNames).not.toContain('link'); expect(subCommandNames).not.toContain('uninstall'); }); }); describe('when enableExtensionReloading is not provided', () => { - it('should not include enable, disable, install, and uninstall subcommands by default', () => { + it('should not include enable, disable, install, link, and uninstall subcommands by default', () => { const command = extensionsCommand(); const subCommandNames = command.subCommands?.map((cmd) => cmd.name); expect(subCommandNames).not.toContain('enable'); expect(subCommandNames).not.toContain('disable'); expect(subCommandNames).not.toContain('install'); + expect(subCommandNames).not.toContain('link'); expect(subCommandNames).not.toContain('uninstall'); }); }); @@ -617,6 +626,88 @@ describe('extensionsCommand', () => { }); }); + describe('link', () => { + let linkAction: SlashCommand['action']; + + beforeEach(() => { + linkAction = extensionsCommand(true).subCommands?.find( + (cmd) => cmd.name === 'link', + )?.action; + + expect(linkAction).not.toBeNull(); + mockContext.invocation!.name = 'link'; + }); + + it('should show usage if no extension is provided', async () => { + await linkAction!(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: 'Usage: /extensions link ', + }, + expect.any(Number), + ); + expect(mockInstallExtension).not.toHaveBeenCalled(); + }); + + it('should call installExtension and show success message', async () => { + const packageName = 'test-extension-package'; + mockInstallExtension.mockResolvedValue({ name: packageName }); + vi.mocked(stat).mockResolvedValue({ + size: 100, + } as Stats); + await linkAction!(mockContext, packageName); + expect(mockInstallExtension).toHaveBeenCalledWith({ + source: packageName, + type: 'link', + }); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: `Linking extension from "${packageName}"...`, + }, + expect.any(Number), + ); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: `Extension "${packageName}" linked successfully.`, + }, + expect.any(Number), + ); + }); + + it('should show error message on linking failure', async () => { + const packageName = 'test-extension-package'; + const errorMessage = 'link failed'; + mockInstallExtension.mockRejectedValue(new Error(errorMessage)); + vi.mocked(stat).mockResolvedValue({ + size: 100, + } as Stats); + + await linkAction!(mockContext, packageName); + expect(mockInstallExtension).toHaveBeenCalledWith({ + source: packageName, + type: 'link', + }); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + { + type: MessageType.ERROR, + text: `Failed to link extension from "${packageName}": ${errorMessage}`, + }, + expect.any(Number), + ); + }); + + it('should show error message for invalid source', async () => { + const packageName = 'test-extension-package'; + const errorMessage = 'invalid path'; + vi.mocked(stat).mockRejectedValue(new Error(errorMessage)); + await linkAction!(mockContext, packageName); + expect(mockInstallExtension).not.toHaveBeenCalled(); + }); + }); + describe('uninstall', () => { let uninstallAction: SlashCommand['action']; diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index 99ea05bccf..7c21115880 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -4,7 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { debugLogger, listExtensions } from '@google/gemini-cli-core'; +import { + debugLogger, + listExtensions, + type ExtensionInstallMetadata, +} from '@google/gemini-cli-core'; import type { ExtensionUpdateInfo } from '../../config/extension.js'; import { getErrorMessage } from '../../utils/errors.js'; import { @@ -26,6 +30,7 @@ import { } from '../../config/extension-manager.js'; import { SettingScope } from '../../config/settings.js'; import { theme } from '../semantic-colors.js'; +import { stat } from 'node:fs/promises'; function showMessageIfNoExtensions( context: CommandContext, @@ -510,6 +515,88 @@ async function installAction(context: CommandContext, args: string) { } } +async function linkAction(context: CommandContext, args: string) { + const extensionLoader = context.services.config?.getExtensionLoader(); + if (!(extensionLoader instanceof ExtensionManager)) { + debugLogger.error( + `Cannot ${context.invocation?.name} extensions in this environment`, + ); + return; + } + + const sourceFilepath = args.trim(); + if (!sourceFilepath) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Usage: /extensions link `, + }, + Date.now(), + ); + return; + } + if (/[;&|`'"]/.test(sourceFilepath)) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Source file path contains disallowed characters: ${sourceFilepath}`, + }, + Date.now(), + ); + return; + } + + try { + await stat(sourceFilepath); + } catch (error) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Invalid source: ${sourceFilepath}`, + }, + Date.now(), + ); + debugLogger.error( + `Failed to stat path "${sourceFilepath}": ${getErrorMessage(error)}`, + ); + return; + } + + context.ui.addItem( + { + type: MessageType.INFO, + text: `Linking extension from "${sourceFilepath}"...`, + }, + Date.now(), + ); + + try { + const installMetadata: ExtensionInstallMetadata = { + source: sourceFilepath, + type: 'link', + }; + const extension = + await extensionLoader.installOrUpdateExtension(installMetadata); + context.ui.addItem( + { + type: MessageType.INFO, + text: `Extension "${extension.name}" linked successfully.`, + }, + Date.now(), + ); + } catch (error) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: `Failed to link extension from "${sourceFilepath}": ${getErrorMessage( + error, + )}`, + }, + Date.now(), + ); + } +} + async function uninstallAction(context: CommandContext, args: string) { const extensionLoader = context.services.config?.getExtensionLoader(); if (!(extensionLoader instanceof ExtensionManager)) { @@ -645,6 +732,14 @@ const installCommand: SlashCommand = { action: installAction, }; +const linkCommand: SlashCommand = { + name: 'link', + description: 'Link an extension from a local path', + kind: CommandKind.BUILT_IN, + autoExecute: false, + action: linkAction, +}; + const uninstallCommand: SlashCommand = { name: 'uninstall', description: 'Uninstall an extension', @@ -675,7 +770,13 @@ export function extensionsCommand( enableExtensionReloading?: boolean, ): SlashCommand { const conditionalCommands = enableExtensionReloading - ? [disableCommand, enableCommand, installCommand, uninstallCommand] + ? [ + disableCommand, + enableCommand, + installCommand, + uninstallCommand, + linkCommand, + ] : []; return { name: 'extensions', diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 8e7a8e42cb..01615c1081 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -357,6 +357,7 @@ export interface ConfigParameters { experimentalJitContext?: boolean; onModelChange?: (model: string) => void; mcpEnabled?: boolean; + extensionsEnabled?: boolean; onReload?: () => Promise<{ disabledSkills?: string[] }>; } @@ -391,6 +392,7 @@ export class Config { private readonly toolCallCommand: string | undefined; private readonly mcpServerCommand: string | undefined; private readonly mcpEnabled: boolean; + private readonly extensionsEnabled: boolean; private mcpServers: Record | undefined; private userMemory: string; private geminiMdFileCount: number; @@ -517,6 +519,7 @@ export class Config { this.mcpServerCommand = params.mcpServerCommand; this.mcpServers = params.mcpServers; this.mcpEnabled = params.mcpEnabled ?? true; + this.extensionsEnabled = params.extensionsEnabled ?? true; this.allowedMcpServers = params.allowedMcpServers ?? []; this.blockedMcpServers = params.blockedMcpServers ?? []; this.allowedEnvironmentVariables = params.allowedEnvironmentVariables ?? []; @@ -1143,6 +1146,10 @@ export class Config { return this.mcpEnabled; } + getExtensionsEnabled(): boolean { + return this.extensionsEnabled; + } + getMcpClientManager(): McpClientManager | undefined { return this.mcpClientManager; } diff --git a/packages/core/src/core/coreToolHookTriggers.ts b/packages/core/src/core/coreToolHookTriggers.ts index 70f9e93c1d..ca1467518b 100644 --- a/packages/core/src/core/coreToolHookTriggers.ts +++ b/packages/core/src/core/coreToolHookTriggers.ts @@ -14,8 +14,10 @@ import { createHookOutput, NotificationType, type DefaultHookOutput, + type McpToolContext, BeforeToolHookOutput, } from '../hooks/types.js'; +import type { Config } from '../config/config.js'; import type { ToolCallConfirmationDetails, ToolResult, @@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js'; import type { AnsiOutput, ShellExecutionConfig } from '../index.js'; import type { AnyToolInvocation } from '../tools/tools.js'; import { ShellToolInvocation } from '../tools/shell.js'; +import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js'; /** * Serializable representation of tool confirmation details for hooks. @@ -154,18 +157,57 @@ export async function fireToolNotificationHook( } } +/** + * Extracts MCP context from a tool invocation if it's an MCP tool. + * + * @param invocation The tool invocation + * @param config Config to look up server details + * @returns MCP context if this is an MCP tool, undefined otherwise + */ +function extractMcpContext( + invocation: ShellToolInvocation | AnyToolInvocation, + config: Config, +): McpToolContext | undefined { + if (!(invocation instanceof DiscoveredMCPToolInvocation)) { + return undefined; + } + + // Get the server config + const mcpServers = + config.getMcpClientManager()?.getMcpServers() ?? + config.getMcpServers() ?? + {}; + const serverConfig = mcpServers[invocation.serverName]; + if (!serverConfig) { + return undefined; + } + + return { + server_name: invocation.serverName, + tool_name: invocation.serverToolName, + // Non-sensitive connection details only + command: serverConfig.command, + args: serverConfig.args, + cwd: serverConfig.cwd, + url: serverConfig.url ?? serverConfig.httpUrl, + tcp: serverConfig.tcp, + }; +} + /** * Fires the BeforeTool hook and returns the hook output. * * @param messageBus The message bus to use for hook communication * @param toolName The name of the tool being executed * @param toolInput The input parameters for the tool + * @param mcpContext Optional MCP context for MCP tools * @returns The hook output, or undefined if no hook was executed or on error */ export async function fireBeforeToolHook( messageBus: MessageBus, toolName: string, toolInput: Record, + mcpContext?: McpToolContext, ): Promise { try { const response = await messageBus.request< @@ -178,6 +220,7 @@ export async function fireBeforeToolHook( input: { tool_name: toolName, tool_input: toolInput, + ...(mcpContext && { mcp_context: mcpContext }), }, }, MessageBusType.HOOK_EXECUTION_RESPONSE, @@ -199,6 +242,7 @@ export async function fireBeforeToolHook( * @param toolName The name of the tool that was executed * @param toolInput The input parameters for the tool * @param toolResponse The result from the tool execution + * @param mcpContext Optional MCP context for MCP tools * @returns The hook output, or undefined if no hook was executed or on error */ export async function fireAfterToolHook( @@ -210,6 +254,7 @@ export async function fireAfterToolHook( returnDisplay: ToolResult['returnDisplay']; error: ToolResult['error']; }, + mcpContext?: McpToolContext, ): Promise { try { const response = await messageBus.request< @@ -223,6 +268,7 @@ export async function fireAfterToolHook( tool_name: toolName, tool_input: toolInput, tool_response: toolResponse, + ...(mcpContext && { mcp_context: mcpContext }), }, }, MessageBusType.HOOK_EXECUTION_RESPONSE, @@ -248,6 +294,7 @@ export async function fireAfterToolHook( * @param liveOutputCallback Optional callback for live output updates * @param shellExecutionConfig Optional shell execution config * @param setPidCallback Optional callback to set the PID for shell invocations + * @param config Config to look up MCP server details for hook context * @returns The tool result */ export async function executeToolWithHooks( @@ -260,17 +307,22 @@ export async function executeToolWithHooks( liveOutputCallback?: (outputChunk: string | AnsiOutput) => void, shellExecutionConfig?: ShellExecutionConfig, setPidCallback?: (pid: number) => void, + config?: Config, ): Promise { const toolInput = (invocation.params || {}) as Record; let inputWasModified = false; let modifiedKeys: string[] = []; + // Extract MCP context if this is an MCP tool (only if config is provided) + const mcpContext = config ? extractMcpContext(invocation, config) : undefined; + // Fire BeforeTool hook through MessageBus (only if hooks are enabled) if (hooksEnabled && messageBus) { const beforeOutput = await fireBeforeToolHook( messageBus, toolName, toolInput, + mcpContext, ); // Check if hook requested to stop entire agent execution @@ -378,6 +430,7 @@ export async function executeToolWithHooks( returnDisplay: toolResult.returnDisplay, error: toolResult.error, }, + mcpContext, ); // Check if hook requested to stop entire agent execution diff --git a/packages/core/src/hooks/hookEventHandler.test.ts b/packages/core/src/hooks/hookEventHandler.test.ts index 2bffc805b6..af7a6be37a 100644 --- a/packages/core/src/hooks/hookEventHandler.test.ts +++ b/packages/core/src/hooks/hookEventHandler.test.ts @@ -258,6 +258,128 @@ describe('HookEventHandler', () => { expect.stringContaining('F12'), ); }); + + it('should fire BeforeTool event with MCP context when provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const mcpContext = { + server_name: 'my-mcp-server', + tool_name: 'read_file', + command: 'npx', + args: ['-y', '@my-org/mcp-server'], + }; + + const result = await hookEventHandler.fireBeforeToolEvent( + 'my-mcp-server__read_file', + { path: '/etc/passwd' }, + mcpContext, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.BeforeTool, + expect.objectContaining({ + session_id: 'test-session', + cwd: '/test/project', + hook_event_name: 'BeforeTool', + tool_name: 'my-mcp-server__read_file', + tool_input: { path: '/etc/passwd' }, + mcp_context: mcpContext, + }), + expect.any(Function), + expect.any(Function), + ); + + expect(result).toBe(mockAggregated); + }); + + it('should not include mcp_context when not provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + await hookEventHandler.fireBeforeToolEvent('EditTool', { + file: 'test.txt', + }); + + const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock + .calls[0][2]; + expect(callArgs).not.toHaveProperty('mcp_context'); + }); }); describe('fireAfterToolEvent', () => { @@ -325,6 +447,78 @@ describe('HookEventHandler', () => { expect(result).toBe(mockAggregated); }); + + it('should fire AfterTool event with MCP context when provided', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './after.sh', + } as unknown as HookConfig, + eventName: HookEventName.AfterTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './after.sh', + timeout: 30000, + }, + eventName: HookEventName.AfterTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.AfterTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const toolInput = { path: '/etc/passwd' }; + const toolResponse = { success: true, content: 'File content' }; + const mcpContext = { + server_name: 'my-mcp-server', + tool_name: 'read_file', + url: 'https://mcp.example.com', + }; + + const result = await hookEventHandler.fireAfterToolEvent( + 'my-mcp-server__read_file', + toolInput, + toolResponse, + mcpContext, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.AfterTool, + expect.objectContaining({ + tool_name: 'my-mcp-server__read_file', + tool_input: toolInput, + tool_response: toolResponse, + mcp_context: mcpContext, + }), + expect.any(Function), + expect.any(Function), + ); + + expect(result).toBe(mockAggregated); + }); }); describe('fireBeforeAgentEvent', () => { diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts index e72aee913a..e208dd1ed4 100644 --- a/packages/core/src/hooks/hookEventHandler.ts +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -29,6 +29,7 @@ import type { SessionEndReason, PreCompressTrigger, HookExecutionResult, + McpToolContext, } from './types.js'; import { defaultHookTranslator } from './hookTranslator.js'; import type { @@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record { function validateBeforeToolInput(input: Record): { toolName: string; toolInput: Record; + mcpContext?: McpToolContext; } { const toolName = input['tool_name']; const toolInput = input['tool_input']; + const mcpContext = input['mcp_context']; if (typeof toolName !== 'string') { throw new Error( 'Invalid input for BeforeTool hook event: tool_name must be a string', @@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record): { 'Invalid input for BeforeTool hook event: tool_input must be an object', ); } - return { toolName, toolInput }; + if (mcpContext !== undefined && !isObject(mcpContext)) { + throw new Error( + 'Invalid input for BeforeTool hook event: mcp_context must be an object', + ); + } + return { + toolName, + toolInput, + mcpContext: mcpContext as McpToolContext | undefined, + }; } /** @@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record): { toolName: string; toolInput: Record; toolResponse: Record; + mcpContext?: McpToolContext; } { const toolName = input['tool_name']; const toolInput = input['tool_input']; const toolResponse = input['tool_response']; + const mcpContext = input['mcp_context']; if (typeof toolName !== 'string') { throw new Error( 'Invalid input for AfterTool hook event: tool_name must be a string', @@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record): { 'Invalid input for AfterTool hook event: tool_response must be an object', ); } - return { toolName, toolInput, toolResponse }; + if (mcpContext !== undefined && !isObject(mcpContext)) { + throw new Error( + 'Invalid input for AfterTool hook event: mcp_context must be an object', + ); + } + return { + toolName, + toolInput, + toolResponse, + mcpContext: mcpContext as McpToolContext | undefined, + }; } /** @@ -313,11 +337,13 @@ export class HookEventHandler { async fireBeforeToolEvent( toolName: string, toolInput: Record, + mcpContext?: McpToolContext, ): Promise { const input: BeforeToolInput = { ...this.createBaseInput(HookEventName.BeforeTool), tool_name: toolName, tool_input: toolInput, + ...(mcpContext && { mcp_context: mcpContext }), }; const context: HookEventContext = { toolName }; @@ -332,12 +358,14 @@ export class HookEventHandler { toolName: string, toolInput: Record, toolResponse: Record, + mcpContext?: McpToolContext, ): Promise { const input: AfterToolInput = { ...this.createBaseInput(HookEventName.AfterTool), tool_name: toolName, tool_input: toolInput, tool_response: toolResponse, + ...(mcpContext && { mcp_context: mcpContext }), }; const context: HookEventContext = { toolName }; @@ -725,18 +753,23 @@ export class HookEventHandler { // Route to appropriate event handler based on eventName switch (request.eventName) { case HookEventName.BeforeTool: { - const { toolName, toolInput } = + const { toolName, toolInput, mcpContext } = validateBeforeToolInput(enrichedInput); - result = await this.fireBeforeToolEvent(toolName, toolInput); + result = await this.fireBeforeToolEvent( + toolName, + toolInput, + mcpContext, + ); break; } case HookEventName.AfterTool: { - const { toolName, toolInput, toolResponse } = + const { toolName, toolInput, toolResponse, mcpContext } = validateAfterToolInput(enrichedInput); result = await this.fireAfterToolEvent( toolName, toolInput, toolResponse, + mcpContext, ); break; } diff --git a/packages/core/src/hooks/hookTranslator.ts b/packages/core/src/hooks/hookTranslator.ts index 9cbcd903f8..56036a16db 100644 --- a/packages/core/src/hooks/hookTranslator.ts +++ b/packages/core/src/hooks/hookTranslator.ts @@ -12,6 +12,7 @@ import type { FunctionCallingConfig, } from '@google/genai'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { getResponseText } from '../utils/partUtils.js'; /** * Decoupled LLM request format - stable across Gemini CLI versions @@ -267,7 +268,7 @@ export class HookTranslatorGenAIv1 extends HookTranslator { */ toHookLLMResponse(sdkResponse: GenerateContentResponse): LLMResponse { return { - text: sdkResponse.text, + text: getResponseText(sdkResponse) ?? undefined, candidates: (sdkResponse.candidates || []).map((candidate) => { // Extract text parts from the candidate const textParts = diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index e54a03f840..5ca7bd5fb1 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput { } } +/** + * Context for MCP tool executions. + * Contains non-sensitive connection information about the MCP server + * identity. Since server_name is user controlled and arbitrary, we + * also include connection information (e.g., command or url) to + * help identify the MCP server. + * + * NOTE: In the future, consider defining a shared sanitized interface + * from MCPServerConfig to avoid duplication and ensure consistency. + */ +export interface McpToolContext { + server_name: string; + tool_name: string; // Original tool name from the MCP server + + // Connection info (mutually exclusive based on transport type) + command?: string; // For stdio transport + args?: string[]; // For stdio transport + cwd?: string; // For stdio transport + + url?: string; // For SSE/HTTP transport + + tcp?: string; // For WebSocket transport +} + /** * BeforeTool hook input */ export interface BeforeToolInput extends HookInput { tool_name: string; tool_input: Record; + mcp_context?: McpToolContext; // Only present for MCP tools } /** @@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput { tool_name: string; tool_input: Record; tool_response: Record; + mcp_context?: McpToolContext; // Only present for MCP tools } /** diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 8334168b93..233ff998ff 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -98,6 +98,7 @@ export class ToolExecutor { liveOutputCallback, shellExecutionConfig, setPidCallback, + this.config, ); } else { promise = executeToolWithHooks( @@ -109,6 +110,8 @@ export class ToolExecutor { tool, liveOutputCallback, shellExecutionConfig, + undefined, + this.config, ); } diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 44a07d99e8..8259b6c2f3 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -59,7 +59,7 @@ type McpContentBlock = | McpResourceBlock | McpResourceLinkBlock; -class DiscoveredMCPToolInvocation extends BaseToolInvocation< +export class DiscoveredMCPToolInvocation extends BaseToolInvocation< ToolParams, ToolResult > {