mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-02 07:54:48 -07:00
feat(core): multi-registry architecture and tool filtering for subagents (#22712)
This commit is contained in:
@@ -14,9 +14,11 @@ import {
|
||||
type MockedObject,
|
||||
} from 'vitest';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient, MCPDiscoveryState } from './mcp-client.js';
|
||||
import { McpClient, MCPDiscoveryState, MCPServerStatus } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { Config, GeminiCLIExtension } from '../config/config.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
@@ -34,21 +36,25 @@ describe('McpClientManager', () => {
|
||||
beforeEach(() => {
|
||||
mockedMcpClient = vi.mockObject({
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
discoverInto: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
getStatus: vi.fn().mockReturnValue(MCPServerStatus.DISCONNECTED),
|
||||
getServerConfig: vi.fn(),
|
||||
getServerName: vi.fn().mockReturnValue('test-server'),
|
||||
} as unknown as McpClient);
|
||||
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
|
||||
mockConfig = vi.mockObject({
|
||||
isTrustedFolder: vi.fn().mockReturnValue(true),
|
||||
getMcpServers: vi.fn().mockReturnValue({}),
|
||||
getPromptRegistry: () => {},
|
||||
getResourceRegistry: () => {},
|
||||
getPromptRegistry: vi.fn().mockReturnValue({ registerPrompt: vi.fn() }),
|
||||
getResourceRegistry: vi
|
||||
.fn()
|
||||
.mockReturnValue({ setResourcesForServer: vi.fn() }),
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getWorkspaceContext: () => ({ getDirectories: () => [] }),
|
||||
getAllowedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getExcludedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(''),
|
||||
getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
@@ -56,21 +62,39 @@ describe('McpClientManager', () => {
|
||||
}),
|
||||
refreshMcpContext: vi.fn(),
|
||||
} as unknown as Config);
|
||||
toolRegistry = {} as ToolRegistry;
|
||||
toolRegistry = vi.mockObject({
|
||||
registerTool: vi.fn(),
|
||||
unregisterTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue({}),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
} as unknown as ToolRegistry);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
const setupManager = (manager: McpClientManager) => {
|
||||
manager.setMainRegistries({
|
||||
toolRegistry,
|
||||
promptRegistry:
|
||||
mockConfig.getPromptRegistry() as unknown as PromptRegistry,
|
||||
resourceRegistry:
|
||||
mockConfig.getResourceRegistry() as unknown as ResourceRegistry,
|
||||
});
|
||||
return manager;
|
||||
};
|
||||
|
||||
it('should discover tools from all configured', async () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
|
||||
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
@@ -80,12 +104,12 @@ describe('McpClientManager', () => {
|
||||
'server-2': { command: 'node' },
|
||||
'server-3': { command: 'node' },
|
||||
});
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
// Each client should be connected/discovered
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3);
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(3);
|
||||
|
||||
// But context refresh should happen only once
|
||||
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
|
||||
@@ -95,7 +119,7 @@ describe('McpClientManager', () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED);
|
||||
const promise = manager.startConfiguredMcpServers();
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
|
||||
@@ -112,7 +136,7 @@ describe('McpClientManager', () => {
|
||||
isFileEnabled: vi.fn().mockResolvedValue(false),
|
||||
});
|
||||
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const promise = manager.startConfiguredMcpServers();
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
|
||||
await promise;
|
||||
@@ -120,7 +144,7 @@ describe('McpClientManager', () => {
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
|
||||
expect(manager.getMcpServerCount()).toBe(0);
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should mark discovery completed when all configured servers are blocked', async () => {
|
||||
@@ -129,7 +153,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const promise = manager.startConfiguredMcpServers();
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
|
||||
await promise;
|
||||
@@ -137,7 +161,7 @@ describe('McpClientManager', () => {
|
||||
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
|
||||
expect(manager.getMcpServerCount()).toBe(0);
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not discover tools if folder is not trusted', async () => {
|
||||
@@ -145,10 +169,10 @@ describe('McpClientManager', () => {
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
mockConfig.isTrustedFolder.mockReturnValue(false);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not start blocked servers', async () => {
|
||||
@@ -156,10 +180,10 @@ describe('McpClientManager', () => {
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should only start allowed servers if allow list is not empty', async () => {
|
||||
@@ -168,14 +192,14 @@ describe('McpClientManager', () => {
|
||||
'another-server': { command: 'node' },
|
||||
});
|
||||
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should start servers from extensions', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
@@ -188,11 +212,11 @@ describe('McpClientManager', () => {
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not start servers from disabled extensions', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startExtension({
|
||||
name: 'test-extension',
|
||||
mcpServers: {
|
||||
@@ -205,7 +229,7 @@ describe('McpClientManager', () => {
|
||||
id: '123',
|
||||
});
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should add blocked servers to the blockedMcpServers list', async () => {
|
||||
@@ -213,7 +237,7 @@ describe('McpClientManager', () => {
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(manager.getBlockedMcpServers()).toEqual([
|
||||
{ name: 'test-server', extensionName: '' },
|
||||
@@ -224,10 +248,10 @@ describe('McpClientManager', () => {
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'test-server': { excludeTools: ['dangerous_tool'] },
|
||||
});
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
|
||||
|
||||
// But it should still be tracked in allServerConfigs
|
||||
expect(manager.getMcpServers()).toHaveProperty('test-server');
|
||||
@@ -240,16 +264,16 @@ describe('McpClientManager', () => {
|
||||
'test-server': serverConfig,
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue(serverConfig);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
|
||||
await manager.restart();
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -260,21 +284,21 @@ describe('McpClientManager', () => {
|
||||
'test-server': serverConfig,
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue(serverConfig);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
|
||||
|
||||
await manager.restartServer('test-server');
|
||||
|
||||
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
|
||||
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw an error if the server does not exist', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await expect(manager.restartServer('non-existent')).rejects.toThrow(
|
||||
'No MCP server registered with the name "non-existent"',
|
||||
);
|
||||
@@ -296,7 +320,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
mockedMcpClient.getServerConfig.mockReturnValue(originalConfig);
|
||||
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
// First call should use the original config
|
||||
@@ -321,9 +345,10 @@ describe('McpClientManager', () => {
|
||||
(name, config) =>
|
||||
({
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
discoverInto: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getServerConfig: vi.fn().mockReturnValue(config),
|
||||
getServerName: vi.fn().mockReturnValue(name),
|
||||
getInstructions: vi
|
||||
.fn()
|
||||
.mockReturnValue(
|
||||
@@ -333,12 +358,7 @@ describe('McpClientManager', () => {
|
||||
),
|
||||
}) as unknown as McpClient,
|
||||
);
|
||||
|
||||
const manager = new McpClientManager(
|
||||
'0.0.1',
|
||||
{} as ToolRegistry,
|
||||
mockConfig,
|
||||
);
|
||||
const manager = new McpClientManager('0.0.1', mockConfig);
|
||||
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
'server-with-instructions': { command: 'node' },
|
||||
@@ -373,11 +393,7 @@ describe('McpClientManager', () => {
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
|
||||
const manager = new McpClientManager(
|
||||
'0.0.1',
|
||||
{} as ToolRegistry,
|
||||
mockConfig,
|
||||
);
|
||||
const manager = new McpClientManager('0.0.1', mockConfig);
|
||||
|
||||
await expect(manager.startConfiguredMcpServers()).resolves.not.toThrow();
|
||||
});
|
||||
@@ -396,11 +412,8 @@ describe('McpClientManager', () => {
|
||||
'test-server': { command: 'node' },
|
||||
});
|
||||
|
||||
const manager = new McpClientManager(
|
||||
'0.0.1',
|
||||
{} as ToolRegistry,
|
||||
mockConfig,
|
||||
);
|
||||
const manager = new McpClientManager('0.0.1', mockConfig);
|
||||
|
||||
await manager.startConfiguredMcpServers();
|
||||
|
||||
await expect(manager.restartServer('test-server')).resolves.not.toThrow();
|
||||
@@ -409,7 +422,7 @@ describe('McpClientManager', () => {
|
||||
|
||||
describe('Extension handling', () => {
|
||||
it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const mcpServers = {
|
||||
'test-server': { command: 'node', args: ['server.js'] },
|
||||
};
|
||||
@@ -431,7 +444,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should merge extension configuration with an existing user-configured server', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const userConfig = { command: 'node', args: ['user-server.js'] };
|
||||
|
||||
mockConfig.getMcpServers.mockReturnValue({
|
||||
@@ -468,7 +481,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should securely merge tool lists and env variables regardless of load order', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
const userConfig = {
|
||||
excludeTools: ['user-tool'],
|
||||
@@ -523,7 +536,7 @@ describe('McpClientManager', () => {
|
||||
|
||||
// Reset for Case 2
|
||||
vi.mocked(McpClient).mockClear();
|
||||
const manager2 = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager2 = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
// Case 2: User config loads first, then Extension loads
|
||||
// This call will skip discovery because userConfig has no connection details
|
||||
@@ -551,7 +564,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should result in empty includeTools if intersection is empty', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const userConfig = { includeTools: ['user-tool'] };
|
||||
const extConfig = {
|
||||
command: 'node',
|
||||
@@ -567,7 +580,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should respect a single allowlist if only one is provided', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const userConfig = { includeTools: ['user-tool'] };
|
||||
const extConfig = { command: 'node', args: ['ext.js'] };
|
||||
|
||||
@@ -579,7 +592,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should allow partial overrides of connection properties', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const extConfig = { command: 'node', args: ['ext.js'], timeout: 1000 };
|
||||
const userOverride = { args: ['overridden.js'] };
|
||||
|
||||
@@ -599,7 +612,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should prevent one extension from hijacking another extension server name', async () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
const extension1: GeminiCLIExtension = {
|
||||
name: 'extension-1',
|
||||
@@ -641,7 +654,7 @@ describe('McpClientManager', () => {
|
||||
|
||||
it('should remove servers from blockedMcpServers when stopExtension is called', async () => {
|
||||
mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']);
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
const mcpServers = {
|
||||
'blocked-server': { command: 'node', args: ['server.js'] },
|
||||
};
|
||||
@@ -679,7 +692,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should emit hint instead of full error when user has not interacted with MCP', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
manager.emitDiagnostic(
|
||||
'error',
|
||||
'Something went wrong',
|
||||
@@ -698,7 +711,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should emit full error when user has interacted with MCP', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
manager.setUserInteractedWithMcp();
|
||||
manager.emitDiagnostic(
|
||||
'error',
|
||||
@@ -714,7 +727,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should still deduplicate diagnostic messages after user interaction', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
manager.setUserInteractedWithMcp();
|
||||
|
||||
manager.emitDiagnostic('error', 'Same error');
|
||||
@@ -724,7 +737,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should only show hint once per session', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
manager.emitDiagnostic('error', 'Error 1');
|
||||
manager.emitDiagnostic('error', 'Error 2');
|
||||
@@ -737,7 +750,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should capture last error for a server even when silenced', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
manager.emitDiagnostic(
|
||||
'error',
|
||||
@@ -752,7 +765,7 @@ describe('McpClientManager', () => {
|
||||
});
|
||||
|
||||
it('should show previously deduplicated errors after interaction clears state', () => {
|
||||
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
|
||||
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
|
||||
|
||||
manager.emitDiagnostic('error', 'Same error');
|
||||
expect(coreEventsMock.emitFeedback).toHaveBeenCalledTimes(1); // The hint
|
||||
|
||||
@@ -13,6 +13,7 @@ import type { ToolRegistry } from './tool-registry.js';
|
||||
import {
|
||||
McpClient,
|
||||
MCPDiscoveryState,
|
||||
MCPServerStatus,
|
||||
populateMcpServerCommand,
|
||||
} from './mcp-client.js';
|
||||
import { getErrorMessage, isAuthenticationError } from '../utils/errors.js';
|
||||
@@ -20,6 +21,11 @@ import type { EventEmitter } from 'node:events';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
import { createHash } from 'node:crypto';
|
||||
import { stableStringify } from '../policy/stable-stringify.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
|
||||
/**
|
||||
* Manages the lifecycle of multiple MCP clients, including local child processes.
|
||||
* This class is responsible for starting, stopping, and discovering tools from
|
||||
@@ -30,7 +36,6 @@ export class McpClientManager {
|
||||
// Track all configured servers (including disabled ones) for UI display
|
||||
private allServerConfigs: Map<string, MCPServerConfig> = new Map();
|
||||
private readonly clientVersion: string;
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly cliConfig: Config;
|
||||
// If we have ongoing MCP client discovery, this completes once that is done.
|
||||
private discoveryPromise: Promise<void> | undefined;
|
||||
@@ -42,6 +47,10 @@ export class McpClientManager {
|
||||
extensionName: string;
|
||||
}> = [];
|
||||
|
||||
private mainToolRegistry: ToolRegistry | undefined;
|
||||
private mainPromptRegistry: PromptRegistry | undefined;
|
||||
private mainResourceRegistry: ResourceRegistry | undefined;
|
||||
|
||||
/**
|
||||
* Track whether the user has explicitly interacted with MCP in this session
|
||||
* (e.g. by running an /mcp command).
|
||||
@@ -66,16 +75,24 @@ export class McpClientManager {
|
||||
|
||||
constructor(
|
||||
clientVersion: string,
|
||||
toolRegistry: ToolRegistry,
|
||||
cliConfig: Config,
|
||||
eventEmitter?: EventEmitter,
|
||||
) {
|
||||
this.clientVersion = clientVersion;
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.cliConfig = cliConfig;
|
||||
this.eventEmitter = eventEmitter;
|
||||
}
|
||||
|
||||
setMainRegistries(registries: {
|
||||
toolRegistry: ToolRegistry;
|
||||
promptRegistry: PromptRegistry;
|
||||
resourceRegistry: ResourceRegistry;
|
||||
}) {
|
||||
this.mainToolRegistry = registries.toolRegistry;
|
||||
this.mainPromptRegistry = registries.promptRegistry;
|
||||
this.mainResourceRegistry = registries.resourceRegistry;
|
||||
}
|
||||
|
||||
setUserInteractedWithMcp() {
|
||||
this.userInteractedWithMcp = true;
|
||||
}
|
||||
@@ -147,6 +164,16 @@ export class McpClientManager {
|
||||
return this.clients.get(serverName);
|
||||
}
|
||||
|
||||
removeRegistries(registries: {
|
||||
toolRegistry: ToolRegistry;
|
||||
promptRegistry: PromptRegistry;
|
||||
resourceRegistry: ResourceRegistry;
|
||||
}): void {
|
||||
for (const client of this.clients.values()) {
|
||||
client.removeRegistries(registries);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For all the MCP servers associated with this extension:
|
||||
*
|
||||
@@ -236,16 +263,17 @@ export class McpClientManager {
|
||||
return false;
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string, skipRefresh = false) {
|
||||
const existing = this.clients.get(name);
|
||||
private async disconnectClient(clientKey: string, skipRefresh = false) {
|
||||
const existing = this.clients.get(clientKey);
|
||||
if (existing) {
|
||||
const serverName = existing.getServerName();
|
||||
try {
|
||||
this.clients.delete(name);
|
||||
this.clients.delete(clientKey);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await existing.disconnect();
|
||||
} catch (error) {
|
||||
debugLogger.warn(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
`Error stopping client '${serverName}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
} finally {
|
||||
if (!skipRefresh) {
|
||||
@@ -257,6 +285,16 @@ export class McpClientManager {
|
||||
}
|
||||
}
|
||||
|
||||
private getClientKey(name: string, config: MCPServerConfig): string {
|
||||
const { extension, ...rest } = config;
|
||||
const keyData = {
|
||||
name,
|
||||
config: rest,
|
||||
extensionId: extension?.id,
|
||||
};
|
||||
return createHash('sha256').update(stableStringify(keyData)).digest('hex');
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges two MCP configurations. The second configuration (override)
|
||||
* takes precedence for scalar properties, but array properties are
|
||||
@@ -305,6 +343,11 @@ export class McpClientManager {
|
||||
async maybeDiscoverMcpServer(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
registries?: {
|
||||
toolRegistry: ToolRegistry;
|
||||
promptRegistry: PromptRegistry;
|
||||
resourceRegistry: ResourceRegistry;
|
||||
},
|
||||
): Promise<void> {
|
||||
const existingConfig = this.allServerConfigs.get(name);
|
||||
if (
|
||||
@@ -337,11 +380,27 @@ export class McpClientManager {
|
||||
// Always track server config for UI display
|
||||
this.allServerConfigs.set(name, finalConfig);
|
||||
|
||||
// Capture the existing client synchronously here before any asynchronous
|
||||
// operations. This ensures that if multiple discovery turns happen
|
||||
// concurrently, this turn only replaces/disconnects the client that was
|
||||
// present when this specific configuration update request began.
|
||||
const existing = this.clients.get(name);
|
||||
const clientKey = this.getClientKey(name, finalConfig);
|
||||
|
||||
// If no registries are provided (main agent) and a server with this name already exists
|
||||
// but with a different configuration, handle potential conflicts.
|
||||
if (!registries) {
|
||||
const existingSameName = Array.from(this.clients.values()).find(
|
||||
(c) => c.getServerName() === name,
|
||||
);
|
||||
if (existingSameName) {
|
||||
const existingConfigFromClient = existingSameName.getServerConfig();
|
||||
const existingKey = this.getClientKey(name, existingConfigFromClient);
|
||||
|
||||
if (existingKey !== clientKey) {
|
||||
// This is a configuration update (hot-reload).
|
||||
// We should stop the old client before starting the new one.
|
||||
await this.disconnectClient(existingKey, true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const existing = this.clients.get(clientKey);
|
||||
|
||||
// If no connection details are provided, we can't discover this server.
|
||||
// This often happens when a user provides only overrides (like excludeTools)
|
||||
@@ -363,7 +422,7 @@ export class McpClientManager {
|
||||
// User-disabled servers: disconnect if running, don't start
|
||||
if (await this.isDisabledByUser(name)) {
|
||||
if (existing) {
|
||||
await this.disconnectClient(name);
|
||||
await this.disconnectClient(clientKey);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -374,34 +433,48 @@ export class McpClientManager {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve, reject) => {
|
||||
(async () => {
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve) => {
|
||||
void (async () => {
|
||||
try {
|
||||
if (existing) {
|
||||
this.clients.delete(name);
|
||||
await existing.disconnect();
|
||||
let client = existing;
|
||||
if (!client) {
|
||||
client = new McpClient(
|
||||
name,
|
||||
finalConfig,
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig,
|
||||
this.cliConfig.getDebugMode(),
|
||||
this.clientVersion,
|
||||
async () => {
|
||||
debugLogger.log(
|
||||
`🔔 Refreshing context for server '${name}'...`,
|
||||
);
|
||||
await this.scheduleMcpContextRefresh();
|
||||
},
|
||||
);
|
||||
this.clients.set(clientKey, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
}
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
finalConfig,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getResourceRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig,
|
||||
this.cliConfig.getDebugMode(),
|
||||
this.clientVersion,
|
||||
async () => {
|
||||
debugLogger.log(`🔔 Refreshing context for server '${name}'...`);
|
||||
await this.scheduleMcpContextRefresh();
|
||||
},
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const targetRegistries =
|
||||
registries ??
|
||||
(this.mainToolRegistry &&
|
||||
this.mainPromptRegistry &&
|
||||
this.mainResourceRegistry
|
||||
? {
|
||||
toolRegistry: this.mainToolRegistry,
|
||||
promptRegistry: this.mainPromptRegistry,
|
||||
resourceRegistry: this.mainResourceRegistry,
|
||||
}
|
||||
: undefined);
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(this.cliConfig);
|
||||
if (client.getStatus() === MCPServerStatus.DISCONNECTED) {
|
||||
await client.connect();
|
||||
}
|
||||
if (targetRegistries) {
|
||||
await client.discoverInto(this.cliConfig, targetRegistries);
|
||||
}
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
@@ -421,13 +494,13 @@ export class McpClientManager {
|
||||
const errorMessage = getErrorMessage(error);
|
||||
this.emitDiagnostic(
|
||||
'error',
|
||||
`Error initializing MCP server '${name}': ${errorMessage}`,
|
||||
`Fatal error ensuring MCP server '${name}' is connected: ${errorMessage}`,
|
||||
error,
|
||||
);
|
||||
} finally {
|
||||
resolve();
|
||||
}
|
||||
})().catch(reject);
|
||||
})();
|
||||
});
|
||||
|
||||
if (this.discoveryPromise) {
|
||||
@@ -510,6 +583,11 @@ export class McpClientManager {
|
||||
* Restarts all MCP servers (including newly enabled ones).
|
||||
*/
|
||||
async restart(): Promise<void> {
|
||||
const disconnectionPromises = Array.from(this.clients.keys()).map((key) =>
|
||||
this.disconnectClient(key, true),
|
||||
);
|
||||
await Promise.all(disconnectionPromises);
|
||||
|
||||
await Promise.all(
|
||||
Array.from(this.allServerConfigs.entries()).map(
|
||||
async ([name, config]) => {
|
||||
@@ -534,6 +612,8 @@ export class McpClientManager {
|
||||
if (!config) {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
const clientKey = this.getClientKey(name, config);
|
||||
await this.disconnectClient(clientKey, true);
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
@@ -578,11 +658,12 @@ export class McpClientManager {
|
||||
|
||||
getMcpInstructions(): string {
|
||||
const instructions: string[] = [];
|
||||
for (const [name, client] of this.clients) {
|
||||
for (const client of this.clients.values()) {
|
||||
const serverName = client.getServerName();
|
||||
const clientInstructions = client.getInstructions();
|
||||
if (clientInstructions) {
|
||||
instructions.push(
|
||||
`The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
|
||||
`The following are instructions provided by the tool server '${serverName}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
@@ -160,16 +161,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||
{},
|
||||
expect.objectContaining({ timeout: 600000, progressReporter: client }),
|
||||
@@ -244,16 +246,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
||||
expect(consoleWarnSpy).not.toHaveBeenCalled();
|
||||
consoleWarnSpy.mockRestore();
|
||||
@@ -296,16 +299,19 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow('Test error');
|
||||
await expect(
|
||||
client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
}),
|
||||
).rejects.toThrow('Test error');
|
||||
expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
|
||||
'error',
|
||||
`Error discovering prompts from test-server: Test error`,
|
||||
@@ -354,18 +360,19 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await expect(client.discover(MOCK_CONTEXT)).rejects.toThrow(
|
||||
'No prompts, tools, or resources found on the server.',
|
||||
);
|
||||
await expect(
|
||||
client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
}),
|
||||
).rejects.toThrow('No prompts, tools, or resources found on the server.');
|
||||
});
|
||||
|
||||
it('should discover tools if server supports them', async () => {
|
||||
@@ -417,16 +424,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
@@ -485,9 +493,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -495,7 +500,11 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
await client.discoverInto(mockConfig, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
// Verify tool registration
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
@@ -566,9 +575,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -576,7 +582,11 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
await client.discoverInto(mockConfig, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
|
||||
@@ -644,9 +654,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -654,7 +661,11 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
await client.discoverInto(mockConfig, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
|
||||
@@ -733,16 +744,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
|
||||
.calls[0][0];
|
||||
@@ -818,16 +830,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
expect(resourceRegistry.setResourcesForServer).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
[
|
||||
@@ -907,16 +920,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
|
||||
expect(resourceListHandler).toBeDefined();
|
||||
@@ -996,16 +1010,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
|
||||
expect(promptListHandler).toBeDefined();
|
||||
@@ -1080,16 +1095,17 @@ describe('mcp-client', () => {
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
mockedPromptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover(MOCK_CONTEXT);
|
||||
await client.discoverInto(MOCK_CONTEXT, {
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: mockedPromptRegistry,
|
||||
resourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockedPromptRegistry.registerPrompt).toHaveBeenCalledOnce();
|
||||
@@ -1138,17 +1154,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1156,6 +1161,20 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith(
|
||||
ToolListChangedNotificationSchema,
|
||||
@@ -1183,21 +1202,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1205,6 +1209,24 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: {
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
promptRegistry: {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
// Should be called for ProgressNotificationSchema, even if no other capabilities
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalled();
|
||||
@@ -1234,21 +1256,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1256,6 +1263,24 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: {
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
promptRegistry: {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
@@ -1308,12 +1333,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{
|
||||
removeMcpResourcesByServer: vi.fn(),
|
||||
registerResource: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1323,6 +1342,15 @@ describe('mcp-client', () => {
|
||||
|
||||
// 1. Connect (sets up listener)
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {} as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
removeMcpResourcesByServer: vi.fn(),
|
||||
registerResource: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
// 2. Extract the callback passed to setNotificationHandler for tools
|
||||
const toolUpdateCall =
|
||||
@@ -1388,9 +1416,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1398,6 +1423,12 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {} as PromptRegistry,
|
||||
resourceRegistry: {} as ResourceRegistry,
|
||||
});
|
||||
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
@@ -1463,9 +1494,6 @@ describe('mcp-client', () => {
|
||||
const clientA = new McpClient(
|
||||
'server-A',
|
||||
{ command: 'cmd-a' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1476,9 +1504,6 @@ describe('mcp-client', () => {
|
||||
const clientB = new McpClient(
|
||||
'server-B',
|
||||
{ command: 'cmd-b' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1487,7 +1512,19 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await clientA.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(clientA as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {} as PromptRegistry,
|
||||
resourceRegistry: {} as ResourceRegistry,
|
||||
});
|
||||
await clientB.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(clientB as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {} as PromptRegistry,
|
||||
resourceRegistry: {} as ResourceRegistry,
|
||||
});
|
||||
|
||||
const toolUpdateCallA =
|
||||
mockClientA.setNotificationHandler.mock.calls.find(
|
||||
@@ -1572,18 +1609,6 @@ describe('mcp-client', () => {
|
||||
'test-server',
|
||||
// Set a very short timeout
|
||||
{ command: 'test-command', timeout: 50 },
|
||||
mockedToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1591,6 +1616,21 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
@@ -1648,18 +1688,6 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1668,6 +1696,21 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
// INJECTED REGISTRIES
|
||||
(client as any).registeredRegistries?.add({
|
||||
toolRegistry: mockedToolRegistry,
|
||||
promptRegistry: {
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
resourceRegistry: {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
});
|
||||
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
|
||||
@@ -130,6 +130,12 @@ export interface McpProgressReporter {
|
||||
unregisterProgressToken(token: string | number): void;
|
||||
}
|
||||
|
||||
export interface RegistrySet {
|
||||
toolRegistry: ToolRegistry;
|
||||
promptRegistry: PromptRegistry;
|
||||
resourceRegistry: ResourceRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* A client for a single MCP server.
|
||||
*
|
||||
@@ -147,6 +153,8 @@ export class McpClient implements McpProgressReporter {
|
||||
private isRefreshingPrompts: boolean = false;
|
||||
private pendingPromptRefresh: boolean = false;
|
||||
|
||||
private readonly registeredRegistries = new Set<RegistrySet>();
|
||||
|
||||
/**
|
||||
* Map of progress tokens to tool call IDs.
|
||||
* This allows us to route progress notifications to the correct tool call.
|
||||
@@ -156,9 +164,6 @@ export class McpClient implements McpProgressReporter {
|
||||
constructor(
|
||||
private readonly serverName: string,
|
||||
private readonly serverConfig: MCPServerConfig,
|
||||
private readonly toolRegistry: ToolRegistry,
|
||||
private readonly promptRegistry: PromptRegistry,
|
||||
private readonly resourceRegistry: ResourceRegistry,
|
||||
private readonly workspaceContext: WorkspaceContext,
|
||||
private readonly cliConfig: McpContext,
|
||||
private readonly debugMode: boolean,
|
||||
@@ -166,6 +171,10 @@ export class McpClient implements McpProgressReporter {
|
||||
private readonly onContextUpdated?: (signal?: AbortSignal) => Promise<void>,
|
||||
) {}
|
||||
|
||||
getServerName(): string {
|
||||
return this.serverName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connects to the MCP server.
|
||||
*/
|
||||
@@ -210,27 +219,34 @@ export class McpClient implements McpProgressReporter {
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers tools and prompts from the MCP server.
|
||||
* Discovers tools and prompts from the MCP server into the specified registries.
|
||||
*/
|
||||
async discover(cliConfig: McpContext): Promise<void> {
|
||||
async discoverInto(
|
||||
cliConfig: McpContext,
|
||||
registries: RegistrySet,
|
||||
): Promise<void> {
|
||||
this.assertConnected();
|
||||
this.registeredRegistries.add(registries);
|
||||
|
||||
const prompts = await this.fetchPrompts();
|
||||
const tools = await this.discoverTools(cliConfig);
|
||||
const tools = await this.discoverTools(
|
||||
cliConfig,
|
||||
registries.toolRegistry.getMessageBus(),
|
||||
);
|
||||
const resources = await this.discoverResources();
|
||||
this.updateResourceRegistry(resources);
|
||||
this.updateResourceRegistry(resources, registries.resourceRegistry);
|
||||
|
||||
if (prompts.length === 0 && tools.length === 0 && resources.length === 0) {
|
||||
throw new Error('No prompts, tools, or resources found on the server.');
|
||||
}
|
||||
|
||||
for (const prompt of prompts) {
|
||||
this.promptRegistry.registerPrompt(prompt);
|
||||
registries.promptRegistry.registerPrompt(prompt);
|
||||
}
|
||||
for (const tool of tools) {
|
||||
this.toolRegistry.registerTool(tool);
|
||||
registries.toolRegistry.registerTool(tool);
|
||||
}
|
||||
this.toolRegistry.sortTools();
|
||||
registries.toolRegistry.sortTools();
|
||||
|
||||
// Validate MCP tool names in policy rules against discovered tools
|
||||
try {
|
||||
@@ -250,6 +266,14 @@ export class McpClient implements McpProgressReporter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregisters registries so this client will no longer update them when it receives
|
||||
* list_changed notifications from the server.
|
||||
*/
|
||||
removeRegistries(registries: RegistrySet): void {
|
||||
this.registeredRegistries.delete(registries);
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnects from the MCP server.
|
||||
*/
|
||||
@@ -257,9 +281,11 @@ export class McpClient implements McpProgressReporter {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
return;
|
||||
}
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
this.resourceRegistry.removeResourcesByServer(this.serverName);
|
||||
for (const registries of this.registeredRegistries) {
|
||||
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
registries.promptRegistry.removePromptsByServer(this.serverName);
|
||||
registries.resourceRegistry.removeResourcesByServer(this.serverName);
|
||||
}
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||
const client = this.client;
|
||||
this.client = undefined;
|
||||
@@ -294,6 +320,7 @@ export class McpClient implements McpProgressReporter {
|
||||
|
||||
private async discoverTools(
|
||||
cliConfig: McpContext,
|
||||
messageBus: MessageBus,
|
||||
options?: { timeout?: number; signal?: AbortSignal },
|
||||
): Promise<DiscoveredMCPTool[]> {
|
||||
this.assertConnected();
|
||||
@@ -302,7 +329,7 @@ export class McpClient implements McpProgressReporter {
|
||||
this.serverConfig,
|
||||
this.client!,
|
||||
cliConfig,
|
||||
this.toolRegistry.messageBus,
|
||||
messageBus,
|
||||
{
|
||||
...(options ?? {
|
||||
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
@@ -329,8 +356,11 @@ export class McpClient implements McpProgressReporter {
|
||||
return discoverResources(this.serverName, this.client!, this.cliConfig);
|
||||
}
|
||||
|
||||
private updateResourceRegistry(resources: Resource[]): void {
|
||||
this.resourceRegistry.setResourcesForServer(this.serverName, resources);
|
||||
private updateResourceRegistry(
|
||||
resources: Resource[],
|
||||
resourceRegistry: ResourceRegistry,
|
||||
): void {
|
||||
resourceRegistry.setResourcesForServer(this.serverName, resources);
|
||||
}
|
||||
|
||||
async readResource(
|
||||
@@ -482,23 +512,32 @@ export class McpClient implements McpProgressReporter {
|
||||
try {
|
||||
newResources = await this.discoverResources();
|
||||
|
||||
// Verification Retry: If no resources are found or resources didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentResources =
|
||||
this.resourceRegistry.getResourcesByServer(this.serverName) || [];
|
||||
const resourceMatch =
|
||||
newResources.length === currentResources.length &&
|
||||
newResources.every((nr: Resource) =>
|
||||
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
|
||||
);
|
||||
for (const registries of this.registeredRegistries) {
|
||||
// Verification Retry: If no resources are found or resources didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentResources =
|
||||
registries.resourceRegistry.getResourcesByServer(
|
||||
this.serverName,
|
||||
) || [];
|
||||
const resourceMatch =
|
||||
newResources.length === currentResources.length &&
|
||||
newResources.every((nr: Resource) =>
|
||||
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
|
||||
);
|
||||
|
||||
if (resourceMatch && !this.pendingResourceRefresh) {
|
||||
debugLogger.log(
|
||||
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
if (resourceMatch && !this.pendingResourceRefresh) {
|
||||
debugLogger.log(
|
||||
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newResources = await this.discoverResources();
|
||||
}
|
||||
|
||||
this.updateResourceRegistry(
|
||||
newResources,
|
||||
registries.resourceRegistry,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newResources = await this.discoverResources();
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
@@ -508,8 +547,6 @@ export class McpClient implements McpProgressReporter {
|
||||
break;
|
||||
}
|
||||
|
||||
this.updateResourceRegistry(newResources);
|
||||
|
||||
if (this.onContextUpdated) {
|
||||
await this.onContextUpdated(abortController.signal);
|
||||
}
|
||||
@@ -575,30 +612,33 @@ export class McpClient implements McpProgressReporter {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
// Verification Retry: If no prompts are found or prompts didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentPrompts =
|
||||
this.promptRegistry.getPromptsByServer(this.serverName) || [];
|
||||
const promptsMatch =
|
||||
newPrompts.length === currentPrompts.length &&
|
||||
newPrompts.every((np) =>
|
||||
currentPrompts.some((cp) => cp.name === np.name),
|
||||
);
|
||||
for (const registries of this.registeredRegistries) {
|
||||
// Verification Retry: If no prompts are found or prompts didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentPrompts =
|
||||
registries.promptRegistry.getPromptsByServer(this.serverName) ||
|
||||
[];
|
||||
const promptsMatch =
|
||||
newPrompts.length === currentPrompts.length &&
|
||||
newPrompts.every((np) =>
|
||||
currentPrompts.some((cp) => cp.name === np.name),
|
||||
);
|
||||
|
||||
if (promptsMatch && !this.pendingPromptRefresh) {
|
||||
debugLogger.log(
|
||||
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newPrompts = await this.fetchPrompts({
|
||||
signal: abortController.signal,
|
||||
});
|
||||
}
|
||||
if (promptsMatch && !this.pendingPromptRefresh) {
|
||||
debugLogger.log(
|
||||
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newPrompts = await this.fetchPrompts({
|
||||
signal: abortController.signal,
|
||||
});
|
||||
}
|
||||
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
for (const prompt of newPrompts) {
|
||||
this.promptRegistry.registerPrompt(prompt);
|
||||
registries.promptRegistry.removePromptsByServer(this.serverName);
|
||||
for (const prompt of newPrompts) {
|
||||
registries.promptRegistry.registerPrompt(prompt);
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
@@ -666,42 +706,58 @@ export class McpClient implements McpProgressReporter {
|
||||
const abortController = new AbortController();
|
||||
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
|
||||
|
||||
let newTools;
|
||||
try {
|
||||
newTools = await this.discoverTools(this.cliConfig, {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
debugLogger.log(
|
||||
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
|
||||
// Verification Retry (Option 3): If no tools are found or tools didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentTools =
|
||||
this.toolRegistry.getToolsByServer(this.serverName) || [];
|
||||
const toolNamesMatch =
|
||||
newTools.length === currentTools.length &&
|
||||
newTools.every((nt) =>
|
||||
currentTools.some(
|
||||
(ct) =>
|
||||
ct.name === nt.name ||
|
||||
(ct instanceof DiscoveredMCPTool &&
|
||||
ct.serverToolName === nt.serverToolName),
|
||||
),
|
||||
for (const registries of this.registeredRegistries) {
|
||||
let newTools = await this.discoverTools(
|
||||
this.cliConfig,
|
||||
registries.toolRegistry.getMessageBus(),
|
||||
{
|
||||
signal: abortController.signal,
|
||||
},
|
||||
);
|
||||
debugLogger.log(
|
||||
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
|
||||
if (toolNamesMatch && !this.pendingToolRefresh) {
|
||||
debugLogger.log(
|
||||
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newTools = await this.discoverTools(this.cliConfig, {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
debugLogger.log(
|
||||
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
// Verification Retry (Option 3): If no tools are found or tools didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentTools =
|
||||
registries.toolRegistry.getToolsByServer(this.serverName) || [];
|
||||
const toolNamesMatch =
|
||||
newTools.length === currentTools.length &&
|
||||
newTools.every((nt) =>
|
||||
currentTools.some(
|
||||
(ct) =>
|
||||
ct.name === nt.name ||
|
||||
(ct instanceof DiscoveredMCPTool &&
|
||||
ct.serverToolName === nt.serverToolName),
|
||||
),
|
||||
);
|
||||
|
||||
if (toolNamesMatch && !this.pendingToolRefresh) {
|
||||
debugLogger.log(
|
||||
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newTools = await this.discoverTools(
|
||||
this.cliConfig,
|
||||
registries.toolRegistry.getMessageBus(),
|
||||
{
|
||||
signal: abortController.signal,
|
||||
},
|
||||
);
|
||||
debugLogger.log(
|
||||
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
}
|
||||
|
||||
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
|
||||
for (const tool of newTools) {
|
||||
registries.toolRegistry.registerTool(tool);
|
||||
}
|
||||
registries.toolRegistry.sortTools();
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
@@ -711,13 +767,6 @@ export class McpClient implements McpProgressReporter {
|
||||
break;
|
||||
}
|
||||
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
|
||||
for (const tool of newTools) {
|
||||
this.toolRegistry.registerTool(tool);
|
||||
}
|
||||
this.toolRegistry.sortTools();
|
||||
|
||||
if (this.onContextUpdated) {
|
||||
await this.onContextUpdated(abortController.signal);
|
||||
}
|
||||
|
||||
@@ -284,6 +284,26 @@ describe('ToolRegistry', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeMcpToolsByServer', () => {
|
||||
it('should remove all tools from a specific server', () => {
|
||||
const serverName = 'test-server';
|
||||
const mcpTool1 = createMCPTool(serverName, 'tool1', 'desc1');
|
||||
const mcpTool2 = createMCPTool(serverName, 'tool2', 'desc2');
|
||||
const otherTool = createMCPTool('other-server', 'tool3', 'desc3');
|
||||
|
||||
toolRegistry.registerTool(mcpTool1);
|
||||
toolRegistry.registerTool(mcpTool2);
|
||||
toolRegistry.registerTool(otherTool);
|
||||
|
||||
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(2);
|
||||
|
||||
toolRegistry.removeMcpToolsByServer(serverName);
|
||||
|
||||
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(0);
|
||||
expect(toolRegistry.getToolsByServer('other-server')).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('excluded tools', () => {
|
||||
const simpleTool = new MockTool({
|
||||
name: 'tool-a',
|
||||
|
||||
@@ -223,10 +223,16 @@ export class ToolRegistry {
|
||||
private allKnownTools: Map<string, AnyDeclarativeTool> = new Map();
|
||||
private config: Config;
|
||||
readonly messageBus: MessageBus;
|
||||
private isMainRegistry: boolean;
|
||||
|
||||
constructor(config: Config, messageBus: MessageBus) {
|
||||
constructor(
|
||||
config: Config,
|
||||
messageBus: MessageBus,
|
||||
isMainRegistry: boolean = false,
|
||||
) {
|
||||
this.config = config;
|
||||
this.messageBus = messageBus;
|
||||
this.isMainRegistry = isMainRegistry;
|
||||
}
|
||||
|
||||
getMessageBus(): MessageBus {
|
||||
@@ -599,6 +605,10 @@ export class ToolRegistry {
|
||||
const declarations: FunctionDeclaration[] = [];
|
||||
const seenNames = new Set<string>();
|
||||
|
||||
const mainAgentTools = this.isMainRegistry
|
||||
? this.config.getMainAgentTools()
|
||||
: undefined;
|
||||
|
||||
this.getActiveTools().forEach((tool) => {
|
||||
const toolName =
|
||||
tool instanceof DiscoveredMCPTool
|
||||
@@ -608,6 +618,16 @@ export class ToolRegistry {
|
||||
if (seenNames.has(toolName)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (
|
||||
mainAgentTools &&
|
||||
!mainAgentTools.includes(toolName) &&
|
||||
!mainAgentTools.includes(tool.constructor.name) &&
|
||||
!mainAgentTools.some((t) => t.startsWith(`${tool.constructor.name}(`))
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
seenNames.add(toolName);
|
||||
|
||||
let schema = tool.getSchema(modelId);
|
||||
|
||||
Reference in New Issue
Block a user