feat(core): multi-registry architecture and tool filtering for subagents (#22712)

This commit is contained in:
AK
2026-03-17 13:54:07 -07:00
committed by GitHub
parent a361a84708
commit 2504105a1c
8 changed files with 586 additions and 342 deletions
@@ -14,9 +14,11 @@ import {
type MockedObject,
} from 'vitest';
import { McpClientManager } from './mcp-client-manager.js';
import { McpClient, MCPDiscoveryState } from './mcp-client.js';
import { McpClient, MCPDiscoveryState, MCPServerStatus } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { Config, GeminiCLIExtension } from '../config/config.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { ResourceRegistry } from '../resources/resource-registry.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
@@ -34,21 +36,25 @@ describe('McpClientManager', () => {
beforeEach(() => {
mockedMcpClient = vi.mockObject({
connect: vi.fn(),
discover: vi.fn(),
discoverInto: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
getStatus: vi.fn().mockReturnValue(MCPServerStatus.DISCONNECTED),
getServerConfig: vi.fn(),
getServerName: vi.fn().mockReturnValue('test-server'),
} as unknown as McpClient);
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
mockConfig = vi.mockObject({
isTrustedFolder: vi.fn().mockReturnValue(true),
getMcpServers: vi.fn().mockReturnValue({}),
getPromptRegistry: () => {},
getResourceRegistry: () => {},
getPromptRegistry: vi.fn().mockReturnValue({ registerPrompt: vi.fn() }),
getResourceRegistry: vi
.fn()
.mockReturnValue({ setResourcesForServer: vi.fn() }),
getDebugMode: () => false,
getWorkspaceContext: () => {},
getWorkspaceContext: () => ({ getDirectories: () => [] }),
getAllowedMcpServers: vi.fn().mockReturnValue([]),
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getExcludedMcpServers: vi.fn().mockReturnValue([]),
getMcpServerCommand: vi.fn().mockReturnValue(''),
getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined),
getGeminiClient: vi.fn().mockReturnValue({
@@ -56,21 +62,39 @@ describe('McpClientManager', () => {
}),
refreshMcpContext: vi.fn(),
} as unknown as Config);
toolRegistry = {} as ToolRegistry;
toolRegistry = vi.mockObject({
registerTool: vi.fn(),
unregisterTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue({}),
removeMcpToolsByServer: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
} as unknown as ToolRegistry);
});
afterEach(() => {
vi.restoreAllMocks();
});
const setupManager = (manager: McpClientManager) => {
manager.setMainRegistries({
toolRegistry,
promptRegistry:
mockConfig.getPromptRegistry() as unknown as PromptRegistry,
resourceRegistry:
mockConfig.getResourceRegistry() as unknown as ResourceRegistry,
});
return manager;
};
it('should discover tools from all configured', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': { command: 'node' },
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
});
@@ -80,12 +104,12 @@ describe('McpClientManager', () => {
'server-2': { command: 'node' },
'server-3': { command: 'node' },
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
// Each client should be connected/discovered
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(3);
// But context refresh should happen only once
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
@@ -95,7 +119,7 @@ describe('McpClientManager', () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': { command: 'node' },
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED);
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
@@ -112,7 +136,7 @@ describe('McpClientManager', () => {
isFileEnabled: vi.fn().mockResolvedValue(false),
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
await promise;
@@ -120,7 +144,7 @@ describe('McpClientManager', () => {
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
expect(manager.getMcpServerCount()).toBe(0);
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should mark discovery completed when all configured servers are blocked', async () => {
@@ -129,7 +153,7 @@ describe('McpClientManager', () => {
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
await promise;
@@ -137,7 +161,7 @@ describe('McpClientManager', () => {
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
expect(manager.getMcpServerCount()).toBe(0);
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should not discover tools if folder is not trusted', async () => {
@@ -145,10 +169,10 @@ describe('McpClientManager', () => {
'test-server': { command: 'node' },
});
mockConfig.isTrustedFolder.mockReturnValue(false);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should not start blocked servers', async () => {
@@ -156,10 +180,10 @@ describe('McpClientManager', () => {
'test-server': { command: 'node' },
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should only start allowed servers if allow list is not empty', async () => {
@@ -168,14 +192,14 @@ describe('McpClientManager', () => {
'another-server': { command: 'node' },
});
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
});
it('should start servers from extensions', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startExtension({
name: 'test-extension',
mcpServers: {
@@ -188,11 +212,11 @@ describe('McpClientManager', () => {
id: '123',
});
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
});
it('should not start servers from disabled extensions', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startExtension({
name: 'test-extension',
mcpServers: {
@@ -205,7 +229,7 @@ describe('McpClientManager', () => {
id: '123',
});
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should add blocked servers to the blockedMcpServers list', async () => {
@@ -213,7 +237,7 @@ describe('McpClientManager', () => {
'test-server': { command: 'node' },
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(manager.getBlockedMcpServers()).toEqual([
{ name: 'test-server', extensionName: '' },
@@ -224,10 +248,10 @@ describe('McpClientManager', () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': { excludeTools: ['dangerous_tool'] },
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
// But it should still be tracked in allServerConfigs
expect(manager.getMcpServers()).toHaveProperty('test-server');
@@ -240,16 +264,16 @@ describe('McpClientManager', () => {
'test-server': serverConfig,
});
mockedMcpClient.getServerConfig.mockReturnValue(serverConfig);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
await manager.restart();
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
});
});
@@ -260,21 +284,21 @@ describe('McpClientManager', () => {
'test-server': serverConfig,
});
mockedMcpClient.getServerConfig.mockReturnValue(serverConfig);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
await manager.restartServer('test-server');
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
});
it('should throw an error if the server does not exist', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await expect(manager.restartServer('non-existent')).rejects.toThrow(
'No MCP server registered with the name "non-existent"',
);
@@ -296,7 +320,7 @@ describe('McpClientManager', () => {
});
mockedMcpClient.getServerConfig.mockReturnValue(originalConfig);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
// First call should use the original config
@@ -321,9 +345,10 @@ describe('McpClientManager', () => {
(name, config) =>
({
connect: vi.fn(),
discover: vi.fn(),
discoverInto: vi.fn(),
disconnect: vi.fn(),
getServerConfig: vi.fn().mockReturnValue(config),
getServerName: vi.fn().mockReturnValue(name),
getInstructions: vi
.fn()
.mockReturnValue(
@@ -333,12 +358,7 @@ describe('McpClientManager', () => {
),
}) as unknown as McpClient,
);
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
mockConfig.getMcpServers.mockReturnValue({
'server-with-instructions': { command: 'node' },
@@ -373,11 +393,7 @@ describe('McpClientManager', () => {
'test-server': { command: 'node' },
});
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
await expect(manager.startConfiguredMcpServers()).resolves.not.toThrow();
});
@@ -396,11 +412,8 @@ describe('McpClientManager', () => {
'test-server': { command: 'node' },
});
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
await manager.startConfiguredMcpServers();
await expect(manager.restartServer('test-server')).resolves.not.toThrow();
@@ -409,7 +422,7 @@ describe('McpClientManager', () => {
describe('Extension handling', () => {
it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const mcpServers = {
'test-server': { command: 'node', args: ['server.js'] },
};
@@ -431,7 +444,7 @@ describe('McpClientManager', () => {
});
it('should merge extension configuration with an existing user-configured server', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const userConfig = { command: 'node', args: ['user-server.js'] };
mockConfig.getMcpServers.mockReturnValue({
@@ -468,7 +481,7 @@ describe('McpClientManager', () => {
});
it('should securely merge tool lists and env variables regardless of load order', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const userConfig = {
excludeTools: ['user-tool'],
@@ -523,7 +536,7 @@ describe('McpClientManager', () => {
// Reset for Case 2
vi.mocked(McpClient).mockClear();
const manager2 = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager2 = setupManager(new McpClientManager('0.0.1', mockConfig));
// Case 2: User config loads first, then Extension loads
// This call will skip discovery because userConfig has no connection details
@@ -551,7 +564,7 @@ describe('McpClientManager', () => {
});
it('should result in empty includeTools if intersection is empty', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const userConfig = { includeTools: ['user-tool'] };
const extConfig = {
command: 'node',
@@ -567,7 +580,7 @@ describe('McpClientManager', () => {
});
it('should respect a single allowlist if only one is provided', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const userConfig = { includeTools: ['user-tool'] };
const extConfig = { command: 'node', args: ['ext.js'] };
@@ -579,7 +592,7 @@ describe('McpClientManager', () => {
});
it('should allow partial overrides of connection properties', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const extConfig = { command: 'node', args: ['ext.js'], timeout: 1000 };
const userOverride = { args: ['overridden.js'] };
@@ -599,7 +612,7 @@ describe('McpClientManager', () => {
});
it('should prevent one extension from hijacking another extension server name', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const extension1: GeminiCLIExtension = {
name: 'extension-1',
@@ -641,7 +654,7 @@ describe('McpClientManager', () => {
it('should remove servers from blockedMcpServers when stopExtension is called', async () => {
mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const mcpServers = {
'blocked-server': { command: 'node', args: ['server.js'] },
};
@@ -679,7 +692,7 @@ describe('McpClientManager', () => {
});
it('should emit hint instead of full error when user has not interacted with MCP', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic(
'error',
'Something went wrong',
@@ -698,7 +711,7 @@ describe('McpClientManager', () => {
});
it('should emit full error when user has interacted with MCP', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.setUserInteractedWithMcp();
manager.emitDiagnostic(
'error',
@@ -714,7 +727,7 @@ describe('McpClientManager', () => {
});
it('should still deduplicate diagnostic messages after user interaction', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.setUserInteractedWithMcp();
manager.emitDiagnostic('error', 'Same error');
@@ -724,7 +737,7 @@ describe('McpClientManager', () => {
});
it('should only show hint once per session', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic('error', 'Error 1');
manager.emitDiagnostic('error', 'Error 2');
@@ -737,7 +750,7 @@ describe('McpClientManager', () => {
});
it('should capture last error for a server even when silenced', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic(
'error',
@@ -752,7 +765,7 @@ describe('McpClientManager', () => {
});
it('should show previously deduplicated errors after interaction clears state', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic('error', 'Same error');
expect(coreEventsMock.emitFeedback).toHaveBeenCalledTimes(1); // The hint
+122 -41
View File
@@ -13,6 +13,7 @@ import type { ToolRegistry } from './tool-registry.js';
import {
McpClient,
MCPDiscoveryState,
MCPServerStatus,
populateMcpServerCommand,
} from './mcp-client.js';
import { getErrorMessage, isAuthenticationError } from '../utils/errors.js';
@@ -20,6 +21,11 @@ import type { EventEmitter } from 'node:events';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
import { createHash } from 'node:crypto';
import { stableStringify } from '../policy/stable-stringify.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { ResourceRegistry } from '../resources/resource-registry.js';
/**
* Manages the lifecycle of multiple MCP clients, including local child processes.
* This class is responsible for starting, stopping, and discovering tools from
@@ -30,7 +36,6 @@ export class McpClientManager {
// Track all configured servers (including disabled ones) for UI display
private allServerConfigs: Map<string, MCPServerConfig> = new Map();
private readonly clientVersion: string;
private readonly toolRegistry: ToolRegistry;
private readonly cliConfig: Config;
// If we have ongoing MCP client discovery, this completes once that is done.
private discoveryPromise: Promise<void> | undefined;
@@ -42,6 +47,10 @@ export class McpClientManager {
extensionName: string;
}> = [];
private mainToolRegistry: ToolRegistry | undefined;
private mainPromptRegistry: PromptRegistry | undefined;
private mainResourceRegistry: ResourceRegistry | undefined;
/**
* Track whether the user has explicitly interacted with MCP in this session
* (e.g. by running an /mcp command).
@@ -66,16 +75,24 @@ export class McpClientManager {
constructor(
clientVersion: string,
toolRegistry: ToolRegistry,
cliConfig: Config,
eventEmitter?: EventEmitter,
) {
this.clientVersion = clientVersion;
this.toolRegistry = toolRegistry;
this.cliConfig = cliConfig;
this.eventEmitter = eventEmitter;
}
setMainRegistries(registries: {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
}) {
this.mainToolRegistry = registries.toolRegistry;
this.mainPromptRegistry = registries.promptRegistry;
this.mainResourceRegistry = registries.resourceRegistry;
}
setUserInteractedWithMcp() {
this.userInteractedWithMcp = true;
}
@@ -147,6 +164,16 @@ export class McpClientManager {
return this.clients.get(serverName);
}
removeRegistries(registries: {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
}): void {
for (const client of this.clients.values()) {
client.removeRegistries(registries);
}
}
/**
* For all the MCP servers associated with this extension:
*
@@ -236,16 +263,17 @@ export class McpClientManager {
return false;
}
private async disconnectClient(name: string, skipRefresh = false) {
const existing = this.clients.get(name);
private async disconnectClient(clientKey: string, skipRefresh = false) {
const existing = this.clients.get(clientKey);
if (existing) {
const serverName = existing.getServerName();
try {
this.clients.delete(name);
this.clients.delete(clientKey);
this.eventEmitter?.emit('mcp-client-update', this.clients);
await existing.disconnect();
} catch (error) {
debugLogger.warn(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
`Error stopping client '${serverName}': ${getErrorMessage(error)}`,
);
} finally {
if (!skipRefresh) {
@@ -257,6 +285,16 @@ export class McpClientManager {
}
}
private getClientKey(name: string, config: MCPServerConfig): string {
const { extension, ...rest } = config;
const keyData = {
name,
config: rest,
extensionId: extension?.id,
};
return createHash('sha256').update(stableStringify(keyData)).digest('hex');
}
/**
* Merges two MCP configurations. The second configuration (override)
* takes precedence for scalar properties, but array properties are
@@ -305,6 +343,11 @@ export class McpClientManager {
async maybeDiscoverMcpServer(
name: string,
config: MCPServerConfig,
registries?: {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
},
): Promise<void> {
const existingConfig = this.allServerConfigs.get(name);
if (
@@ -337,11 +380,27 @@ export class McpClientManager {
// Always track server config for UI display
this.allServerConfigs.set(name, finalConfig);
// Capture the existing client synchronously here before any asynchronous
// operations. This ensures that if multiple discovery turns happen
// concurrently, this turn only replaces/disconnects the client that was
// present when this specific configuration update request began.
const existing = this.clients.get(name);
const clientKey = this.getClientKey(name, finalConfig);
// If no registries are provided (main agent) and a server with this name already exists
// but with a different configuration, handle potential conflicts.
if (!registries) {
const existingSameName = Array.from(this.clients.values()).find(
(c) => c.getServerName() === name,
);
if (existingSameName) {
const existingConfigFromClient = existingSameName.getServerConfig();
const existingKey = this.getClientKey(name, existingConfigFromClient);
if (existingKey !== clientKey) {
// This is a configuration update (hot-reload).
// We should stop the old client before starting the new one.
await this.disconnectClient(existingKey, true);
}
}
}
const existing = this.clients.get(clientKey);
// If no connection details are provided, we can't discover this server.
// This often happens when a user provides only overrides (like excludeTools)
@@ -363,7 +422,7 @@ export class McpClientManager {
// User-disabled servers: disconnect if running, don't start
if (await this.isDisabledByUser(name)) {
if (existing) {
await this.disconnectClient(name);
await this.disconnectClient(clientKey);
}
return;
}
@@ -374,34 +433,48 @@ export class McpClientManager {
return;
}
const currentDiscoveryPromise = new Promise<void>((resolve, reject) => {
(async () => {
const currentDiscoveryPromise = new Promise<void>((resolve) => {
void (async () => {
try {
if (existing) {
this.clients.delete(name);
await existing.disconnect();
let client = existing;
if (!client) {
client = new McpClient(
name,
finalConfig,
this.cliConfig.getWorkspaceContext(),
this.cliConfig,
this.cliConfig.getDebugMode(),
this.clientVersion,
async () => {
debugLogger.log(
`🔔 Refreshing context for server '${name}'...`,
);
await this.scheduleMcpContextRefresh();
},
);
this.clients.set(clientKey, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
}
const client = new McpClient(
name,
finalConfig,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getResourceRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig,
this.cliConfig.getDebugMode(),
this.clientVersion,
async () => {
debugLogger.log(`🔔 Refreshing context for server '${name}'...`);
await this.scheduleMcpContextRefresh();
},
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
const targetRegistries =
registries ??
(this.mainToolRegistry &&
this.mainPromptRegistry &&
this.mainResourceRegistry
? {
toolRegistry: this.mainToolRegistry,
promptRegistry: this.mainPromptRegistry,
resourceRegistry: this.mainResourceRegistry,
}
: undefined);
try {
await client.connect();
await client.discover(this.cliConfig);
if (client.getStatus() === MCPServerStatus.DISCONNECTED) {
await client.connect();
}
if (targetRegistries) {
await client.discoverInto(this.cliConfig, targetRegistries);
}
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
@@ -421,13 +494,13 @@ export class McpClientManager {
const errorMessage = getErrorMessage(error);
this.emitDiagnostic(
'error',
`Error initializing MCP server '${name}': ${errorMessage}`,
`Fatal error ensuring MCP server '${name}' is connected: ${errorMessage}`,
error,
);
} finally {
resolve();
}
})().catch(reject);
})();
});
if (this.discoveryPromise) {
@@ -510,6 +583,11 @@ export class McpClientManager {
* Restarts all MCP servers (including newly enabled ones).
*/
async restart(): Promise<void> {
const disconnectionPromises = Array.from(this.clients.keys()).map((key) =>
this.disconnectClient(key, true),
);
await Promise.all(disconnectionPromises);
await Promise.all(
Array.from(this.allServerConfigs.entries()).map(
async ([name, config]) => {
@@ -534,6 +612,8 @@ export class McpClientManager {
if (!config) {
throw new Error(`No MCP server registered with the name "${name}"`);
}
const clientKey = this.getClientKey(name, config);
await this.disconnectClient(clientKey, true);
await this.maybeDiscoverMcpServer(name, config);
await this.scheduleMcpContextRefresh();
}
@@ -578,11 +658,12 @@ export class McpClientManager {
getMcpInstructions(): string {
const instructions: string[] = [];
for (const [name, client] of this.clients) {
for (const client of this.clients.values()) {
const serverName = client.getServerName();
const clientInstructions = client.getInstructions();
if (clientInstructions) {
instructions.push(
`The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
`The following are instructions provided by the tool server '${serverName}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
);
}
}
+177 -134
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -160,16 +161,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedClient.listTools).toHaveBeenCalledWith(
{},
expect.objectContaining({ timeout: 600000, progressReporter: client }),
@@ -244,16 +246,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
expect(consoleWarnSpy).not.toHaveBeenCalled();
consoleWarnSpy.mockRestore();
@@ -296,16 +299,19 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow('Test error');
await expect(
client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
}),
).rejects.toThrow('Test error');
expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
'error',
`Error discovering prompts from test-server: Test error`,
@@ -354,18 +360,19 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow(
'No prompts, tools, or resources found on the server.',
);
await expect(
client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
}),
).rejects.toThrow('No prompts, tools, or resources found on the server.');
});
it('should discover tools if server supports them', async () => {
@@ -417,16 +424,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
});
@@ -485,9 +493,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -495,7 +500,11 @@ describe('mcp-client', () => {
);
await client.connect();
await client.discover(mockConfig);
await client.discoverInto(mockConfig, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
// Verify tool registration
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
@@ -566,9 +575,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -576,7 +582,11 @@ describe('mcp-client', () => {
);
await client.connect();
await client.discover(mockConfig);
await client.discoverInto(mockConfig, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
@@ -644,9 +654,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -654,7 +661,11 @@ describe('mcp-client', () => {
);
await client.connect();
await client.discover(mockConfig);
await client.discoverInto(mockConfig, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
@@ -733,16 +744,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
.calls[0][0];
@@ -818,16 +830,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith(
'test-server',
[
@@ -907,16 +920,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(resourceListHandler).toBeDefined();
@@ -996,16 +1010,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry,
resourceRegistry,
});
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(promptListHandler).toBeDefined();
@@ -1080,16 +1095,17 @@ describe('mcp-client', () => {
{
command: 'test-command',
},
mockedToolRegistry,
mockedPromptRegistry,
resourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
await client.discover(MOCK_CONTEXT);
await client.discoverInto(MOCK_CONTEXT, {
toolRegistry: mockedToolRegistry,
promptRegistry: mockedPromptRegistry,
resourceRegistry,
});
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
@@ -1138,17 +1154,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1156,6 +1161,20 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
resourceRegistry: {
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
});
expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith(
ToolListChangedNotificationSchema,
@@ -1183,21 +1202,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
{
getToolsByServer: vi.fn().mockReturnValue([]),
registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1205,6 +1209,24 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: {
getToolsByServer: vi.fn().mockReturnValue([]),
registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
promptRegistry: {
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
resourceRegistry: {
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
});
// Should be called for ProgressNotificationSchema, even if no other capabilities
expect(mockedClient.setNotificationHandler).toHaveBeenCalled();
@@ -1234,21 +1256,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
{
getToolsByServer: vi.fn().mockReturnValue([]),
registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1256,6 +1263,24 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: {
getToolsByServer: vi.fn().mockReturnValue([]),
registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
promptRegistry: {
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
resourceRegistry: {
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
});
const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls.find(
@@ -1308,12 +1333,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
{} as PromptRegistry,
{
removeMcpResourcesByServer: vi.fn(),
registerResource: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1323,6 +1342,15 @@ describe('mcp-client', () => {
// 1. Connect (sets up listener)
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {} as PromptRegistry,
resourceRegistry: {
removeMcpResourcesByServer: vi.fn(),
registerResource: vi.fn(),
} as unknown as ResourceRegistry,
});
// 2. Extract the callback passed to setNotificationHandler for tools
const toolUpdateCall =
@@ -1388,9 +1416,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
{} as PromptRegistry,
{} as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1398,6 +1423,12 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {} as PromptRegistry,
resourceRegistry: {} as ResourceRegistry,
});
const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls.find(
@@ -1463,9 +1494,6 @@ describe('mcp-client', () => {
const clientA = new McpClient(
'server-A',
{ command: 'cmd-a' },
mockedToolRegistry,
{} as PromptRegistry,
{} as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1476,9 +1504,6 @@ describe('mcp-client', () => {
const clientB = new McpClient(
'server-B',
{ command: 'cmd-b' },
mockedToolRegistry,
{} as PromptRegistry,
{} as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1487,7 +1512,19 @@ describe('mcp-client', () => {
);
await clientA.connect();
// INJECTED REGISTRIES
(clientA as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {} as PromptRegistry,
resourceRegistry: {} as ResourceRegistry,
});
await clientB.connect();
// INJECTED REGISTRIES
(clientB as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {} as PromptRegistry,
resourceRegistry: {} as ResourceRegistry,
});
const toolUpdateCallA =
mockClientA.setNotificationHandler.mock.calls.find(
@@ -1572,18 +1609,6 @@ describe('mcp-client', () => {
'test-server',
// Set a very short timeout
{ command: 'test-command', timeout: 50 },
mockedToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1591,6 +1616,21 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
resourceRegistry: {
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
});
const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls.find(
@@ -1648,18 +1688,6 @@ describe('mcp-client', () => {
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
@@ -1668,6 +1696,21 @@ describe('mcp-client', () => {
);
await client.connect();
// INJECTED REGISTRIES
(client as any).registeredRegistries?.add({
toolRegistry: mockedToolRegistry,
promptRegistry: {
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
resourceRegistry: {
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
});
const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls.find(
+144 -95
View File
@@ -130,6 +130,12 @@ export interface McpProgressReporter {
unregisterProgressToken(token: string | number): void;
}
export interface RegistrySet {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
}
/**
* A client for a single MCP server.
*
@@ -147,6 +153,8 @@ export class McpClient implements McpProgressReporter {
private isRefreshingPrompts: boolean = false;
private pendingPromptRefresh: boolean = false;
private readonly registeredRegistries = new Set<RegistrySet>();
/**
* Map of progress tokens to tool call IDs.
* This allows us to route progress notifications to the correct tool call.
@@ -156,9 +164,6 @@ export class McpClient implements McpProgressReporter {
constructor(
private readonly serverName: string,
private readonly serverConfig: MCPServerConfig,
private readonly toolRegistry: ToolRegistry,
private readonly promptRegistry: PromptRegistry,
private readonly resourceRegistry: ResourceRegistry,
private readonly workspaceContext: WorkspaceContext,
private readonly cliConfig: McpContext,
private readonly debugMode: boolean,
@@ -166,6 +171,10 @@ export class McpClient implements McpProgressReporter {
private readonly onContextUpdated?: (signal?: AbortSignal) => Promise<void>,
) {}
getServerName(): string {
return this.serverName;
}
/**
* Connects to the MCP server.
*/
@@ -210,27 +219,34 @@ export class McpClient implements McpProgressReporter {
}
/**
* Discovers tools and prompts from the MCP server.
* Discovers tools and prompts from the MCP server into the specified registries.
*/
async discover(cliConfig: McpContext): Promise<void> {
async discoverInto(
cliConfig: McpContext,
registries: RegistrySet,
): Promise<void> {
this.assertConnected();
this.registeredRegistries.add(registries);
const prompts = await this.fetchPrompts();
const tools = await this.discoverTools(cliConfig);
const tools = await this.discoverTools(
cliConfig,
registries.toolRegistry.getMessageBus(),
);
const resources = await this.discoverResources();
this.updateResourceRegistry(resources);
this.updateResourceRegistry(resources, registries.resourceRegistry);
if (prompts.length === 0 && tools.length === 0 && resources.length === 0) {
throw new Error('No prompts, tools, or resources found on the server.');
}
for (const prompt of prompts) {
this.promptRegistry.registerPrompt(prompt);
registries.promptRegistry.registerPrompt(prompt);
}
for (const tool of tools) {
this.toolRegistry.registerTool(tool);
registries.toolRegistry.registerTool(tool);
}
this.toolRegistry.sortTools();
registries.toolRegistry.sortTools();
// Validate MCP tool names in policy rules against discovered tools
try {
@@ -250,6 +266,14 @@ export class McpClient implements McpProgressReporter {
}
}
/**
* Unregisters registries so this client will no longer update them when it receives
* list_changed notifications from the server.
*/
removeRegistries(registries: RegistrySet): void {
this.registeredRegistries.delete(registries);
}
/**
* Disconnects from the MCP server.
*/
@@ -257,9 +281,11 @@ export class McpClient implements McpProgressReporter {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
this.promptRegistry.removePromptsByServer(this.serverName);
this.resourceRegistry.removeResourcesByServer(this.serverName);
for (const registries of this.registeredRegistries) {
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
registries.promptRegistry.removePromptsByServer(this.serverName);
registries.resourceRegistry.removeResourcesByServer(this.serverName);
}
this.updateStatus(MCPServerStatus.DISCONNECTING);
const client = this.client;
this.client = undefined;
@@ -294,6 +320,7 @@ export class McpClient implements McpProgressReporter {
private async discoverTools(
cliConfig: McpContext,
messageBus: MessageBus,
options?: { timeout?: number; signal?: AbortSignal },
): Promise<DiscoveredMCPTool[]> {
this.assertConnected();
@@ -302,7 +329,7 @@ export class McpClient implements McpProgressReporter {
this.serverConfig,
this.client!,
cliConfig,
this.toolRegistry.messageBus,
messageBus,
{
...(options ?? {
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
@@ -329,8 +356,11 @@ export class McpClient implements McpProgressReporter {
return discoverResources(this.serverName, this.client!, this.cliConfig);
}
private updateResourceRegistry(resources: Resource[]): void {
this.resourceRegistry.setResourcesForServer(this.serverName, resources);
private updateResourceRegistry(
resources: Resource[],
resourceRegistry: ResourceRegistry,
): void {
resourceRegistry.setResourcesForServer(this.serverName, resources);
}
async readResource(
@@ -482,23 +512,32 @@ export class McpClient implements McpProgressReporter {
try {
newResources = await this.discoverResources();
// Verification Retry: If no resources are found or resources didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentResources =
this.resourceRegistry.getResourcesByServer(this.serverName) || [];
const resourceMatch =
newResources.length === currentResources.length &&
newResources.every((nr: Resource) =>
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
);
for (const registries of this.registeredRegistries) {
// Verification Retry: If no resources are found or resources didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentResources =
registries.resourceRegistry.getResourcesByServer(
this.serverName,
) || [];
const resourceMatch =
newResources.length === currentResources.length &&
newResources.every((nr: Resource) =>
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
);
if (resourceMatch && !this.pendingResourceRefresh) {
debugLogger.log(
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
if (resourceMatch && !this.pendingResourceRefresh) {
debugLogger.log(
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newResources = await this.discoverResources();
}
this.updateResourceRegistry(
newResources,
registries.resourceRegistry,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newResources = await this.discoverResources();
}
} catch (err) {
debugLogger.error(
@@ -508,8 +547,6 @@ export class McpClient implements McpProgressReporter {
break;
}
this.updateResourceRegistry(newResources);
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
@@ -575,30 +612,33 @@ export class McpClient implements McpProgressReporter {
signal: abortController.signal,
});
// Verification Retry: If no prompts are found or prompts didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentPrompts =
this.promptRegistry.getPromptsByServer(this.serverName) || [];
const promptsMatch =
newPrompts.length === currentPrompts.length &&
newPrompts.every((np) =>
currentPrompts.some((cp) => cp.name === np.name),
);
for (const registries of this.registeredRegistries) {
// Verification Retry: If no prompts are found or prompts didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentPrompts =
registries.promptRegistry.getPromptsByServer(this.serverName) ||
[];
const promptsMatch =
newPrompts.length === currentPrompts.length &&
newPrompts.every((np) =>
currentPrompts.some((cp) => cp.name === np.name),
);
if (promptsMatch && !this.pendingPromptRefresh) {
debugLogger.log(
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
}
if (promptsMatch && !this.pendingPromptRefresh) {
debugLogger.log(
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
}
this.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) {
this.promptRegistry.registerPrompt(prompt);
registries.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) {
registries.promptRegistry.registerPrompt(prompt);
}
}
} catch (err) {
debugLogger.error(
@@ -666,42 +706,58 @@ export class McpClient implements McpProgressReporter {
const abortController = new AbortController();
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
let newTools;
try {
newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal,
});
debugLogger.log(
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
// Verification Retry (Option 3): If no tools are found or tools didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentTools =
this.toolRegistry.getToolsByServer(this.serverName) || [];
const toolNamesMatch =
newTools.length === currentTools.length &&
newTools.every((nt) =>
currentTools.some(
(ct) =>
ct.name === nt.name ||
(ct instanceof DiscoveredMCPTool &&
ct.serverToolName === nt.serverToolName),
),
for (const registries of this.registeredRegistries) {
let newTools = await this.discoverTools(
this.cliConfig,
registries.toolRegistry.getMessageBus(),
{
signal: abortController.signal,
},
);
debugLogger.log(
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
if (toolNamesMatch && !this.pendingToolRefresh) {
debugLogger.log(
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal,
});
debugLogger.log(
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
// Verification Retry (Option 3): If no tools are found or tools didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentTools =
registries.toolRegistry.getToolsByServer(this.serverName) || [];
const toolNamesMatch =
newTools.length === currentTools.length &&
newTools.every((nt) =>
currentTools.some(
(ct) =>
ct.name === nt.name ||
(ct instanceof DiscoveredMCPTool &&
ct.serverToolName === nt.serverToolName),
),
);
if (toolNamesMatch && !this.pendingToolRefresh) {
debugLogger.log(
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newTools = await this.discoverTools(
this.cliConfig,
registries.toolRegistry.getMessageBus(),
{
signal: abortController.signal,
},
);
debugLogger.log(
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
}
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
for (const tool of newTools) {
registries.toolRegistry.registerTool(tool);
}
registries.toolRegistry.sortTools();
}
} catch (err) {
debugLogger.error(
@@ -711,13 +767,6 @@ export class McpClient implements McpProgressReporter {
break;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
for (const tool of newTools) {
this.toolRegistry.registerTool(tool);
}
this.toolRegistry.sortTools();
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
@@ -284,6 +284,26 @@ describe('ToolRegistry', () => {
});
});
describe('removeMcpToolsByServer', () => {
it('should remove all tools from a specific server', () => {
const serverName = 'test-server';
const mcpTool1 = createMCPTool(serverName, 'tool1', 'desc1');
const mcpTool2 = createMCPTool(serverName, 'tool2', 'desc2');
const otherTool = createMCPTool('other-server', 'tool3', 'desc3');
toolRegistry.registerTool(mcpTool1);
toolRegistry.registerTool(mcpTool2);
toolRegistry.registerTool(otherTool);
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(2);
toolRegistry.removeMcpToolsByServer(serverName);
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(0);
expect(toolRegistry.getToolsByServer('other-server')).toHaveLength(1);
});
});
describe('excluded tools', () => {
const simpleTool = new MockTool({
name: 'tool-a',
+21 -1
View File
@@ -223,10 +223,16 @@ export class ToolRegistry {
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
private config: Config;
readonly messageBus: MessageBus;
private isMainRegistry: boolean;
constructor(config: Config, messageBus: MessageBus) {
constructor(
config: Config,
messageBus: MessageBus,
isMainRegistry: boolean = false,
) {
this.config = config;
this.messageBus = messageBus;
this.isMainRegistry = isMainRegistry;
}
getMessageBus(): MessageBus {
@@ -599,6 +605,10 @@ export class ToolRegistry {
const declarations: FunctionDeclaration[] = [];
const seenNames = new Set<string>();
const mainAgentTools = this.isMainRegistry
? this.config.getMainAgentTools()
: undefined;
this.getActiveTools().forEach((tool) => {
const toolName =
tool instanceof DiscoveredMCPTool
@@ -608,6 +618,16 @@ export class ToolRegistry {
if (seenNames.has(toolName)) {
return;
}
if (
mainAgentTools &&
!mainAgentTools.includes(toolName) &&
!mainAgentTools.includes(tool.constructor.name) &&
!mainAgentTools.some((t) => t.startsWith(`${tool.constructor.name}(`))
) {
return;
}
seenNames.add(toolName);
let schema = tool.getSchema(modelId);