Merge branch 'main' into dyim/add-api-version-env-var

This commit is contained in:
Danielle Yim
2026-01-08 10:46:54 -08:00
committed by GitHub
17 changed files with 644 additions and 20 deletions
+1
View File
@@ -79,6 +79,7 @@ export async function loadConfig(
: settings.checkpointing?.enabled,
previewFeatures: settings.general?.previewFeatures,
interactive: true,
enableInteractiveShell: true,
};
const fileService = new FileDiscoveryService(workspaceDir);
+2
View File
@@ -637,6 +637,7 @@ export async function loadCliConfig(
const ptyInfo = await getPty();
const mcpEnabled = settings.admin?.mcp?.enabled ?? true;
const extensionsEnabled = settings.admin?.extensions?.enabled ?? true;
return new Config({
sessionId,
@@ -659,6 +660,7 @@ export async function loadCliConfig(
mcpServerCommand: mcpEnabled ? settings.mcp?.serverCommand : undefined,
mcpServers: mcpEnabled ? settings.mcpServers : {},
mcpEnabled,
extensionsEnabled,
allowedMcpServers: mcpEnabled
? (argv.allowedMcpServerNames ?? settings.mcp?.allowed)
: undefined,
+16 -6
View File
@@ -465,6 +465,12 @@ Would you like to attempt to install via "git clone" instead?`,
if (this.loadedExtensions) {
throw new Error('Extensions already loaded, only load extensions once.');
}
if (this.settings.admin?.extensions?.enabled === false) {
this.loadedExtensions = [];
return this.loadedExtensions;
}
const extensionsDir = ExtensionStorage.getUserExtensionsDir();
this.loadedExtensions = [];
if (!fs.existsSync(extensionsDir)) {
@@ -537,12 +543,16 @@ Would you like to attempt to install via "git clone" instead?`,
}
if (config.mcpServers) {
config.mcpServers = Object.fromEntries(
Object.entries(config.mcpServers).map(([key, value]) => [
key,
filterMcpConfig(value),
]),
);
if (this.settings.admin?.mcp?.enabled === false) {
config.mcpServers = undefined;
} else {
config.mcpServers = Object.fromEntries(
Object.entries(config.mcpServers).map(([key, value]) => [
key,
filterMcpConfig(value),
]),
);
}
}
const contextFiles = getContextFileNames(config)
+73
View File
@@ -632,6 +632,79 @@ describe('extension tests', () => {
expect(extension).toBeUndefined();
});
it('should not load any extensions if admin.extensions.enabled is false', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
version: '1.0.0',
});
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
(loadedSettings.admin ??= {}).extensions ??= {};
loadedSettings.admin.extensions.enabled = false;
extensionManager = new ExtensionManager({
workspaceDir: tempWorkspaceDir,
requestConsent: mockRequestConsent,
requestSetting: mockPromptForSettings,
settings: loadedSettings,
});
const extensions = await extensionManager.loadExtensions();
expect(extensions).toEqual([]);
});
it('should not load mcpServers if admin.mcp.enabled is false', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
version: '1.0.0',
mcpServers: {
'test-server': { command: 'echo', args: ['hello'] },
},
});
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
(loadedSettings.admin ??= {}).mcp ??= {};
loadedSettings.admin.mcp.enabled = false;
extensionManager = new ExtensionManager({
workspaceDir: tempWorkspaceDir,
requestConsent: mockRequestConsent,
requestSetting: mockPromptForSettings,
settings: loadedSettings,
});
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].mcpServers).toBeUndefined();
});
it('should load mcpServers if admin.mcp.enabled is true', async () => {
createExtension({
extensionsDir: userExtensionsDir,
name: 'test-extension',
version: '1.0.0',
mcpServers: {
'test-server': { command: 'echo', args: ['hello'] },
},
});
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
(loadedSettings.admin ??= {}).mcp ??= {};
loadedSettings.admin.mcp.enabled = true;
extensionManager = new ExtensionManager({
workspaceDir: tempWorkspaceDir,
requestConsent: mockRequestConsent,
requestSetting: mockPromptForSettings,
settings: loadedSettings,
});
const extensions = await extensionManager.loadExtensions();
expect(extensions).toHaveLength(1);
expect(extensions[0].mcpServers).toEqual({
'test-server': { command: 'echo', args: ['hello'] },
});
});
describe('id generation', () => {
it.each([
{
@@ -102,6 +102,7 @@ describe('BuiltinCommandLoader', () => {
getEnableExtensionReloading: () => false,
getEnableHooks: () => false,
getEnableHooksUI: () => false,
getExtensionsEnabled: vi.fn().mockReturnValue(true),
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
getMcpEnabled: vi.fn().mockReturnValue(true),
getSkillManager: vi.fn().mockReturnValue({
@@ -201,6 +202,7 @@ describe('BuiltinCommandLoader profile', () => {
getEnableExtensionReloading: () => false,
getEnableHooks: () => false,
getEnableHooksUI: () => false,
getExtensionsEnabled: vi.fn().mockReturnValue(true),
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
getMcpEnabled: vi.fn().mockReturnValue(true),
getSkillManager: vi.fn().mockReturnValue({
@@ -76,7 +76,24 @@ export class BuiltinCommandLoader implements ICommandLoader {
docsCommand,
directoryCommand,
editorCommand,
extensionsCommand(this.config?.getEnableExtensionReloading()),
...(this.config?.getExtensionsEnabled() === false
? [
{
name: 'extensions',
description: 'Manage extensions',
kind: CommandKind.BUILT_IN,
autoExecute: false,
subCommands: [],
action: async (
_context: CommandContext,
): Promise<MessageActionReturn> => ({
type: 'message',
messageType: 'error',
content: 'Extensions are disabled by your admin.',
}),
},
]
: [extensionsCommand(this.config?.getEnableExtensionReloading())]),
helpCommand,
...(this.config?.getEnableHooksUI() ? [hooksCommand] : []),
await ideCommand(),
@@ -95,7 +112,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
): Promise<MessageActionReturn> => ({
type: 'message',
messageType: 'error',
content: 'MCP disabled by your admin.',
content: 'MCP is disabled by your admin.',
}),
},
]
@@ -31,6 +31,7 @@ import {
inferInstallMetadata,
} from '../../config/extension-manager.js';
import { SettingScope } from '../../config/settings.js';
import { stat } from 'node:fs/promises';
vi.mock('../../config/extension-manager.js', async (importOriginal) => {
const actual =
@@ -42,11 +43,16 @@ vi.mock('../../config/extension-manager.js', async (importOriginal) => {
});
import open from 'open';
import type { Stats } from 'node:fs';
vi.mock('open', () => ({
default: vi.fn(),
}));
vi.mock('node:fs/promises', () => ({
stat: vi.fn(),
}));
vi.mock('../../config/extensions/update.js', () => ({
updateExtension: vi.fn(),
checkForAllExtensionUpdates: vi.fn(),
@@ -493,34 +499,37 @@ describe('extensionsCommand', () => {
});
describe('when enableExtensionReloading is true', () => {
it('should include enable, disable, install, and uninstall subcommands', () => {
it('should include enable, disable, install, link, and uninstall subcommands', () => {
const command = extensionsCommand(true);
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
expect(subCommandNames).toContain('enable');
expect(subCommandNames).toContain('disable');
expect(subCommandNames).toContain('install');
expect(subCommandNames).toContain('link');
expect(subCommandNames).toContain('uninstall');
});
});
describe('when enableExtensionReloading is false', () => {
it('should not include enable, disable, install, and uninstall subcommands', () => {
it('should not include enable, disable, install, link, and uninstall subcommands', () => {
const command = extensionsCommand(false);
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
expect(subCommandNames).not.toContain('enable');
expect(subCommandNames).not.toContain('disable');
expect(subCommandNames).not.toContain('install');
expect(subCommandNames).not.toContain('link');
expect(subCommandNames).not.toContain('uninstall');
});
});
describe('when enableExtensionReloading is not provided', () => {
it('should not include enable, disable, install, and uninstall subcommands by default', () => {
it('should not include enable, disable, install, link, and uninstall subcommands by default', () => {
const command = extensionsCommand();
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
expect(subCommandNames).not.toContain('enable');
expect(subCommandNames).not.toContain('disable');
expect(subCommandNames).not.toContain('install');
expect(subCommandNames).not.toContain('link');
expect(subCommandNames).not.toContain('uninstall');
});
});
@@ -617,6 +626,88 @@ describe('extensionsCommand', () => {
});
});
describe('link', () => {
let linkAction: SlashCommand['action'];
beforeEach(() => {
linkAction = extensionsCommand(true).subCommands?.find(
(cmd) => cmd.name === 'link',
)?.action;
expect(linkAction).not.toBeNull();
mockContext.invocation!.name = 'link';
});
it('should show usage if no extension is provided', async () => {
await linkAction!(mockContext, '');
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: MessageType.ERROR,
text: 'Usage: /extensions link <source>',
},
expect.any(Number),
);
expect(mockInstallExtension).not.toHaveBeenCalled();
});
it('should call installExtension and show success message', async () => {
const packageName = 'test-extension-package';
mockInstallExtension.mockResolvedValue({ name: packageName });
vi.mocked(stat).mockResolvedValue({
size: 100,
} as Stats);
await linkAction!(mockContext, packageName);
expect(mockInstallExtension).toHaveBeenCalledWith({
source: packageName,
type: 'link',
});
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: `Linking extension from "${packageName}"...`,
},
expect.any(Number),
);
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: MessageType.INFO,
text: `Extension "${packageName}" linked successfully.`,
},
expect.any(Number),
);
});
it('should show error message on linking failure', async () => {
const packageName = 'test-extension-package';
const errorMessage = 'link failed';
mockInstallExtension.mockRejectedValue(new Error(errorMessage));
vi.mocked(stat).mockResolvedValue({
size: 100,
} as Stats);
await linkAction!(mockContext, packageName);
expect(mockInstallExtension).toHaveBeenCalledWith({
source: packageName,
type: 'link',
});
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
{
type: MessageType.ERROR,
text: `Failed to link extension from "${packageName}": ${errorMessage}`,
},
expect.any(Number),
);
});
it('should show error message for invalid source', async () => {
const packageName = 'test-extension-package';
const errorMessage = 'invalid path';
vi.mocked(stat).mockRejectedValue(new Error(errorMessage));
await linkAction!(mockContext, packageName);
expect(mockInstallExtension).not.toHaveBeenCalled();
});
});
describe('uninstall', () => {
let uninstallAction: SlashCommand['action'];
@@ -4,7 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { debugLogger, listExtensions } from '@google/gemini-cli-core';
import {
debugLogger,
listExtensions,
type ExtensionInstallMetadata,
} from '@google/gemini-cli-core';
import type { ExtensionUpdateInfo } from '../../config/extension.js';
import { getErrorMessage } from '../../utils/errors.js';
import {
@@ -26,6 +30,7 @@ import {
} from '../../config/extension-manager.js';
import { SettingScope } from '../../config/settings.js';
import { theme } from '../semantic-colors.js';
import { stat } from 'node:fs/promises';
function showMessageIfNoExtensions(
context: CommandContext,
@@ -510,6 +515,88 @@ async function installAction(context: CommandContext, args: string) {
}
}
async function linkAction(context: CommandContext, args: string) {
const extensionLoader = context.services.config?.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
debugLogger.error(
`Cannot ${context.invocation?.name} extensions in this environment`,
);
return;
}
const sourceFilepath = args.trim();
if (!sourceFilepath) {
context.ui.addItem(
{
type: MessageType.ERROR,
text: `Usage: /extensions link <source>`,
},
Date.now(),
);
return;
}
if (/[;&|`'"]/.test(sourceFilepath)) {
context.ui.addItem(
{
type: MessageType.ERROR,
text: `Source file path contains disallowed characters: ${sourceFilepath}`,
},
Date.now(),
);
return;
}
try {
await stat(sourceFilepath);
} catch (error) {
context.ui.addItem(
{
type: MessageType.ERROR,
text: `Invalid source: ${sourceFilepath}`,
},
Date.now(),
);
debugLogger.error(
`Failed to stat path "${sourceFilepath}": ${getErrorMessage(error)}`,
);
return;
}
context.ui.addItem(
{
type: MessageType.INFO,
text: `Linking extension from "${sourceFilepath}"...`,
},
Date.now(),
);
try {
const installMetadata: ExtensionInstallMetadata = {
source: sourceFilepath,
type: 'link',
};
const extension =
await extensionLoader.installOrUpdateExtension(installMetadata);
context.ui.addItem(
{
type: MessageType.INFO,
text: `Extension "${extension.name}" linked successfully.`,
},
Date.now(),
);
} catch (error) {
context.ui.addItem(
{
type: MessageType.ERROR,
text: `Failed to link extension from "${sourceFilepath}": ${getErrorMessage(
error,
)}`,
},
Date.now(),
);
}
}
async function uninstallAction(context: CommandContext, args: string) {
const extensionLoader = context.services.config?.getExtensionLoader();
if (!(extensionLoader instanceof ExtensionManager)) {
@@ -645,6 +732,14 @@ const installCommand: SlashCommand = {
action: installAction,
};
const linkCommand: SlashCommand = {
name: 'link',
description: 'Link an extension from a local path',
kind: CommandKind.BUILT_IN,
autoExecute: false,
action: linkAction,
};
const uninstallCommand: SlashCommand = {
name: 'uninstall',
description: 'Uninstall an extension',
@@ -675,7 +770,13 @@ export function extensionsCommand(
enableExtensionReloading?: boolean,
): SlashCommand {
const conditionalCommands = enableExtensionReloading
? [disableCommand, enableCommand, installCommand, uninstallCommand]
? [
disableCommand,
enableCommand,
installCommand,
uninstallCommand,
linkCommand,
]
: [];
return {
name: 'extensions',
+7
View File
@@ -357,6 +357,7 @@ export interface ConfigParameters {
experimentalJitContext?: boolean;
onModelChange?: (model: string) => void;
mcpEnabled?: boolean;
extensionsEnabled?: boolean;
onReload?: () => Promise<{ disabledSkills?: string[] }>;
}
@@ -391,6 +392,7 @@ export class Config {
private readonly toolCallCommand: string | undefined;
private readonly mcpServerCommand: string | undefined;
private readonly mcpEnabled: boolean;
private readonly extensionsEnabled: boolean;
private mcpServers: Record<string, MCPServerConfig> | undefined;
private userMemory: string;
private geminiMdFileCount: number;
@@ -517,6 +519,7 @@ export class Config {
this.mcpServerCommand = params.mcpServerCommand;
this.mcpServers = params.mcpServers;
this.mcpEnabled = params.mcpEnabled ?? true;
this.extensionsEnabled = params.extensionsEnabled ?? true;
this.allowedMcpServers = params.allowedMcpServers ?? [];
this.blockedMcpServers = params.blockedMcpServers ?? [];
this.allowedEnvironmentVariables = params.allowedEnvironmentVariables ?? [];
@@ -1143,6 +1146,10 @@ export class Config {
return this.mcpEnabled;
}
getExtensionsEnabled(): boolean {
return this.extensionsEnabled;
}
getMcpClientManager(): McpClientManager | undefined {
return this.mcpClientManager;
}
@@ -14,8 +14,10 @@ import {
createHookOutput,
NotificationType,
type DefaultHookOutput,
type McpToolContext,
BeforeToolHookOutput,
} from '../hooks/types.js';
import type { Config } from '../config/config.js';
import type {
ToolCallConfirmationDetails,
ToolResult,
@@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js';
import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
import type { AnyToolInvocation } from '../tools/tools.js';
import { ShellToolInvocation } from '../tools/shell.js';
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
/**
* Serializable representation of tool confirmation details for hooks.
@@ -154,18 +157,57 @@ export async function fireToolNotificationHook(
}
}
/**
* Extracts MCP context from a tool invocation if it's an MCP tool.
*
* @param invocation The tool invocation
* @param config Config to look up server details
* @returns MCP context if this is an MCP tool, undefined otherwise
*/
function extractMcpContext(
invocation: ShellToolInvocation | AnyToolInvocation,
config: Config,
): McpToolContext | undefined {
if (!(invocation instanceof DiscoveredMCPToolInvocation)) {
return undefined;
}
// Get the server config
const mcpServers =
config.getMcpClientManager()?.getMcpServers() ??
config.getMcpServers() ??
{};
const serverConfig = mcpServers[invocation.serverName];
if (!serverConfig) {
return undefined;
}
return {
server_name: invocation.serverName,
tool_name: invocation.serverToolName,
// Non-sensitive connection details only
command: serverConfig.command,
args: serverConfig.args,
cwd: serverConfig.cwd,
url: serverConfig.url ?? serverConfig.httpUrl,
tcp: serverConfig.tcp,
};
}
/**
* Fires the BeforeTool hook and returns the hook output.
*
* @param messageBus The message bus to use for hook communication
* @param toolName The name of the tool being executed
* @param toolInput The input parameters for the tool
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireBeforeToolHook(
messageBus: MessageBus,
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
@@ -178,6 +220,7 @@ export async function fireBeforeToolHook(
input: {
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
@@ -199,6 +242,7 @@ export async function fireBeforeToolHook(
* @param toolName The name of the tool that was executed
* @param toolInput The input parameters for the tool
* @param toolResponse The result from the tool execution
* @param mcpContext Optional MCP context for MCP tools
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireAfterToolHook(
@@ -210,6 +254,7 @@ export async function fireAfterToolHook(
returnDisplay: ToolResult['returnDisplay'];
error: ToolResult['error'];
},
mcpContext?: McpToolContext,
): Promise<DefaultHookOutput | undefined> {
try {
const response = await messageBus.request<
@@ -223,6 +268,7 @@ export async function fireAfterToolHook(
tool_name: toolName,
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
@@ -248,6 +294,7 @@ export async function fireAfterToolHook(
* @param liveOutputCallback Optional callback for live output updates
* @param shellExecutionConfig Optional shell execution config
* @param setPidCallback Optional callback to set the PID for shell invocations
* @param config Config to look up MCP server details for hook context
* @returns The tool result
*/
export async function executeToolWithHooks(
@@ -260,17 +307,22 @@ export async function executeToolWithHooks(
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
shellExecutionConfig?: ShellExecutionConfig,
setPidCallback?: (pid: number) => void,
config?: Config,
): Promise<ToolResult> {
const toolInput = (invocation.params || {}) as Record<string, unknown>;
let inputWasModified = false;
let modifiedKeys: string[] = [];
// Extract MCP context if this is an MCP tool (only if config is provided)
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const beforeOutput = await fireBeforeToolHook(
messageBus,
toolName,
toolInput,
mcpContext,
);
// Check if hook requested to stop entire agent execution
@@ -378,6 +430,7 @@ export async function executeToolWithHooks(
returnDisplay: toolResult.returnDisplay,
error: toolResult.error,
},
mcpContext,
);
// Check if hook requested to stop entire agent execution
@@ -258,6 +258,128 @@ describe('HookEventHandler', () => {
expect.stringContaining('F12'),
);
});
it('should fire BeforeTool event with MCP context when provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const mcpContext = {
server_name: 'my-mcp-server',
tool_name: 'read_file',
command: 'npx',
args: ['-y', '@my-org/mcp-server'],
};
const result = await hookEventHandler.fireBeforeToolEvent(
'my-mcp-server__read_file',
{ path: '/etc/passwd' },
mcpContext,
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.BeforeTool,
expect.objectContaining({
session_id: 'test-session',
cwd: '/test/project',
hook_event_name: 'BeforeTool',
tool_name: 'my-mcp-server__read_file',
tool_input: { path: '/etc/passwd' },
mcp_context: mcpContext,
}),
expect.any(Function),
expect.any(Function),
);
expect(result).toBe(mockAggregated);
});
it('should not include mcp_context when not provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
await hookEventHandler.fireBeforeToolEvent('EditTool', {
file: 'test.txt',
});
const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock
.calls[0][2];
expect(callArgs).not.toHaveProperty('mcp_context');
});
});
describe('fireAfterToolEvent', () => {
@@ -325,6 +447,78 @@ describe('HookEventHandler', () => {
expect(result).toBe(mockAggregated);
});
it('should fire AfterTool event with MCP context when provided', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './after.sh',
} as unknown as HookConfig,
eventName: HookEventName.AfterTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './after.sh',
timeout: 30000,
},
eventName: HookEventName.AfterTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.AfterTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const toolInput = { path: '/etc/passwd' };
const toolResponse = { success: true, content: 'File content' };
const mcpContext = {
server_name: 'my-mcp-server',
tool_name: 'read_file',
url: 'https://mcp.example.com',
};
const result = await hookEventHandler.fireAfterToolEvent(
'my-mcp-server__read_file',
toolInput,
toolResponse,
mcpContext,
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.AfterTool,
expect.objectContaining({
tool_name: 'my-mcp-server__read_file',
tool_input: toolInput,
tool_response: toolResponse,
mcp_context: mcpContext,
}),
expect.any(Function),
expect.any(Function),
);
expect(result).toBe(mockAggregated);
});
});
describe('fireBeforeAgentEvent', () => {
+38 -5
View File
@@ -29,6 +29,7 @@ import type {
SessionEndReason,
PreCompressTrigger,
HookExecutionResult,
McpToolContext,
} from './types.js';
import { defaultHookTranslator } from './hookTranslator.js';
import type {
@@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record<string, unknown> {
function validateBeforeToolInput(input: Record<string, unknown>): {
toolName: string;
toolInput: Record<string, unknown>;
mcpContext?: McpToolContext;
} {
const toolName = input['tool_name'];
const toolInput = input['tool_input'];
const mcpContext = input['mcp_context'];
if (typeof toolName !== 'string') {
throw new Error(
'Invalid input for BeforeTool hook event: tool_name must be a string',
@@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record<string, unknown>): {
'Invalid input for BeforeTool hook event: tool_input must be an object',
);
}
return { toolName, toolInput };
if (mcpContext !== undefined && !isObject(mcpContext)) {
throw new Error(
'Invalid input for BeforeTool hook event: mcp_context must be an object',
);
}
return {
toolName,
toolInput,
mcpContext: mcpContext as McpToolContext | undefined,
};
}
/**
@@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record<string, unknown>): {
toolName: string;
toolInput: Record<string, unknown>;
toolResponse: Record<string, unknown>;
mcpContext?: McpToolContext;
} {
const toolName = input['tool_name'];
const toolInput = input['tool_input'];
const toolResponse = input['tool_response'];
const mcpContext = input['mcp_context'];
if (typeof toolName !== 'string') {
throw new Error(
'Invalid input for AfterTool hook event: tool_name must be a string',
@@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record<string, unknown>): {
'Invalid input for AfterTool hook event: tool_response must be an object',
);
}
return { toolName, toolInput, toolResponse };
if (mcpContext !== undefined && !isObject(mcpContext)) {
throw new Error(
'Invalid input for AfterTool hook event: mcp_context must be an object',
);
}
return {
toolName,
toolInput,
toolResponse,
mcpContext: mcpContext as McpToolContext | undefined,
};
}
/**
@@ -313,11 +337,13 @@ export class HookEventHandler {
async fireBeforeToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<AggregatedHookResult> {
const input: BeforeToolInput = {
...this.createBaseInput(HookEventName.BeforeTool),
tool_name: toolName,
tool_input: toolInput,
...(mcpContext && { mcp_context: mcpContext }),
};
const context: HookEventContext = { toolName };
@@ -332,12 +358,14 @@ export class HookEventHandler {
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: Record<string, unknown>,
mcpContext?: McpToolContext,
): Promise<AggregatedHookResult> {
const input: AfterToolInput = {
...this.createBaseInput(HookEventName.AfterTool),
tool_name: toolName,
tool_input: toolInput,
tool_response: toolResponse,
...(mcpContext && { mcp_context: mcpContext }),
};
const context: HookEventContext = { toolName };
@@ -725,18 +753,23 @@ export class HookEventHandler {
// Route to appropriate event handler based on eventName
switch (request.eventName) {
case HookEventName.BeforeTool: {
const { toolName, toolInput } =
const { toolName, toolInput, mcpContext } =
validateBeforeToolInput(enrichedInput);
result = await this.fireBeforeToolEvent(toolName, toolInput);
result = await this.fireBeforeToolEvent(
toolName,
toolInput,
mcpContext,
);
break;
}
case HookEventName.AfterTool: {
const { toolName, toolInput, toolResponse } =
const { toolName, toolInput, toolResponse, mcpContext } =
validateAfterToolInput(enrichedInput);
result = await this.fireAfterToolEvent(
toolName,
toolInput,
toolResponse,
mcpContext,
);
break;
}
+2 -1
View File
@@ -12,6 +12,7 @@ import type {
FunctionCallingConfig,
} from '@google/genai';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { getResponseText } from '../utils/partUtils.js';
/**
* Decoupled LLM request format - stable across Gemini CLI versions
@@ -267,7 +268,7 @@ export class HookTranslatorGenAIv1 extends HookTranslator {
*/
toHookLLMResponse(sdkResponse: GenerateContentResponse): LLMResponse {
return {
text: sdkResponse.text,
text: getResponseText(sdkResponse) ?? undefined,
candidates: (sdkResponse.candidates || []).map((candidate) => {
// Extract text parts from the candidate
const textParts =
+26
View File
@@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput {
}
}
/**
* Context for MCP tool executions.
* Contains non-sensitive connection information about the MCP server
* identity. Since server_name is user controlled and arbitrary, we
* also include connection information (e.g., command or url) to
* help identify the MCP server.
*
* NOTE: In the future, consider defining a shared sanitized interface
* from MCPServerConfig to avoid duplication and ensure consistency.
*/
export interface McpToolContext {
server_name: string;
tool_name: string; // Original tool name from the MCP server
// Connection info (mutually exclusive based on transport type)
command?: string; // For stdio transport
args?: string[]; // For stdio transport
cwd?: string; // For stdio transport
url?: string; // For SSE/HTTP transport
tcp?: string; // For WebSocket transport
}
/**
* BeforeTool hook input
*/
export interface BeforeToolInput extends HookInput {
tool_name: string;
tool_input: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
}
/**
@@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput {
tool_name: string;
tool_input: Record<string, unknown>;
tool_response: Record<string, unknown>;
mcp_context?: McpToolContext; // Only present for MCP tools
}
/**
@@ -98,6 +98,7 @@ export class ToolExecutor {
liveOutputCallback,
shellExecutionConfig,
setPidCallback,
this.config,
);
} else {
promise = executeToolWithHooks(
@@ -109,6 +110,8 @@ export class ToolExecutor {
tool,
liveOutputCallback,
shellExecutionConfig,
undefined,
this.config,
);
}
+1 -1
View File
@@ -59,7 +59,7 @@ type McpContentBlock =
| McpResourceBlock
| McpResourceLinkBlock;
class DiscoveredMCPToolInvocation extends BaseToolInvocation<
export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
ToolParams,
ToolResult
> {