mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 21:03:05 -07:00
Merge branch 'main' into dyim/add-api-version-env-var
This commit is contained in:
@@ -46,6 +46,16 @@ specific event.
|
|||||||
- `tool_input`: (`object`) The arguments passed to the tool.
|
- `tool_input`: (`object`) The arguments passed to the tool.
|
||||||
- `tool_response`: (`object`, **AfterTool only**) The raw output from the tool
|
- `tool_response`: (`object`, **AfterTool only**) The raw output from the tool
|
||||||
execution.
|
execution.
|
||||||
|
- `mcp_context`: (`object`, **optional**) Present only for MCP tool invocations.
|
||||||
|
Contains server identity information:
|
||||||
|
- `server_name`: (`string`) The configured name of the MCP server.
|
||||||
|
- `tool_name`: (`string`) The original tool name from the MCP server.
|
||||||
|
- `command`: (`string`, optional) For stdio transport, the command used to
|
||||||
|
start the server.
|
||||||
|
- `args`: (`string[]`, optional) For stdio transport, the command arguments.
|
||||||
|
- `cwd`: (`string`, optional) For stdio transport, the working directory.
|
||||||
|
- `url`: (`string`, optional) For SSE/HTTP transport, the server URL.
|
||||||
|
- `tcp`: (`string`, optional) For WebSocket transport, the TCP address.
|
||||||
|
|
||||||
#### Agent Events (`BeforeAgent`, `AfterAgent`)
|
#### Agent Events (`BeforeAgent`, `AfterAgent`)
|
||||||
|
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ export async function loadConfig(
|
|||||||
: settings.checkpointing?.enabled,
|
: settings.checkpointing?.enabled,
|
||||||
previewFeatures: settings.general?.previewFeatures,
|
previewFeatures: settings.general?.previewFeatures,
|
||||||
interactive: true,
|
interactive: true,
|
||||||
|
enableInteractiveShell: true,
|
||||||
};
|
};
|
||||||
|
|
||||||
const fileService = new FileDiscoveryService(workspaceDir);
|
const fileService = new FileDiscoveryService(workspaceDir);
|
||||||
|
|||||||
@@ -637,6 +637,7 @@ export async function loadCliConfig(
|
|||||||
const ptyInfo = await getPty();
|
const ptyInfo = await getPty();
|
||||||
|
|
||||||
const mcpEnabled = settings.admin?.mcp?.enabled ?? true;
|
const mcpEnabled = settings.admin?.mcp?.enabled ?? true;
|
||||||
|
const extensionsEnabled = settings.admin?.extensions?.enabled ?? true;
|
||||||
|
|
||||||
return new Config({
|
return new Config({
|
||||||
sessionId,
|
sessionId,
|
||||||
@@ -659,6 +660,7 @@ export async function loadCliConfig(
|
|||||||
mcpServerCommand: mcpEnabled ? settings.mcp?.serverCommand : undefined,
|
mcpServerCommand: mcpEnabled ? settings.mcp?.serverCommand : undefined,
|
||||||
mcpServers: mcpEnabled ? settings.mcpServers : {},
|
mcpServers: mcpEnabled ? settings.mcpServers : {},
|
||||||
mcpEnabled,
|
mcpEnabled,
|
||||||
|
extensionsEnabled,
|
||||||
allowedMcpServers: mcpEnabled
|
allowedMcpServers: mcpEnabled
|
||||||
? (argv.allowedMcpServerNames ?? settings.mcp?.allowed)
|
? (argv.allowedMcpServerNames ?? settings.mcp?.allowed)
|
||||||
: undefined,
|
: undefined,
|
||||||
|
|||||||
@@ -465,6 +465,12 @@ Would you like to attempt to install via "git clone" instead?`,
|
|||||||
if (this.loadedExtensions) {
|
if (this.loadedExtensions) {
|
||||||
throw new Error('Extensions already loaded, only load extensions once.');
|
throw new Error('Extensions already loaded, only load extensions once.');
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (this.settings.admin?.extensions?.enabled === false) {
|
||||||
|
this.loadedExtensions = [];
|
||||||
|
return this.loadedExtensions;
|
||||||
|
}
|
||||||
|
|
||||||
const extensionsDir = ExtensionStorage.getUserExtensionsDir();
|
const extensionsDir = ExtensionStorage.getUserExtensionsDir();
|
||||||
this.loadedExtensions = [];
|
this.loadedExtensions = [];
|
||||||
if (!fs.existsSync(extensionsDir)) {
|
if (!fs.existsSync(extensionsDir)) {
|
||||||
@@ -537,6 +543,9 @@ Would you like to attempt to install via "git clone" instead?`,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (config.mcpServers) {
|
if (config.mcpServers) {
|
||||||
|
if (this.settings.admin?.mcp?.enabled === false) {
|
||||||
|
config.mcpServers = undefined;
|
||||||
|
} else {
|
||||||
config.mcpServers = Object.fromEntries(
|
config.mcpServers = Object.fromEntries(
|
||||||
Object.entries(config.mcpServers).map(([key, value]) => [
|
Object.entries(config.mcpServers).map(([key, value]) => [
|
||||||
key,
|
key,
|
||||||
@@ -544,6 +553,7 @@ Would you like to attempt to install via "git clone" instead?`,
|
|||||||
]),
|
]),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const contextFiles = getContextFileNames(config)
|
const contextFiles = getContextFileNames(config)
|
||||||
.map((contextFileName) =>
|
.map((contextFileName) =>
|
||||||
|
|||||||
@@ -632,6 +632,79 @@ describe('extension tests', () => {
|
|||||||
expect(extension).toBeUndefined();
|
expect(extension).toBeUndefined();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should not load any extensions if admin.extensions.enabled is false', async () => {
|
||||||
|
createExtension({
|
||||||
|
extensionsDir: userExtensionsDir,
|
||||||
|
name: 'test-extension',
|
||||||
|
version: '1.0.0',
|
||||||
|
});
|
||||||
|
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
|
||||||
|
(loadedSettings.admin ??= {}).extensions ??= {};
|
||||||
|
loadedSettings.admin.extensions.enabled = false;
|
||||||
|
|
||||||
|
extensionManager = new ExtensionManager({
|
||||||
|
workspaceDir: tempWorkspaceDir,
|
||||||
|
requestConsent: mockRequestConsent,
|
||||||
|
requestSetting: mockPromptForSettings,
|
||||||
|
settings: loadedSettings,
|
||||||
|
});
|
||||||
|
|
||||||
|
const extensions = await extensionManager.loadExtensions();
|
||||||
|
expect(extensions).toEqual([]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not load mcpServers if admin.mcp.enabled is false', async () => {
|
||||||
|
createExtension({
|
||||||
|
extensionsDir: userExtensionsDir,
|
||||||
|
name: 'test-extension',
|
||||||
|
version: '1.0.0',
|
||||||
|
mcpServers: {
|
||||||
|
'test-server': { command: 'echo', args: ['hello'] },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
|
||||||
|
(loadedSettings.admin ??= {}).mcp ??= {};
|
||||||
|
loadedSettings.admin.mcp.enabled = false;
|
||||||
|
|
||||||
|
extensionManager = new ExtensionManager({
|
||||||
|
workspaceDir: tempWorkspaceDir,
|
||||||
|
requestConsent: mockRequestConsent,
|
||||||
|
requestSetting: mockPromptForSettings,
|
||||||
|
settings: loadedSettings,
|
||||||
|
});
|
||||||
|
|
||||||
|
const extensions = await extensionManager.loadExtensions();
|
||||||
|
expect(extensions).toHaveLength(1);
|
||||||
|
expect(extensions[0].mcpServers).toBeUndefined();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should load mcpServers if admin.mcp.enabled is true', async () => {
|
||||||
|
createExtension({
|
||||||
|
extensionsDir: userExtensionsDir,
|
||||||
|
name: 'test-extension',
|
||||||
|
version: '1.0.0',
|
||||||
|
mcpServers: {
|
||||||
|
'test-server': { command: 'echo', args: ['hello'] },
|
||||||
|
},
|
||||||
|
});
|
||||||
|
const loadedSettings = loadSettings(tempWorkspaceDir).merged;
|
||||||
|
(loadedSettings.admin ??= {}).mcp ??= {};
|
||||||
|
loadedSettings.admin.mcp.enabled = true;
|
||||||
|
|
||||||
|
extensionManager = new ExtensionManager({
|
||||||
|
workspaceDir: tempWorkspaceDir,
|
||||||
|
requestConsent: mockRequestConsent,
|
||||||
|
requestSetting: mockPromptForSettings,
|
||||||
|
settings: loadedSettings,
|
||||||
|
});
|
||||||
|
|
||||||
|
const extensions = await extensionManager.loadExtensions();
|
||||||
|
expect(extensions).toHaveLength(1);
|
||||||
|
expect(extensions[0].mcpServers).toEqual({
|
||||||
|
'test-server': { command: 'echo', args: ['hello'] },
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('id generation', () => {
|
describe('id generation', () => {
|
||||||
it.each([
|
it.each([
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ describe('BuiltinCommandLoader', () => {
|
|||||||
getEnableExtensionReloading: () => false,
|
getEnableExtensionReloading: () => false,
|
||||||
getEnableHooks: () => false,
|
getEnableHooks: () => false,
|
||||||
getEnableHooksUI: () => false,
|
getEnableHooksUI: () => false,
|
||||||
|
getExtensionsEnabled: vi.fn().mockReturnValue(true),
|
||||||
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
|
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
|
||||||
getMcpEnabled: vi.fn().mockReturnValue(true),
|
getMcpEnabled: vi.fn().mockReturnValue(true),
|
||||||
getSkillManager: vi.fn().mockReturnValue({
|
getSkillManager: vi.fn().mockReturnValue({
|
||||||
@@ -201,6 +202,7 @@ describe('BuiltinCommandLoader profile', () => {
|
|||||||
getEnableExtensionReloading: () => false,
|
getEnableExtensionReloading: () => false,
|
||||||
getEnableHooks: () => false,
|
getEnableHooks: () => false,
|
||||||
getEnableHooksUI: () => false,
|
getEnableHooksUI: () => false,
|
||||||
|
getExtensionsEnabled: vi.fn().mockReturnValue(true),
|
||||||
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
|
isSkillsSupportEnabled: vi.fn().mockReturnValue(false),
|
||||||
getMcpEnabled: vi.fn().mockReturnValue(true),
|
getMcpEnabled: vi.fn().mockReturnValue(true),
|
||||||
getSkillManager: vi.fn().mockReturnValue({
|
getSkillManager: vi.fn().mockReturnValue({
|
||||||
|
|||||||
@@ -76,7 +76,24 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
|||||||
docsCommand,
|
docsCommand,
|
||||||
directoryCommand,
|
directoryCommand,
|
||||||
editorCommand,
|
editorCommand,
|
||||||
extensionsCommand(this.config?.getEnableExtensionReloading()),
|
...(this.config?.getExtensionsEnabled() === false
|
||||||
|
? [
|
||||||
|
{
|
||||||
|
name: 'extensions',
|
||||||
|
description: 'Manage extensions',
|
||||||
|
kind: CommandKind.BUILT_IN,
|
||||||
|
autoExecute: false,
|
||||||
|
subCommands: [],
|
||||||
|
action: async (
|
||||||
|
_context: CommandContext,
|
||||||
|
): Promise<MessageActionReturn> => ({
|
||||||
|
type: 'message',
|
||||||
|
messageType: 'error',
|
||||||
|
content: 'Extensions are disabled by your admin.',
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
: [extensionsCommand(this.config?.getEnableExtensionReloading())]),
|
||||||
helpCommand,
|
helpCommand,
|
||||||
...(this.config?.getEnableHooksUI() ? [hooksCommand] : []),
|
...(this.config?.getEnableHooksUI() ? [hooksCommand] : []),
|
||||||
await ideCommand(),
|
await ideCommand(),
|
||||||
@@ -95,7 +112,7 @@ export class BuiltinCommandLoader implements ICommandLoader {
|
|||||||
): Promise<MessageActionReturn> => ({
|
): Promise<MessageActionReturn> => ({
|
||||||
type: 'message',
|
type: 'message',
|
||||||
messageType: 'error',
|
messageType: 'error',
|
||||||
content: 'MCP disabled by your admin.',
|
content: 'MCP is disabled by your admin.',
|
||||||
}),
|
}),
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import {
|
|||||||
inferInstallMetadata,
|
inferInstallMetadata,
|
||||||
} from '../../config/extension-manager.js';
|
} from '../../config/extension-manager.js';
|
||||||
import { SettingScope } from '../../config/settings.js';
|
import { SettingScope } from '../../config/settings.js';
|
||||||
|
import { stat } from 'node:fs/promises';
|
||||||
|
|
||||||
vi.mock('../../config/extension-manager.js', async (importOriginal) => {
|
vi.mock('../../config/extension-manager.js', async (importOriginal) => {
|
||||||
const actual =
|
const actual =
|
||||||
@@ -42,11 +43,16 @@ vi.mock('../../config/extension-manager.js', async (importOriginal) => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
import open from 'open';
|
import open from 'open';
|
||||||
|
import type { Stats } from 'node:fs';
|
||||||
|
|
||||||
vi.mock('open', () => ({
|
vi.mock('open', () => ({
|
||||||
default: vi.fn(),
|
default: vi.fn(),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
vi.mock('node:fs/promises', () => ({
|
||||||
|
stat: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
vi.mock('../../config/extensions/update.js', () => ({
|
vi.mock('../../config/extensions/update.js', () => ({
|
||||||
updateExtension: vi.fn(),
|
updateExtension: vi.fn(),
|
||||||
checkForAllExtensionUpdates: vi.fn(),
|
checkForAllExtensionUpdates: vi.fn(),
|
||||||
@@ -493,34 +499,37 @@ describe('extensionsCommand', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('when enableExtensionReloading is true', () => {
|
describe('when enableExtensionReloading is true', () => {
|
||||||
it('should include enable, disable, install, and uninstall subcommands', () => {
|
it('should include enable, disable, install, link, and uninstall subcommands', () => {
|
||||||
const command = extensionsCommand(true);
|
const command = extensionsCommand(true);
|
||||||
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
||||||
expect(subCommandNames).toContain('enable');
|
expect(subCommandNames).toContain('enable');
|
||||||
expect(subCommandNames).toContain('disable');
|
expect(subCommandNames).toContain('disable');
|
||||||
expect(subCommandNames).toContain('install');
|
expect(subCommandNames).toContain('install');
|
||||||
|
expect(subCommandNames).toContain('link');
|
||||||
expect(subCommandNames).toContain('uninstall');
|
expect(subCommandNames).toContain('uninstall');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('when enableExtensionReloading is false', () => {
|
describe('when enableExtensionReloading is false', () => {
|
||||||
it('should not include enable, disable, install, and uninstall subcommands', () => {
|
it('should not include enable, disable, install, link, and uninstall subcommands', () => {
|
||||||
const command = extensionsCommand(false);
|
const command = extensionsCommand(false);
|
||||||
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
||||||
expect(subCommandNames).not.toContain('enable');
|
expect(subCommandNames).not.toContain('enable');
|
||||||
expect(subCommandNames).not.toContain('disable');
|
expect(subCommandNames).not.toContain('disable');
|
||||||
expect(subCommandNames).not.toContain('install');
|
expect(subCommandNames).not.toContain('install');
|
||||||
|
expect(subCommandNames).not.toContain('link');
|
||||||
expect(subCommandNames).not.toContain('uninstall');
|
expect(subCommandNames).not.toContain('uninstall');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('when enableExtensionReloading is not provided', () => {
|
describe('when enableExtensionReloading is not provided', () => {
|
||||||
it('should not include enable, disable, install, and uninstall subcommands by default', () => {
|
it('should not include enable, disable, install, link, and uninstall subcommands by default', () => {
|
||||||
const command = extensionsCommand();
|
const command = extensionsCommand();
|
||||||
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
const subCommandNames = command.subCommands?.map((cmd) => cmd.name);
|
||||||
expect(subCommandNames).not.toContain('enable');
|
expect(subCommandNames).not.toContain('enable');
|
||||||
expect(subCommandNames).not.toContain('disable');
|
expect(subCommandNames).not.toContain('disable');
|
||||||
expect(subCommandNames).not.toContain('install');
|
expect(subCommandNames).not.toContain('install');
|
||||||
|
expect(subCommandNames).not.toContain('link');
|
||||||
expect(subCommandNames).not.toContain('uninstall');
|
expect(subCommandNames).not.toContain('uninstall');
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -617,6 +626,88 @@ describe('extensionsCommand', () => {
|
|||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('link', () => {
|
||||||
|
let linkAction: SlashCommand['action'];
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
linkAction = extensionsCommand(true).subCommands?.find(
|
||||||
|
(cmd) => cmd.name === 'link',
|
||||||
|
)?.action;
|
||||||
|
|
||||||
|
expect(linkAction).not.toBeNull();
|
||||||
|
mockContext.invocation!.name = 'link';
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should show usage if no extension is provided', async () => {
|
||||||
|
await linkAction!(mockContext, '');
|
||||||
|
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: 'Usage: /extensions link <source>',
|
||||||
|
},
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
expect(mockInstallExtension).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should call installExtension and show success message', async () => {
|
||||||
|
const packageName = 'test-extension-package';
|
||||||
|
mockInstallExtension.mockResolvedValue({ name: packageName });
|
||||||
|
vi.mocked(stat).mockResolvedValue({
|
||||||
|
size: 100,
|
||||||
|
} as Stats);
|
||||||
|
await linkAction!(mockContext, packageName);
|
||||||
|
expect(mockInstallExtension).toHaveBeenCalledWith({
|
||||||
|
source: packageName,
|
||||||
|
type: 'link',
|
||||||
|
});
|
||||||
|
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: `Linking extension from "${packageName}"...`,
|
||||||
|
},
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: `Extension "${packageName}" linked successfully.`,
|
||||||
|
},
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should show error message on linking failure', async () => {
|
||||||
|
const packageName = 'test-extension-package';
|
||||||
|
const errorMessage = 'link failed';
|
||||||
|
mockInstallExtension.mockRejectedValue(new Error(errorMessage));
|
||||||
|
vi.mocked(stat).mockResolvedValue({
|
||||||
|
size: 100,
|
||||||
|
} as Stats);
|
||||||
|
|
||||||
|
await linkAction!(mockContext, packageName);
|
||||||
|
expect(mockInstallExtension).toHaveBeenCalledWith({
|
||||||
|
source: packageName,
|
||||||
|
type: 'link',
|
||||||
|
});
|
||||||
|
expect(mockContext.ui.addItem).toHaveBeenCalledWith(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: `Failed to link extension from "${packageName}": ${errorMessage}`,
|
||||||
|
},
|
||||||
|
expect.any(Number),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should show error message for invalid source', async () => {
|
||||||
|
const packageName = 'test-extension-package';
|
||||||
|
const errorMessage = 'invalid path';
|
||||||
|
vi.mocked(stat).mockRejectedValue(new Error(errorMessage));
|
||||||
|
await linkAction!(mockContext, packageName);
|
||||||
|
expect(mockInstallExtension).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('uninstall', () => {
|
describe('uninstall', () => {
|
||||||
let uninstallAction: SlashCommand['action'];
|
let uninstallAction: SlashCommand['action'];
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,11 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { debugLogger, listExtensions } from '@google/gemini-cli-core';
|
import {
|
||||||
|
debugLogger,
|
||||||
|
listExtensions,
|
||||||
|
type ExtensionInstallMetadata,
|
||||||
|
} from '@google/gemini-cli-core';
|
||||||
import type { ExtensionUpdateInfo } from '../../config/extension.js';
|
import type { ExtensionUpdateInfo } from '../../config/extension.js';
|
||||||
import { getErrorMessage } from '../../utils/errors.js';
|
import { getErrorMessage } from '../../utils/errors.js';
|
||||||
import {
|
import {
|
||||||
@@ -26,6 +30,7 @@ import {
|
|||||||
} from '../../config/extension-manager.js';
|
} from '../../config/extension-manager.js';
|
||||||
import { SettingScope } from '../../config/settings.js';
|
import { SettingScope } from '../../config/settings.js';
|
||||||
import { theme } from '../semantic-colors.js';
|
import { theme } from '../semantic-colors.js';
|
||||||
|
import { stat } from 'node:fs/promises';
|
||||||
|
|
||||||
function showMessageIfNoExtensions(
|
function showMessageIfNoExtensions(
|
||||||
context: CommandContext,
|
context: CommandContext,
|
||||||
@@ -510,6 +515,88 @@ async function installAction(context: CommandContext, args: string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function linkAction(context: CommandContext, args: string) {
|
||||||
|
const extensionLoader = context.services.config?.getExtensionLoader();
|
||||||
|
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||||
|
debugLogger.error(
|
||||||
|
`Cannot ${context.invocation?.name} extensions in this environment`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceFilepath = args.trim();
|
||||||
|
if (!sourceFilepath) {
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: `Usage: /extensions link <source>`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (/[;&|`'"]/.test(sourceFilepath)) {
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: `Source file path contains disallowed characters: ${sourceFilepath}`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
await stat(sourceFilepath);
|
||||||
|
} catch (error) {
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: `Invalid source: ${sourceFilepath}`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
debugLogger.error(
|
||||||
|
`Failed to stat path "${sourceFilepath}": ${getErrorMessage(error)}`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: `Linking extension from "${sourceFilepath}"...`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const installMetadata: ExtensionInstallMetadata = {
|
||||||
|
source: sourceFilepath,
|
||||||
|
type: 'link',
|
||||||
|
};
|
||||||
|
const extension =
|
||||||
|
await extensionLoader.installOrUpdateExtension(installMetadata);
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.INFO,
|
||||||
|
text: `Extension "${extension.name}" linked successfully.`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
context.ui.addItem(
|
||||||
|
{
|
||||||
|
type: MessageType.ERROR,
|
||||||
|
text: `Failed to link extension from "${sourceFilepath}": ${getErrorMessage(
|
||||||
|
error,
|
||||||
|
)}`,
|
||||||
|
},
|
||||||
|
Date.now(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function uninstallAction(context: CommandContext, args: string) {
|
async function uninstallAction(context: CommandContext, args: string) {
|
||||||
const extensionLoader = context.services.config?.getExtensionLoader();
|
const extensionLoader = context.services.config?.getExtensionLoader();
|
||||||
if (!(extensionLoader instanceof ExtensionManager)) {
|
if (!(extensionLoader instanceof ExtensionManager)) {
|
||||||
@@ -645,6 +732,14 @@ const installCommand: SlashCommand = {
|
|||||||
action: installAction,
|
action: installAction,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const linkCommand: SlashCommand = {
|
||||||
|
name: 'link',
|
||||||
|
description: 'Link an extension from a local path',
|
||||||
|
kind: CommandKind.BUILT_IN,
|
||||||
|
autoExecute: false,
|
||||||
|
action: linkAction,
|
||||||
|
};
|
||||||
|
|
||||||
const uninstallCommand: SlashCommand = {
|
const uninstallCommand: SlashCommand = {
|
||||||
name: 'uninstall',
|
name: 'uninstall',
|
||||||
description: 'Uninstall an extension',
|
description: 'Uninstall an extension',
|
||||||
@@ -675,7 +770,13 @@ export function extensionsCommand(
|
|||||||
enableExtensionReloading?: boolean,
|
enableExtensionReloading?: boolean,
|
||||||
): SlashCommand {
|
): SlashCommand {
|
||||||
const conditionalCommands = enableExtensionReloading
|
const conditionalCommands = enableExtensionReloading
|
||||||
? [disableCommand, enableCommand, installCommand, uninstallCommand]
|
? [
|
||||||
|
disableCommand,
|
||||||
|
enableCommand,
|
||||||
|
installCommand,
|
||||||
|
uninstallCommand,
|
||||||
|
linkCommand,
|
||||||
|
]
|
||||||
: [];
|
: [];
|
||||||
return {
|
return {
|
||||||
name: 'extensions',
|
name: 'extensions',
|
||||||
|
|||||||
@@ -357,6 +357,7 @@ export interface ConfigParameters {
|
|||||||
experimentalJitContext?: boolean;
|
experimentalJitContext?: boolean;
|
||||||
onModelChange?: (model: string) => void;
|
onModelChange?: (model: string) => void;
|
||||||
mcpEnabled?: boolean;
|
mcpEnabled?: boolean;
|
||||||
|
extensionsEnabled?: boolean;
|
||||||
onReload?: () => Promise<{ disabledSkills?: string[] }>;
|
onReload?: () => Promise<{ disabledSkills?: string[] }>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -391,6 +392,7 @@ export class Config {
|
|||||||
private readonly toolCallCommand: string | undefined;
|
private readonly toolCallCommand: string | undefined;
|
||||||
private readonly mcpServerCommand: string | undefined;
|
private readonly mcpServerCommand: string | undefined;
|
||||||
private readonly mcpEnabled: boolean;
|
private readonly mcpEnabled: boolean;
|
||||||
|
private readonly extensionsEnabled: boolean;
|
||||||
private mcpServers: Record<string, MCPServerConfig> | undefined;
|
private mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||||
private userMemory: string;
|
private userMemory: string;
|
||||||
private geminiMdFileCount: number;
|
private geminiMdFileCount: number;
|
||||||
@@ -517,6 +519,7 @@ export class Config {
|
|||||||
this.mcpServerCommand = params.mcpServerCommand;
|
this.mcpServerCommand = params.mcpServerCommand;
|
||||||
this.mcpServers = params.mcpServers;
|
this.mcpServers = params.mcpServers;
|
||||||
this.mcpEnabled = params.mcpEnabled ?? true;
|
this.mcpEnabled = params.mcpEnabled ?? true;
|
||||||
|
this.extensionsEnabled = params.extensionsEnabled ?? true;
|
||||||
this.allowedMcpServers = params.allowedMcpServers ?? [];
|
this.allowedMcpServers = params.allowedMcpServers ?? [];
|
||||||
this.blockedMcpServers = params.blockedMcpServers ?? [];
|
this.blockedMcpServers = params.blockedMcpServers ?? [];
|
||||||
this.allowedEnvironmentVariables = params.allowedEnvironmentVariables ?? [];
|
this.allowedEnvironmentVariables = params.allowedEnvironmentVariables ?? [];
|
||||||
@@ -1143,6 +1146,10 @@ export class Config {
|
|||||||
return this.mcpEnabled;
|
return this.mcpEnabled;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getExtensionsEnabled(): boolean {
|
||||||
|
return this.extensionsEnabled;
|
||||||
|
}
|
||||||
|
|
||||||
getMcpClientManager(): McpClientManager | undefined {
|
getMcpClientManager(): McpClientManager | undefined {
|
||||||
return this.mcpClientManager;
|
return this.mcpClientManager;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,10 @@ import {
|
|||||||
createHookOutput,
|
createHookOutput,
|
||||||
NotificationType,
|
NotificationType,
|
||||||
type DefaultHookOutput,
|
type DefaultHookOutput,
|
||||||
|
type McpToolContext,
|
||||||
BeforeToolHookOutput,
|
BeforeToolHookOutput,
|
||||||
} from '../hooks/types.js';
|
} from '../hooks/types.js';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
import type {
|
import type {
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolResult,
|
ToolResult,
|
||||||
@@ -26,6 +28,7 @@ import { debugLogger } from '../utils/debugLogger.js';
|
|||||||
import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
|
import type { AnsiOutput, ShellExecutionConfig } from '../index.js';
|
||||||
import type { AnyToolInvocation } from '../tools/tools.js';
|
import type { AnyToolInvocation } from '../tools/tools.js';
|
||||||
import { ShellToolInvocation } from '../tools/shell.js';
|
import { ShellToolInvocation } from '../tools/shell.js';
|
||||||
|
import { DiscoveredMCPToolInvocation } from '../tools/mcp-tool.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Serializable representation of tool confirmation details for hooks.
|
* Serializable representation of tool confirmation details for hooks.
|
||||||
@@ -154,18 +157,57 @@ export async function fireToolNotificationHook(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts MCP context from a tool invocation if it's an MCP tool.
|
||||||
|
*
|
||||||
|
* @param invocation The tool invocation
|
||||||
|
* @param config Config to look up server details
|
||||||
|
* @returns MCP context if this is an MCP tool, undefined otherwise
|
||||||
|
*/
|
||||||
|
function extractMcpContext(
|
||||||
|
invocation: ShellToolInvocation | AnyToolInvocation,
|
||||||
|
config: Config,
|
||||||
|
): McpToolContext | undefined {
|
||||||
|
if (!(invocation instanceof DiscoveredMCPToolInvocation)) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the server config
|
||||||
|
const mcpServers =
|
||||||
|
config.getMcpClientManager()?.getMcpServers() ??
|
||||||
|
config.getMcpServers() ??
|
||||||
|
{};
|
||||||
|
const serverConfig = mcpServers[invocation.serverName];
|
||||||
|
if (!serverConfig) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
server_name: invocation.serverName,
|
||||||
|
tool_name: invocation.serverToolName,
|
||||||
|
// Non-sensitive connection details only
|
||||||
|
command: serverConfig.command,
|
||||||
|
args: serverConfig.args,
|
||||||
|
cwd: serverConfig.cwd,
|
||||||
|
url: serverConfig.url ?? serverConfig.httpUrl,
|
||||||
|
tcp: serverConfig.tcp,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Fires the BeforeTool hook and returns the hook output.
|
* Fires the BeforeTool hook and returns the hook output.
|
||||||
*
|
*
|
||||||
* @param messageBus The message bus to use for hook communication
|
* @param messageBus The message bus to use for hook communication
|
||||||
* @param toolName The name of the tool being executed
|
* @param toolName The name of the tool being executed
|
||||||
* @param toolInput The input parameters for the tool
|
* @param toolInput The input parameters for the tool
|
||||||
|
* @param mcpContext Optional MCP context for MCP tools
|
||||||
* @returns The hook output, or undefined if no hook was executed or on error
|
* @returns The hook output, or undefined if no hook was executed or on error
|
||||||
*/
|
*/
|
||||||
export async function fireBeforeToolHook(
|
export async function fireBeforeToolHook(
|
||||||
messageBus: MessageBus,
|
messageBus: MessageBus,
|
||||||
toolName: string,
|
toolName: string,
|
||||||
toolInput: Record<string, unknown>,
|
toolInput: Record<string, unknown>,
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
): Promise<DefaultHookOutput | undefined> {
|
): Promise<DefaultHookOutput | undefined> {
|
||||||
try {
|
try {
|
||||||
const response = await messageBus.request<
|
const response = await messageBus.request<
|
||||||
@@ -178,6 +220,7 @@ export async function fireBeforeToolHook(
|
|||||||
input: {
|
input: {
|
||||||
tool_name: toolName,
|
tool_name: toolName,
|
||||||
tool_input: toolInput,
|
tool_input: toolInput,
|
||||||
|
...(mcpContext && { mcp_context: mcpContext }),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||||
@@ -199,6 +242,7 @@ export async function fireBeforeToolHook(
|
|||||||
* @param toolName The name of the tool that was executed
|
* @param toolName The name of the tool that was executed
|
||||||
* @param toolInput The input parameters for the tool
|
* @param toolInput The input parameters for the tool
|
||||||
* @param toolResponse The result from the tool execution
|
* @param toolResponse The result from the tool execution
|
||||||
|
* @param mcpContext Optional MCP context for MCP tools
|
||||||
* @returns The hook output, or undefined if no hook was executed or on error
|
* @returns The hook output, or undefined if no hook was executed or on error
|
||||||
*/
|
*/
|
||||||
export async function fireAfterToolHook(
|
export async function fireAfterToolHook(
|
||||||
@@ -210,6 +254,7 @@ export async function fireAfterToolHook(
|
|||||||
returnDisplay: ToolResult['returnDisplay'];
|
returnDisplay: ToolResult['returnDisplay'];
|
||||||
error: ToolResult['error'];
|
error: ToolResult['error'];
|
||||||
},
|
},
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
): Promise<DefaultHookOutput | undefined> {
|
): Promise<DefaultHookOutput | undefined> {
|
||||||
try {
|
try {
|
||||||
const response = await messageBus.request<
|
const response = await messageBus.request<
|
||||||
@@ -223,6 +268,7 @@ export async function fireAfterToolHook(
|
|||||||
tool_name: toolName,
|
tool_name: toolName,
|
||||||
tool_input: toolInput,
|
tool_input: toolInput,
|
||||||
tool_response: toolResponse,
|
tool_response: toolResponse,
|
||||||
|
...(mcpContext && { mcp_context: mcpContext }),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||||
@@ -248,6 +294,7 @@ export async function fireAfterToolHook(
|
|||||||
* @param liveOutputCallback Optional callback for live output updates
|
* @param liveOutputCallback Optional callback for live output updates
|
||||||
* @param shellExecutionConfig Optional shell execution config
|
* @param shellExecutionConfig Optional shell execution config
|
||||||
* @param setPidCallback Optional callback to set the PID for shell invocations
|
* @param setPidCallback Optional callback to set the PID for shell invocations
|
||||||
|
* @param config Config to look up MCP server details for hook context
|
||||||
* @returns The tool result
|
* @returns The tool result
|
||||||
*/
|
*/
|
||||||
export async function executeToolWithHooks(
|
export async function executeToolWithHooks(
|
||||||
@@ -260,17 +307,22 @@ export async function executeToolWithHooks(
|
|||||||
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
liveOutputCallback?: (outputChunk: string | AnsiOutput) => void,
|
||||||
shellExecutionConfig?: ShellExecutionConfig,
|
shellExecutionConfig?: ShellExecutionConfig,
|
||||||
setPidCallback?: (pid: number) => void,
|
setPidCallback?: (pid: number) => void,
|
||||||
|
config?: Config,
|
||||||
): Promise<ToolResult> {
|
): Promise<ToolResult> {
|
||||||
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
const toolInput = (invocation.params || {}) as Record<string, unknown>;
|
||||||
let inputWasModified = false;
|
let inputWasModified = false;
|
||||||
let modifiedKeys: string[] = [];
|
let modifiedKeys: string[] = [];
|
||||||
|
|
||||||
|
// Extract MCP context if this is an MCP tool (only if config is provided)
|
||||||
|
const mcpContext = config ? extractMcpContext(invocation, config) : undefined;
|
||||||
|
|
||||||
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
|
// Fire BeforeTool hook through MessageBus (only if hooks are enabled)
|
||||||
if (hooksEnabled && messageBus) {
|
if (hooksEnabled && messageBus) {
|
||||||
const beforeOutput = await fireBeforeToolHook(
|
const beforeOutput = await fireBeforeToolHook(
|
||||||
messageBus,
|
messageBus,
|
||||||
toolName,
|
toolName,
|
||||||
toolInput,
|
toolInput,
|
||||||
|
mcpContext,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check if hook requested to stop entire agent execution
|
// Check if hook requested to stop entire agent execution
|
||||||
@@ -378,6 +430,7 @@ export async function executeToolWithHooks(
|
|||||||
returnDisplay: toolResult.returnDisplay,
|
returnDisplay: toolResult.returnDisplay,
|
||||||
error: toolResult.error,
|
error: toolResult.error,
|
||||||
},
|
},
|
||||||
|
mcpContext,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check if hook requested to stop entire agent execution
|
// Check if hook requested to stop entire agent execution
|
||||||
|
|||||||
@@ -258,6 +258,128 @@ describe('HookEventHandler', () => {
|
|||||||
expect.stringContaining('F12'),
|
expect.stringContaining('F12'),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should fire BeforeTool event with MCP context when provided', async () => {
|
||||||
|
const mockPlan = [
|
||||||
|
{
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './test.sh',
|
||||||
|
} as unknown as HookConfig,
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockResults: HookExecutionResult[] = [
|
||||||
|
{
|
||||||
|
success: true,
|
||||||
|
duration: 100,
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './test.sh',
|
||||||
|
timeout: 30000,
|
||||||
|
},
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockAggregated = {
|
||||||
|
success: true,
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 100,
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||||
|
sequential: false,
|
||||||
|
});
|
||||||
|
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||||
|
mockResults,
|
||||||
|
);
|
||||||
|
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||||
|
mockAggregated,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mcpContext = {
|
||||||
|
server_name: 'my-mcp-server',
|
||||||
|
tool_name: 'read_file',
|
||||||
|
command: 'npx',
|
||||||
|
args: ['-y', '@my-org/mcp-server'],
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await hookEventHandler.fireBeforeToolEvent(
|
||||||
|
'my-mcp-server__read_file',
|
||||||
|
{ path: '/etc/passwd' },
|
||||||
|
mcpContext,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||||
|
[mockPlan[0].hookConfig],
|
||||||
|
HookEventName.BeforeTool,
|
||||||
|
expect.objectContaining({
|
||||||
|
session_id: 'test-session',
|
||||||
|
cwd: '/test/project',
|
||||||
|
hook_event_name: 'BeforeTool',
|
||||||
|
tool_name: 'my-mcp-server__read_file',
|
||||||
|
tool_input: { path: '/etc/passwd' },
|
||||||
|
mcp_context: mcpContext,
|
||||||
|
}),
|
||||||
|
expect.any(Function),
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBe(mockAggregated);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should not include mcp_context when not provided', async () => {
|
||||||
|
const mockPlan = [
|
||||||
|
{
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './test.sh',
|
||||||
|
} as unknown as HookConfig,
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockResults: HookExecutionResult[] = [
|
||||||
|
{
|
||||||
|
success: true,
|
||||||
|
duration: 100,
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './test.sh',
|
||||||
|
timeout: 30000,
|
||||||
|
},
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockAggregated = {
|
||||||
|
success: true,
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 100,
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||||
|
eventName: HookEventName.BeforeTool,
|
||||||
|
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||||
|
sequential: false,
|
||||||
|
});
|
||||||
|
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||||
|
mockResults,
|
||||||
|
);
|
||||||
|
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||||
|
mockAggregated,
|
||||||
|
);
|
||||||
|
|
||||||
|
await hookEventHandler.fireBeforeToolEvent('EditTool', {
|
||||||
|
file: 'test.txt',
|
||||||
|
});
|
||||||
|
|
||||||
|
const callArgs = vi.mocked(mockHookRunner.executeHooksParallel).mock
|
||||||
|
.calls[0][2];
|
||||||
|
expect(callArgs).not.toHaveProperty('mcp_context');
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('fireAfterToolEvent', () => {
|
describe('fireAfterToolEvent', () => {
|
||||||
@@ -325,6 +447,78 @@ describe('HookEventHandler', () => {
|
|||||||
|
|
||||||
expect(result).toBe(mockAggregated);
|
expect(result).toBe(mockAggregated);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should fire AfterTool event with MCP context when provided', async () => {
|
||||||
|
const mockPlan = [
|
||||||
|
{
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './after.sh',
|
||||||
|
} as unknown as HookConfig,
|
||||||
|
eventName: HookEventName.AfterTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockResults: HookExecutionResult[] = [
|
||||||
|
{
|
||||||
|
success: true,
|
||||||
|
duration: 100,
|
||||||
|
hookConfig: {
|
||||||
|
type: HookType.Command,
|
||||||
|
command: './after.sh',
|
||||||
|
timeout: 30000,
|
||||||
|
},
|
||||||
|
eventName: HookEventName.AfterTool,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
const mockAggregated = {
|
||||||
|
success: true,
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 100,
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||||
|
eventName: HookEventName.AfterTool,
|
||||||
|
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||||
|
sequential: false,
|
||||||
|
});
|
||||||
|
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||||
|
mockResults,
|
||||||
|
);
|
||||||
|
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||||
|
mockAggregated,
|
||||||
|
);
|
||||||
|
|
||||||
|
const toolInput = { path: '/etc/passwd' };
|
||||||
|
const toolResponse = { success: true, content: 'File content' };
|
||||||
|
const mcpContext = {
|
||||||
|
server_name: 'my-mcp-server',
|
||||||
|
tool_name: 'read_file',
|
||||||
|
url: 'https://mcp.example.com',
|
||||||
|
};
|
||||||
|
|
||||||
|
const result = await hookEventHandler.fireAfterToolEvent(
|
||||||
|
'my-mcp-server__read_file',
|
||||||
|
toolInput,
|
||||||
|
toolResponse,
|
||||||
|
mcpContext,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||||
|
[mockPlan[0].hookConfig],
|
||||||
|
HookEventName.AfterTool,
|
||||||
|
expect.objectContaining({
|
||||||
|
tool_name: 'my-mcp-server__read_file',
|
||||||
|
tool_input: toolInput,
|
||||||
|
tool_response: toolResponse,
|
||||||
|
mcp_context: mcpContext,
|
||||||
|
}),
|
||||||
|
expect.any(Function),
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result).toBe(mockAggregated);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('fireBeforeAgentEvent', () => {
|
describe('fireBeforeAgentEvent', () => {
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import type {
|
|||||||
SessionEndReason,
|
SessionEndReason,
|
||||||
PreCompressTrigger,
|
PreCompressTrigger,
|
||||||
HookExecutionResult,
|
HookExecutionResult,
|
||||||
|
McpToolContext,
|
||||||
} from './types.js';
|
} from './types.js';
|
||||||
import { defaultHookTranslator } from './hookTranslator.js';
|
import { defaultHookTranslator } from './hookTranslator.js';
|
||||||
import type {
|
import type {
|
||||||
@@ -58,9 +59,11 @@ function isObject(value: unknown): value is Record<string, unknown> {
|
|||||||
function validateBeforeToolInput(input: Record<string, unknown>): {
|
function validateBeforeToolInput(input: Record<string, unknown>): {
|
||||||
toolName: string;
|
toolName: string;
|
||||||
toolInput: Record<string, unknown>;
|
toolInput: Record<string, unknown>;
|
||||||
|
mcpContext?: McpToolContext;
|
||||||
} {
|
} {
|
||||||
const toolName = input['tool_name'];
|
const toolName = input['tool_name'];
|
||||||
const toolInput = input['tool_input'];
|
const toolInput = input['tool_input'];
|
||||||
|
const mcpContext = input['mcp_context'];
|
||||||
if (typeof toolName !== 'string') {
|
if (typeof toolName !== 'string') {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Invalid input for BeforeTool hook event: tool_name must be a string',
|
'Invalid input for BeforeTool hook event: tool_name must be a string',
|
||||||
@@ -71,7 +74,16 @@ function validateBeforeToolInput(input: Record<string, unknown>): {
|
|||||||
'Invalid input for BeforeTool hook event: tool_input must be an object',
|
'Invalid input for BeforeTool hook event: tool_input must be an object',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return { toolName, toolInput };
|
if (mcpContext !== undefined && !isObject(mcpContext)) {
|
||||||
|
throw new Error(
|
||||||
|
'Invalid input for BeforeTool hook event: mcp_context must be an object',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
toolName,
|
||||||
|
toolInput,
|
||||||
|
mcpContext: mcpContext as McpToolContext | undefined,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -81,10 +93,12 @@ function validateAfterToolInput(input: Record<string, unknown>): {
|
|||||||
toolName: string;
|
toolName: string;
|
||||||
toolInput: Record<string, unknown>;
|
toolInput: Record<string, unknown>;
|
||||||
toolResponse: Record<string, unknown>;
|
toolResponse: Record<string, unknown>;
|
||||||
|
mcpContext?: McpToolContext;
|
||||||
} {
|
} {
|
||||||
const toolName = input['tool_name'];
|
const toolName = input['tool_name'];
|
||||||
const toolInput = input['tool_input'];
|
const toolInput = input['tool_input'];
|
||||||
const toolResponse = input['tool_response'];
|
const toolResponse = input['tool_response'];
|
||||||
|
const mcpContext = input['mcp_context'];
|
||||||
if (typeof toolName !== 'string') {
|
if (typeof toolName !== 'string') {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
'Invalid input for AfterTool hook event: tool_name must be a string',
|
'Invalid input for AfterTool hook event: tool_name must be a string',
|
||||||
@@ -100,7 +114,17 @@ function validateAfterToolInput(input: Record<string, unknown>): {
|
|||||||
'Invalid input for AfterTool hook event: tool_response must be an object',
|
'Invalid input for AfterTool hook event: tool_response must be an object',
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return { toolName, toolInput, toolResponse };
|
if (mcpContext !== undefined && !isObject(mcpContext)) {
|
||||||
|
throw new Error(
|
||||||
|
'Invalid input for AfterTool hook event: mcp_context must be an object',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
toolName,
|
||||||
|
toolInput,
|
||||||
|
toolResponse,
|
||||||
|
mcpContext: mcpContext as McpToolContext | undefined,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -313,11 +337,13 @@ export class HookEventHandler {
|
|||||||
async fireBeforeToolEvent(
|
async fireBeforeToolEvent(
|
||||||
toolName: string,
|
toolName: string,
|
||||||
toolInput: Record<string, unknown>,
|
toolInput: Record<string, unknown>,
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
): Promise<AggregatedHookResult> {
|
): Promise<AggregatedHookResult> {
|
||||||
const input: BeforeToolInput = {
|
const input: BeforeToolInput = {
|
||||||
...this.createBaseInput(HookEventName.BeforeTool),
|
...this.createBaseInput(HookEventName.BeforeTool),
|
||||||
tool_name: toolName,
|
tool_name: toolName,
|
||||||
tool_input: toolInput,
|
tool_input: toolInput,
|
||||||
|
...(mcpContext && { mcp_context: mcpContext }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const context: HookEventContext = { toolName };
|
const context: HookEventContext = { toolName };
|
||||||
@@ -332,12 +358,14 @@ export class HookEventHandler {
|
|||||||
toolName: string,
|
toolName: string,
|
||||||
toolInput: Record<string, unknown>,
|
toolInput: Record<string, unknown>,
|
||||||
toolResponse: Record<string, unknown>,
|
toolResponse: Record<string, unknown>,
|
||||||
|
mcpContext?: McpToolContext,
|
||||||
): Promise<AggregatedHookResult> {
|
): Promise<AggregatedHookResult> {
|
||||||
const input: AfterToolInput = {
|
const input: AfterToolInput = {
|
||||||
...this.createBaseInput(HookEventName.AfterTool),
|
...this.createBaseInput(HookEventName.AfterTool),
|
||||||
tool_name: toolName,
|
tool_name: toolName,
|
||||||
tool_input: toolInput,
|
tool_input: toolInput,
|
||||||
tool_response: toolResponse,
|
tool_response: toolResponse,
|
||||||
|
...(mcpContext && { mcp_context: mcpContext }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const context: HookEventContext = { toolName };
|
const context: HookEventContext = { toolName };
|
||||||
@@ -725,18 +753,23 @@ export class HookEventHandler {
|
|||||||
// Route to appropriate event handler based on eventName
|
// Route to appropriate event handler based on eventName
|
||||||
switch (request.eventName) {
|
switch (request.eventName) {
|
||||||
case HookEventName.BeforeTool: {
|
case HookEventName.BeforeTool: {
|
||||||
const { toolName, toolInput } =
|
const { toolName, toolInput, mcpContext } =
|
||||||
validateBeforeToolInput(enrichedInput);
|
validateBeforeToolInput(enrichedInput);
|
||||||
result = await this.fireBeforeToolEvent(toolName, toolInput);
|
result = await this.fireBeforeToolEvent(
|
||||||
|
toolName,
|
||||||
|
toolInput,
|
||||||
|
mcpContext,
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case HookEventName.AfterTool: {
|
case HookEventName.AfterTool: {
|
||||||
const { toolName, toolInput, toolResponse } =
|
const { toolName, toolInput, toolResponse, mcpContext } =
|
||||||
validateAfterToolInput(enrichedInput);
|
validateAfterToolInput(enrichedInput);
|
||||||
result = await this.fireAfterToolEvent(
|
result = await this.fireAfterToolEvent(
|
||||||
toolName,
|
toolName,
|
||||||
toolInput,
|
toolInput,
|
||||||
toolResponse,
|
toolResponse,
|
||||||
|
mcpContext,
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import type {
|
|||||||
FunctionCallingConfig,
|
FunctionCallingConfig,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
|
import { getResponseText } from '../utils/partUtils.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Decoupled LLM request format - stable across Gemini CLI versions
|
* Decoupled LLM request format - stable across Gemini CLI versions
|
||||||
@@ -267,7 +268,7 @@ export class HookTranslatorGenAIv1 extends HookTranslator {
|
|||||||
*/
|
*/
|
||||||
toHookLLMResponse(sdkResponse: GenerateContentResponse): LLMResponse {
|
toHookLLMResponse(sdkResponse: GenerateContentResponse): LLMResponse {
|
||||||
return {
|
return {
|
||||||
text: sdkResponse.text,
|
text: getResponseText(sdkResponse) ?? undefined,
|
||||||
candidates: (sdkResponse.candidates || []).map((candidate) => {
|
candidates: (sdkResponse.candidates || []).map((candidate) => {
|
||||||
// Extract text parts from the candidate
|
// Extract text parts from the candidate
|
||||||
const textParts =
|
const textParts =
|
||||||
|
|||||||
@@ -373,12 +373,37 @@ export class AfterModelHookOutput extends DefaultHookOutput {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Context for MCP tool executions.
|
||||||
|
* Contains non-sensitive connection information about the MCP server
|
||||||
|
* identity. Since server_name is user controlled and arbitrary, we
|
||||||
|
* also include connection information (e.g., command or url) to
|
||||||
|
* help identify the MCP server.
|
||||||
|
*
|
||||||
|
* NOTE: In the future, consider defining a shared sanitized interface
|
||||||
|
* from MCPServerConfig to avoid duplication and ensure consistency.
|
||||||
|
*/
|
||||||
|
export interface McpToolContext {
|
||||||
|
server_name: string;
|
||||||
|
tool_name: string; // Original tool name from the MCP server
|
||||||
|
|
||||||
|
// Connection info (mutually exclusive based on transport type)
|
||||||
|
command?: string; // For stdio transport
|
||||||
|
args?: string[]; // For stdio transport
|
||||||
|
cwd?: string; // For stdio transport
|
||||||
|
|
||||||
|
url?: string; // For SSE/HTTP transport
|
||||||
|
|
||||||
|
tcp?: string; // For WebSocket transport
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* BeforeTool hook input
|
* BeforeTool hook input
|
||||||
*/
|
*/
|
||||||
export interface BeforeToolInput extends HookInput {
|
export interface BeforeToolInput extends HookInput {
|
||||||
tool_name: string;
|
tool_name: string;
|
||||||
tool_input: Record<string, unknown>;
|
tool_input: Record<string, unknown>;
|
||||||
|
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -398,6 +423,7 @@ export interface AfterToolInput extends HookInput {
|
|||||||
tool_name: string;
|
tool_name: string;
|
||||||
tool_input: Record<string, unknown>;
|
tool_input: Record<string, unknown>;
|
||||||
tool_response: Record<string, unknown>;
|
tool_response: Record<string, unknown>;
|
||||||
|
mcp_context?: McpToolContext; // Only present for MCP tools
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ export class ToolExecutor {
|
|||||||
liveOutputCallback,
|
liveOutputCallback,
|
||||||
shellExecutionConfig,
|
shellExecutionConfig,
|
||||||
setPidCallback,
|
setPidCallback,
|
||||||
|
this.config,
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
promise = executeToolWithHooks(
|
promise = executeToolWithHooks(
|
||||||
@@ -109,6 +110,8 @@ export class ToolExecutor {
|
|||||||
tool,
|
tool,
|
||||||
liveOutputCallback,
|
liveOutputCallback,
|
||||||
shellExecutionConfig,
|
shellExecutionConfig,
|
||||||
|
undefined,
|
||||||
|
this.config,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ type McpContentBlock =
|
|||||||
| McpResourceBlock
|
| McpResourceBlock
|
||||||
| McpResourceLinkBlock;
|
| McpResourceLinkBlock;
|
||||||
|
|
||||||
class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||||
ToolParams,
|
ToolParams,
|
||||||
ToolResult
|
ToolResult
|
||||||
> {
|
> {
|
||||||
|
|||||||
Reference in New Issue
Block a user