mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-23 11:34:44 -07:00
Extensions MCP refactor (#12413)
This commit is contained in:
@@ -4,86 +4,193 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
import {
|
||||
afterEach,
|
||||
beforeEach,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
type MockedObject,
|
||||
} from 'vitest';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
McpClient: vi.fn(),
|
||||
populateMcpServerCommand: vi.fn(() => ({
|
||||
'test-server': {},
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('McpClientManager', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
let mockedMcpClient: MockedObject<McpClient>;
|
||||
let mockConfig: MockedObject<Config>;
|
||||
|
||||
it('should discover tools from all servers', async () => {
|
||||
const mockedMcpClient = {
|
||||
beforeEach(() => {
|
||||
mockedMcpClient = vi.mockObject({
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => true,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
getServerConfig: vi.fn(),
|
||||
} as unknown as McpClient);
|
||||
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
|
||||
mockConfig = vi.mockObject({
|
||||
isTrustedFolder: vi.fn().mockReturnValue(true),
|
||||
getMcpServers: vi.fn().mockReturnValue({}),
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getAllowedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(''),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
isInitialized: vi.fn(),
|
||||
}),
|
||||
} as unknown as Config);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should discover tools from all configured', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not discover tools if folder is not trusted', async () => {
|
||||
const mockedMcpClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => false,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.isTrustedFolder.mockReturnValue(false);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not start blocked servers', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should only start allowed servers if allow list is not empty', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
'another-server': {},
|
||||
});
|
||||
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should start servers from extensions', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
isActive: true,
|
||||
version: '1.0.0',
|
||||
path: '/some-path',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not start servers from disabled extensions', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
'test-server': {},
|
||||
},
|
||||
isActive: false,
|
||||
version: '1.0.0',
|
||||
path: '/some-path',
|
||||
contextFiles: [],
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should add blocked servers to the blockedMcpServers list', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(manager.getBlockedMcpServers()).toEqual([
|
||||
{ name: 'test-server', extensionName: '' },
|
||||
]);
|
||||
});
|
||||
|
||||
describe('restart', () => {
|
||||
it('should restart all running servers', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue({});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
await manager.restart();
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('restartServer', () => {
|
||||
it('should restart the specified server', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': {},
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue({});
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
|
||||
await manager.restartServer('test-server');
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw an error if the server does not exist', async () => {
|
||||
const manager = new McpClientManager({} as ToolRegistry, mockConfig);
|
||||
await expect(manager.restartServer('non-existent')).rejects.toThrow(
|
||||
'No MCP server registered with the name "non-existent"',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -33,6 +33,10 @@ export class McpClientManager {
|
||||
private discoveryPromise: Promise<void> | undefined;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
private readonly blockedMcpServers: Array<{
|
||||
name: string;
|
||||
extensionName: string;
|
||||
}> = [];
|
||||
|
||||
constructor(
|
||||
toolRegistry: ToolRegistry,
|
||||
@@ -42,19 +46,10 @@ export class McpClientManager {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.cliConfig = cliConfig;
|
||||
this.eventEmitter = eventEmitter;
|
||||
if (this.cliConfig.getEnableExtensionReloading()) {
|
||||
this.cliConfig
|
||||
.getExtensionLoader()
|
||||
.extensionEvents()
|
||||
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionDisabled', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
)
|
||||
.on('extensionUnloaded', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
getBlockedMcpServers() {
|
||||
return this.blockedMcpServers;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -64,21 +59,13 @@ export class McpClientManager {
|
||||
* - Disconnects all MCP clients from their servers.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async unloadExtension(extension: GeminiCLIExtension) {
|
||||
async stopExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Unloading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.keys(extension.mcpServers ?? {}).map((name) => {
|
||||
const newMcpServers = {
|
||||
...this.cliConfig.getMcpServers(),
|
||||
};
|
||||
delete newMcpServers[name];
|
||||
this.cliConfig.setMcpServers(newMcpServers);
|
||||
return this.disconnectClient(name);
|
||||
}),
|
||||
Object.keys(extension.mcpServers ?? {}).map(
|
||||
this.disconnectClient.bind(this),
|
||||
),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -88,20 +75,36 @@ export class McpClientManager {
|
||||
* - Connects MCP clients to each server and discovers their tools.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async loadExtension(extension: GeminiCLIExtension) {
|
||||
async startExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Loading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
|
||||
this.cliConfig.setMcpServers({
|
||||
...this.cliConfig.getMcpServers(),
|
||||
[name]: config,
|
||||
});
|
||||
return this.discoverMcpTools(name, config);
|
||||
}),
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) =>
|
||||
this.maybeDiscoverMcpServer(name, {
|
||||
...config,
|
||||
extension,
|
||||
}),
|
||||
),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
private isAllowedMcpServer(name: string) {
|
||||
const allowedNames = this.cliConfig.getAllowedMcpServers();
|
||||
if (
|
||||
allowedNames &&
|
||||
allowedNames.length > 0 &&
|
||||
allowedNames.indexOf(name) === -1
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
const blockedNames = this.cliConfig.getBlockedMcpServers();
|
||||
if (
|
||||
blockedNames &&
|
||||
blockedNames.length > 0 &&
|
||||
blockedNames.indexOf(name) !== -1
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string) {
|
||||
@@ -115,36 +118,68 @@ export class McpClientManager {
|
||||
debugLogger.warn(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
} finally {
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
const geminiClient = this.cliConfig.getGeminiClient();
|
||||
if (geminiClient.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
discoverMcpTools(
|
||||
maybeDiscoverMcpServer(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.isAllowedMcpServer(name)) {
|
||||
if (!this.blockedMcpServers.find((s) => s.name === name)) {
|
||||
this.blockedMcpServers?.push({
|
||||
name,
|
||||
extensionName: config.extension?.name ?? '',
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
if (config.extension && !config.extension.isActive) {
|
||||
return;
|
||||
}
|
||||
const existing = this.clients.get(name);
|
||||
if (existing && existing.getServerConfig().extension !== config.extension) {
|
||||
const extensionText = config.extension
|
||||
? ` from extension "${config.extension.name}"`
|
||||
: '';
|
||||
debugLogger.warn(
|
||||
`Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
await this.disconnectClient(name);
|
||||
if (existing) {
|
||||
await existing.disconnect();
|
||||
}
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const client =
|
||||
existing ??
|
||||
new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
if (!existing) {
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
}
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(this.cliConfig);
|
||||
@@ -161,6 +196,12 @@ export class McpClientManager {
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
const geminiClient = this.cliConfig.getGeminiClient();
|
||||
if (geminiClient.isInitialized()) {
|
||||
await geminiClient.setTools();
|
||||
}
|
||||
resolve();
|
||||
}
|
||||
})();
|
||||
@@ -174,6 +215,7 @@ export class McpClientManager {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
this.discoveryPromise = currentDiscoveryPromise;
|
||||
}
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const currentPromise = this.discoveryPromise;
|
||||
currentPromise.then((_) => {
|
||||
// If we are the last recorded discoveryPromise, then we are done, reset
|
||||
@@ -187,15 +229,21 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* Initiates the tool discovery process for all configured MCP servers.
|
||||
* Initiates the tool discovery process for all configured MCP servers (via
|
||||
* gemini settings or command line arguments).
|
||||
*
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*
|
||||
* For any server which is already connected, it will first be disconnected.
|
||||
*
|
||||
* This does NOT load extension MCP servers - this happens when the
|
||||
* ExtensionLoader explicitly calls `loadExtension`.
|
||||
*/
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
async startConfiguredMcpServers(): Promise<void> {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
await this.stop();
|
||||
|
||||
const servers = populateMcpServerCommand(
|
||||
this.cliConfig.getMcpServers() || {},
|
||||
@@ -204,12 +252,40 @@ export class McpClientManager {
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(async ([name, config]) =>
|
||||
this.discoverMcpTools(name, config),
|
||||
Object.entries(servers).map(([name, config]) =>
|
||||
this.maybeDiscoverMcpServer(name, config),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all active MCP Clients.
|
||||
*/
|
||||
async restart(): Promise<void> {
|
||||
await Promise.all(
|
||||
Array.from(this.clients.entries()).map(async ([name, client]) => {
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Restart a single MCP server by name.
|
||||
*/
|
||||
async restartServer(name: string) {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
}
|
||||
|
||||
/**
|
||||
* Stops all running local MCP servers and closes all client connections.
|
||||
* This is the cleanup method to be called on application exit.
|
||||
@@ -236,4 +312,15 @@ export class McpClientManager {
|
||||
getDiscoveryState(): MCPDiscoveryState {
|
||||
return this.discoveryState;
|
||||
}
|
||||
|
||||
/**
|
||||
* All of the MCP server configurations currently loaded.
|
||||
*/
|
||||
getMcpServers(): Record<string, MCPServerConfig> {
|
||||
const mcpServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [name, client] of this.clients.entries()) {
|
||||
mcpServers[name] = client.getServerConfig();
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,6 +303,71 @@ describe('mcp-client', () => {
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should remove tools and prompts on disconnect', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
close: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: {}, prompts: {} }),
|
||||
request: vi.fn().mockResolvedValue({
|
||||
prompts: [{ id: 'prompt1', text: 'a prompt' }],
|
||||
}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
unregisterTool: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const mockedPromptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
unregisterPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
mockedPromptRegistry,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
|
||||
|
||||
await client.disconnect();
|
||||
|
||||
expect(mockedClient.close).toHaveBeenCalledOnce();
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledOnce();
|
||||
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
|
||||
@@ -161,6 +161,7 @@ export class McpClient {
|
||||
return;
|
||||
}
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||
const client = this.client;
|
||||
this.client = undefined;
|
||||
@@ -208,6 +209,10 @@ export class McpClient {
|
||||
this.assertConnected();
|
||||
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
|
||||
}
|
||||
|
||||
getServerConfig(): MCPServerConfig {
|
||||
return this.serverConfig;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -19,20 +19,10 @@ import { spawn } from 'node:child_process';
|
||||
|
||||
import fs from 'node:fs';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
// Mock ./mcp-client.js to control its behavior within tool-registry tests
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
return {
|
||||
...originalModule,
|
||||
};
|
||||
});
|
||||
|
||||
// Mock node:child_process
|
||||
vi.mock('node:child_process', async () => {
|
||||
const actual = await vi.importActual('node:child_process');
|
||||
@@ -401,27 +391,6 @@ describe('ToolRegistry', () => {
|
||||
expect(result.llmContent).toContain('Stderr: Something went wrong');
|
||||
expect(result.llmContent).toContain('Exit Code: 1');
|
||||
});
|
||||
|
||||
it('should discover tools using MCP servers defined in getMcpServers', async () => {
|
||||
const discoverSpy = vi.spyOn(
|
||||
McpClientManager.prototype,
|
||||
'discoverAllMcpTools',
|
||||
);
|
||||
mockConfigGetToolDiscoveryCommand.mockReturnValue(undefined);
|
||||
vi.spyOn(config, 'getMcpServerCommand').mockReturnValue(undefined);
|
||||
const mcpServerConfigVal = {
|
||||
'my-mcp-server': {
|
||||
command: 'mcp-server-cmd',
|
||||
args: ['--port', '1234'],
|
||||
trust: true,
|
||||
},
|
||||
};
|
||||
vi.spyOn(config, 'getMcpServers').mockReturnValue(mcpServerConfigVal);
|
||||
|
||||
await toolRegistry.discoverAllTools();
|
||||
|
||||
expect(discoverSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('DiscoveredToolInvocation', () => {
|
||||
|
||||
@@ -14,13 +14,10 @@ import { Kind, BaseDeclarativeTool, BaseToolInvocation } from './tools.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { connectAndDiscover } from './mcp-client.js';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
@@ -176,12 +173,10 @@ export class ToolRegistry {
|
||||
// The tools keyed by tool name as seen by the LLM.
|
||||
private tools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
private mcpClientManager: McpClientManager;
|
||||
private messageBus?: MessageBus;
|
||||
|
||||
constructor(config: Config, eventEmitter?: EventEmitter) {
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
|
||||
}
|
||||
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
@@ -238,64 +233,7 @@ export class ToolRegistry {
|
||||
async discoverAllTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
this.removeDiscoveredTools();
|
||||
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools from project (if available and configured).
|
||||
* Can be called multiple times to update discovered tools.
|
||||
* This will NOT discover tools from the command line, only from MCP servers.
|
||||
*/
|
||||
async discoverMcpTools(): Promise<void> {
|
||||
// remove any previously discovered tools
|
||||
this.removeDiscoveredTools();
|
||||
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all MCP servers and re-discovers tools.
|
||||
*/
|
||||
async restartMcpServers(): Promise<void> {
|
||||
await this.discoverMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or re-discover tools for a single MCP server.
|
||||
* @param serverName - The name of the server to discover tools from.
|
||||
*/
|
||||
async discoverToolsForServer(serverName: string): Promise<void> {
|
||||
// Remove any previously discovered tools from this server
|
||||
for (const [name, tool] of this.tools.entries()) {
|
||||
if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) {
|
||||
this.tools.delete(name);
|
||||
}
|
||||
}
|
||||
|
||||
this.config.getPromptRegistry().removePromptsByServer(serverName);
|
||||
|
||||
const mcpServers = this.config.getMcpServers() ?? {};
|
||||
const serverConfig = mcpServers[serverName];
|
||||
if (serverConfig) {
|
||||
await connectAndDiscover(
|
||||
serverName,
|
||||
serverConfig,
|
||||
this,
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
this.config,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async discoverAndRegisterToolsFromCommand(): Promise<void> {
|
||||
|
||||
Reference in New Issue
Block a user