refactor(core): architectural decoupling of MCP management and tool isolation

This commit implements a proper architectural decoupling of MCP servers from the global ToolRegistry, eliminating the need for the `__agent__` naming prefix while maintaining perfect isolation.

Key changes:
1. McpClientManager now acts as a pure connection pool, keying clients by a hash of their configuration. This allows multiple agents or extensions to define servers with the same name (e.g. 'github') without collision.
2. McpClient supports multiple 'RegistrySets', allowing it to push discovered tools, prompts, and resources into arbitrary isolated registries.
3. LocalAgentExecutor now creates and manages its own isolated Tool, Prompt, and Resource registries. The `__agent__` prefix is removed, and tools retain their standard `mcp_{server}_{tool}` FQN.
4. CoreToolScheduler and policy checks are reverted to use standard names, as isolation is now handled at the registry level rather than via string namespacing.
5. Proxied the Config object within subagents to ensure system-wide components (like prompt templates) automatically use the agent-specific registries.
6. Verified through comprehensive updates to core tests for agents, MCP management, and registries.
This commit is contained in:
Akhilesh Kumar
2026-03-13 19:23:33 +00:00
parent ee425228fe
commit 54a9bce2b7
13 changed files with 421 additions and 307 deletions
+17 -1
View File
@@ -11,6 +11,8 @@ import type {
CompletedToolCall,
} from '../scheduler/types.js';
import type { ToolRegistry } from '../tools/tool-registry.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { ResourceRegistry } from '../resources/resource-registry.js';
import type { EditorType } from '../utils/editor.js';
/**
@@ -25,6 +27,10 @@ export interface AgentSchedulingOptions {
parentCallId?: string;
/** The tool registry specific to this agent. */
toolRegistry: ToolRegistry;
/** The prompt registry specific to this agent. */
promptRegistry?: PromptRegistry;
/** The resource registry specific to this agent. */
resourceRegistry?: ResourceRegistry;
/** AbortSignal for cancellation. */
signal: AbortSignal;
/** Optional function to get the preferred editor for tool modifications. */
@@ -51,16 +57,26 @@ export async function scheduleAgentTools(
subagent,
parentCallId,
toolRegistry,
promptRegistry,
resourceRegistry,
signal,
getPreferredEditor,
onWaitingForConfirmation,
} = options;
// Create a proxy/override of the config to provide the agent-specific tool registry.
// Create a proxy/override of the config to provide the agent-specific registries.
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const agentConfig: Config = Object.create(config);
agentConfig.getToolRegistry = () => toolRegistry;
agentConfig.getMessageBus = () => toolRegistry.getMessageBus();
if (promptRegistry) {
agentConfig.getPromptRegistry = () => promptRegistry;
}
if (resourceRegistry) {
agentConfig.getResourceRegistry = () => resourceRegistry;
}
// Override toolRegistry property so AgentLoopContext reads the agent-specific registry.
Object.defineProperty(agentConfig, 'toolRegistry', {
get: () => toolRegistry,
-6
View File
@@ -550,12 +550,6 @@ export function markdownToAgentDefinition(
config.description,
config.include_tools,
config.exclude_tools,
undefined, // extension
undefined, // oauth
undefined, // authProviderType
undefined, // targetAudience
undefined, // targetServiceAccount
name, // originalName
);
}
}
+13 -16
View File
@@ -48,6 +48,8 @@ import { debugLogger } from '../utils/debugLogger.js';
import { LocalAgentExecutor, type ActivityCallback } from './local-executor.js';
import { makeFakeConfig } from '../test-utils/config.js';
import { ToolRegistry } from '../tools/tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { ResourceRegistry } from '../resources/resource-registry.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { LSTool } from '../tools/ls.js';
import { LS_TOOL_NAME, READ_FILE_TOOL_NAME } from '../tools/tool-names.js';
@@ -2506,31 +2508,27 @@ describe('LocalAgentExecutor', () => {
const mcpManager = mockConfig.getMcpClientManager();
expect(mcpManager?.maybeDiscoverMcpServer).toHaveBeenCalledWith(
'__agent__TestAgent__test-server',
'test-server',
mcpServers['test-server'],
expect.objectContaining({
toolRegistry: expect.any(ToolRegistry),
promptRegistry: expect.any(PromptRegistry),
resourceRegistry: expect.any(ResourceRegistry),
}),
);
});
it('should filter out other agents MCP tools when inheriting tools from parent registry', async () => {
const parentMcpTool1 = new DiscoveredMCPTool(
it('should inherit main registry tools', async () => {
const parentMcpTool = new DiscoveredMCPTool(
{} as unknown as CallableTool,
'__agent__OtherAgent__server1',
'main-server',
'tool1',
'desc1',
{},
mockConfig.getMessageBus(),
);
const parentMcpTool2 = new DiscoveredMCPTool(
{} as unknown as CallableTool,
'__agent__TestAgent__server2',
'tool2',
'desc2',
{},
mockConfig.getMessageBus(),
);
parentToolRegistry.registerTool(parentMcpTool1);
parentToolRegistry.registerTool(parentMcpTool2);
parentToolRegistry.registerTool(parentMcpTool);
const definition = createTestDefinition();
definition.toolConfig = undefined; // trigger inheritance
@@ -2549,8 +2547,7 @@ describe('LocalAgentExecutor', () => {
executor as unknown as { toolRegistry: ToolRegistry }
).toolRegistry.getAllToolNames();
expect(agentTools).toContain(parentMcpTool2.name);
expect(agentTools).not.toContain(parentMcpTool1.name);
expect(agentTools).toContain(parentMcpTool.name);
});
});
});
+40 -33
View File
@@ -17,6 +17,8 @@ import {
type Schema,
} from '@google/genai';
import { ToolRegistry } from '../tools/tool-registry.js';
import { PromptRegistry } from '../prompts/prompt-registry.js';
import { ResourceRegistry } from '../resources/resource-registry.js';
import { type AnyDeclarativeTool } from '../tools/tools.js';
import {
DiscoveredMCPTool,
@@ -99,6 +101,8 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private readonly agentId: string;
private readonly toolRegistry: ToolRegistry;
private readonly promptRegistry: PromptRegistry;
private readonly resourceRegistry: ResourceRegistry;
private readonly context: AgentLoopContext;
private readonly onActivity?: ActivityCallback;
private readonly compressionService: ChatCompressionService;
@@ -106,7 +110,18 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private hasFailedCompressionAttempt = false;
private get config(): Config {
return this.context.config;
// eslint-disable-next-line @typescript-eslint/no-unsafe-assignment
const agentConfig: Config = Object.create(this.context.config);
agentConfig.getToolRegistry = () => this.toolRegistry;
agentConfig.getPromptRegistry = () => this.promptRegistry;
agentConfig.getResourceRegistry = () => this.resourceRegistry;
agentConfig.getMessageBus = () => this.toolRegistry.getMessageBus();
Object.defineProperty(agentConfig, 'toolRegistry', {
get: () => this.toolRegistry,
configurable: true,
});
return agentConfig;
}
/**
@@ -142,17 +157,23 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
return parentMessageBus.publish(message);
};
// Create an isolated tool registry for this agent instance.
// Create isolated registries for this agent instance.
const agentToolRegistry = new ToolRegistry(
context.config,
subagentMessageBus,
);
const agentPromptRegistry = new PromptRegistry();
const agentResourceRegistry = new ResourceRegistry();
if (definition.mcpServers) {
const globalMcpManager = context.config.getMcpClientManager();
if (globalMcpManager) {
for (const [name, config] of Object.entries(definition.mcpServers)) {
const prefixedName = `__agent__${definition.name}__${name}`;
await globalMcpManager.maybeDiscoverMcpServer(prefixedName, config);
await globalMcpManager.maybeDiscoverMcpServer(name, config, {
toolRegistry: agentToolRegistry,
promptRegistry: agentPromptRegistry,
resourceRegistry: agentResourceRegistry,
});
}
}
}
@@ -233,32 +254,10 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
} else {
// If no tools are explicitly configured, default to all available tools.
for (const toolName of parentToolRegistry.getAllToolNames()) {
const tool = parentToolRegistry.getTool(toolName);
if (
tool instanceof DiscoveredMCPTool &&
tool.serverName.startsWith('__agent__')
) {
if (!tool.serverName.startsWith(`__agent__${definition.name}__`)) {
continue; // Skip other agents' MCP tools
}
}
registerToolByName(toolName);
}
}
// Always ensure this agent's own MCP servers are included, even if toolConfig is restricted
parentToolRegistry.getAllTools().forEach((tool) => {
if (
tool instanceof DiscoveredMCPTool &&
tool.serverName.startsWith(`__agent__${definition.name}__`)
) {
const qualifiedName = tool.name;
if (!agentToolRegistry.getTool(qualifiedName)) {
registerToolByName(qualifiedName);
}
}
});
agentToolRegistry.sortTools();
// Get the parent prompt ID from context
@@ -271,10 +270,12 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
return new LocalAgentExecutor(
definition,
context,
agentToolRegistry,
parentPromptId,
parentCallId,
agentToolRegistry,
agentPromptRegistry,
agentResourceRegistry,
onActivity,
parentCallId,
);
}
@@ -287,14 +288,18 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
private constructor(
definition: LocalAgentDefinition<TOutput>,
context: AgentLoopContext,
toolRegistry: ToolRegistry,
parentPromptId: string | undefined,
parentCallId: string | undefined,
toolRegistry: ToolRegistry,
promptRegistry: PromptRegistry,
resourceRegistry: ResourceRegistry,
onActivity?: ActivityCallback,
parentCallId?: string,
) {
this.definition = definition;
this.context = context;
this.toolRegistry = toolRegistry;
this.promptRegistry = promptRegistry;
this.resourceRegistry = resourceRegistry;
this.onActivity = onActivity;
this.compressionService = new ChatCompressionService();
this.parentCallId = parentCallId;
@@ -538,7 +543,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]);
logAgentStart(
this.config,
this.context.config,
new AgentStartEvent(this.agentId, this.definition.name),
);
@@ -745,7 +750,7 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
} finally {
deadlineTimer.abort();
logAgentFinish(
this.config,
this.context.config,
new AgentFinishEvent(
this.agentId,
this.definition.name,
@@ -1165,10 +1170,12 @@ export class LocalAgentExecutor<TOutput extends z.ZodTypeAny> {
this.config,
toolRequests,
{
schedulerId: this.agentId,
schedulerId: promptId,
subagent: this.definition.name,
parentCallId: this.parentCallId,
toolRegistry: this.toolRegistry,
promptRegistry: this.promptRegistry,
resourceRegistry: this.resourceRegistry,
signal,
onWaitingForConfirmation,
},
+2
View File
@@ -367,6 +367,7 @@ describe('Server Config (config.ts)', () => {
mcpStarted = true;
}),
getMcpInstructions: vi.fn(),
setMainRegistries: vi.fn(),
}) as Partial<McpClientManager> as McpClientManager,
);
@@ -400,6 +401,7 @@ describe('Server Config (config.ts)', () => {
mcpStarted = true;
}),
getMcpInstructions: vi.fn(),
setMainRegistries: vi.fn(),
}) as Partial<McpClientManager> as McpClientManager,
);
+5 -6
View File
@@ -445,11 +445,6 @@ export class MCPServerConfig {
readonly targetAudience?: string,
/* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
readonly targetServiceAccount?: string,
/**
* The original name of the server before any prefixing (e.g. for subagents).
* This is used by the policy engine to match rules.
*/
readonly originalName?: string,
) {}
}
@@ -1202,10 +1197,14 @@ export class Config implements McpContext, AgentLoopContext {
discoverToolsHandle?.end();
this.mcpClientManager = new McpClientManager(
this.clientVersion,
this._toolRegistry,
this,
this.eventEmitter,
);
this.mcpClientManager.setMainRegistries({
toolRegistry: this._toolRegistry,
promptRegistry: this.promptRegistry,
resourceRegistry: this.resourceRegistry,
});
// We do not await this promise so that the CLI can start up even if
// MCP servers are slow to connect.
this.mcpInitializationPromise = Promise.allSettled([
+3 -11
View File
@@ -47,11 +47,7 @@ import {
type ToolCallResponseInfo,
} from '../scheduler/types.js';
import { ToolExecutor } from '../scheduler/tool-executor.js';
import {
DiscoveredMCPTool,
MCP_TOOL_PREFIX,
MCP_QUALIFIED_NAME_SEPARATOR,
} from '../tools/mcp-tool.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { getPolicyDenialError } from '../scheduler/policy.js';
import { GeminiCliOperation } from '../telemetry/constants.js';
@@ -642,16 +638,12 @@ export class CoreToolScheduler {
// Policy Check using PolicyEngine
// We must reconstruct the FunctionCall format expected by PolicyEngine
const toolCallForPolicy = {
name:
toolCall.tool instanceof DiscoveredMCPTool &&
toolCall.tool.originalServerName
? `${MCP_TOOL_PREFIX}${toolCall.tool.originalServerName}${MCP_QUALIFIED_NAME_SEPARATOR}${toolCall.tool.serverToolName}`
: toolCall.request.name,
name: toolCall.request.name,
args: toolCall.request.args,
};
const serverName =
toolCall.tool instanceof DiscoveredMCPTool
? (toolCall.tool.originalServerName ?? toolCall.tool.serverName)
? toolCall.tool.serverName
: undefined;
const toolAnnotations = toolCall.tool.toolAnnotations;
@@ -34,21 +34,25 @@ describe('McpClientManager', () => {
beforeEach(() => {
mockedMcpClient = vi.mockObject({
connect: vi.fn(),
discover: vi.fn(),
discoverInto: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
getServerConfig: vi.fn(),
getServerName: vi.fn().mockReturnValue('test-server'),
} as unknown as McpClient);
vi.mocked(McpClient).mockReturnValue(mockedMcpClient);
mockConfig = vi.mockObject({
isTrustedFolder: vi.fn().mockReturnValue(true),
getMcpServers: vi.fn().mockReturnValue({}),
getPromptRegistry: () => {},
getResourceRegistry: () => {},
getPromptRegistry: vi.fn().mockReturnValue({ registerPrompt: vi.fn() }),
getResourceRegistry: vi
.fn()
.mockReturnValue({ setResourcesForServer: vi.fn() }),
getDebugMode: () => false,
getWorkspaceContext: () => {},
getWorkspaceContext: () => ({ getDirectories: () => [] }),
getAllowedMcpServers: vi.fn().mockReturnValue([]),
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getExcludedMcpServers: vi.fn().mockReturnValue([]),
getMcpServerCommand: vi.fn().mockReturnValue(''),
getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined),
getGeminiClient: vi.fn().mockReturnValue({
@@ -56,21 +60,39 @@ describe('McpClientManager', () => {
}),
refreshMcpContext: vi.fn(),
} as unknown as Config);
toolRegistry = {} as ToolRegistry;
toolRegistry = vi.mockObject({
registerTool: vi.fn(),
unregisterTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue({}),
removeMcpToolsByServer: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
} as unknown as ToolRegistry);
});
afterEach(() => {
vi.restoreAllMocks();
});
const setupManager = (manager: McpClientManager) => {
manager.setMainRegistries({
toolRegistry,
promptRegistry:
mockConfig.getPromptRegistry() as unknown as PromptRegistry,
resourceRegistry:
mockConfig.getResourceRegistry() as unknown as ResourceRegistry,
});
return manager;
};
it('should discover tools from all configured', async () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
});
@@ -80,12 +102,12 @@ describe('McpClientManager', () => {
'server-2': {},
'server-3': {},
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
// Each client should be connected/discovered
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(3);
// But context refresh should happen only once
expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce();
@@ -95,7 +117,7 @@ describe('McpClientManager', () => {
mockConfig.getMcpServers.mockReturnValue({
'test-server': {},
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED);
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
@@ -112,7 +134,7 @@ describe('McpClientManager', () => {
isFileEnabled: vi.fn().mockResolvedValue(false),
});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
await promise;
@@ -120,7 +142,7 @@ describe('McpClientManager', () => {
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
expect(manager.getMcpServerCount()).toBe(0);
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should mark discovery completed when all configured servers are blocked', async () => {
@@ -129,7 +151,7 @@ describe('McpClientManager', () => {
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const promise = manager.startConfiguredMcpServers();
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS);
await promise;
@@ -137,7 +159,7 @@ describe('McpClientManager', () => {
expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED);
expect(manager.getMcpServerCount()).toBe(0);
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should not discover tools if folder is not trusted', async () => {
@@ -145,10 +167,10 @@ describe('McpClientManager', () => {
'test-server': {},
});
mockConfig.isTrustedFolder.mockReturnValue(false);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should not start blocked servers', async () => {
@@ -156,10 +178,10 @@ describe('McpClientManager', () => {
'test-server': {},
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should only start allowed servers if allow list is not empty', async () => {
@@ -168,14 +190,14 @@ describe('McpClientManager', () => {
'another-server': {},
});
mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
});
it('should start servers from extensions', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startExtension({
name: 'test-extension',
mcpServers: {
@@ -188,11 +210,11 @@ describe('McpClientManager', () => {
id: '123',
});
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce();
});
it('should not start servers from disabled extensions', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startExtension({
name: 'test-extension',
mcpServers: {
@@ -205,7 +227,7 @@ describe('McpClientManager', () => {
id: '123',
});
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled();
});
it('should add blocked servers to the blockedMcpServers list', async () => {
@@ -213,7 +235,7 @@ describe('McpClientManager', () => {
'test-server': {},
});
mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(manager.getBlockedMcpServers()).toEqual([
{ name: 'test-server', extensionName: '' },
@@ -226,16 +248,16 @@ describe('McpClientManager', () => {
'test-server': {},
});
mockedMcpClient.getServerConfig.mockReturnValue({});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
await manager.restart();
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
});
});
@@ -245,21 +267,21 @@ describe('McpClientManager', () => {
'test-server': {},
});
mockedMcpClient.getServerConfig.mockReturnValue({});
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1);
await manager.restartServer('test-server');
expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1);
expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2);
expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2);
});
it('should throw an error if the server does not exist', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await expect(manager.restartServer('non-existent')).rejects.toThrow(
'No MCP server registered with the name "non-existent"',
);
@@ -281,7 +303,7 @@ describe('McpClientManager', () => {
});
mockedMcpClient.getServerConfig.mockReturnValue(originalConfig);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
await manager.startConfiguredMcpServers();
// First call should use the original config
@@ -306,9 +328,10 @@ describe('McpClientManager', () => {
(name, config) =>
({
connect: vi.fn(),
discover: vi.fn(),
discoverInto: vi.fn(),
disconnect: vi.fn(),
getServerConfig: vi.fn().mockReturnValue(config),
getServerName: vi.fn().mockReturnValue(name),
getInstructions: vi
.fn()
.mockReturnValue(
@@ -318,12 +341,7 @@ describe('McpClientManager', () => {
),
}) as unknown as McpClient,
);
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
mockConfig.getMcpServers.mockReturnValue({
'server-with-instructions': {},
@@ -358,11 +376,7 @@ describe('McpClientManager', () => {
'test-server': {},
});
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
await expect(manager.startConfiguredMcpServers()).resolves.not.toThrow();
});
@@ -381,11 +395,8 @@ describe('McpClientManager', () => {
'test-server': {},
});
const manager = new McpClientManager(
'0.0.1',
{} as ToolRegistry,
mockConfig,
);
const manager = new McpClientManager('0.0.1', mockConfig);
await manager.startConfiguredMcpServers();
await expect(manager.restartServer('test-server')).resolves.not.toThrow();
@@ -394,7 +405,7 @@ describe('McpClientManager', () => {
describe('Extension handling', () => {
it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const mcpServers = {
'test-server': { command: 'node', args: ['server.js'] },
};
@@ -416,7 +427,7 @@ describe('McpClientManager', () => {
});
it('should ignore an extension attempting to register a server with an existing name', async () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const userConfig = { command: 'node', args: ['user-server.js'] };
mockConfig.getMcpServers.mockReturnValue({
@@ -447,7 +458,7 @@ describe('McpClientManager', () => {
it('should remove servers from blockedMcpServers when stopExtension is called', async () => {
mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']);
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
const mcpServers = {
'blocked-server': { command: 'node', args: ['server.js'] },
};
@@ -485,7 +496,7 @@ describe('McpClientManager', () => {
});
it('should emit hint instead of full error when user has not interacted with MCP', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic(
'error',
'Something went wrong',
@@ -504,7 +515,7 @@ describe('McpClientManager', () => {
});
it('should emit full error when user has interacted with MCP', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.setUserInteractedWithMcp();
manager.emitDiagnostic(
'error',
@@ -520,7 +531,7 @@ describe('McpClientManager', () => {
});
it('should still deduplicate diagnostic messages after user interaction', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.setUserInteractedWithMcp();
manager.emitDiagnostic('error', 'Same error');
@@ -530,7 +541,7 @@ describe('McpClientManager', () => {
});
it('should only show hint once per session', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic('error', 'Error 1');
manager.emitDiagnostic('error', 'Error 2');
@@ -543,7 +554,7 @@ describe('McpClientManager', () => {
});
it('should capture last error for a server even when silenced', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic(
'error',
@@ -558,7 +569,7 @@ describe('McpClientManager', () => {
});
it('should show previously deduplicated errors after interaction clears state', () => {
const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig);
const manager = setupManager(new McpClientManager('0.0.1', mockConfig));
manager.emitDiagnostic('error', 'Same error');
expect(coreEventsMock.emitFeedback).toHaveBeenCalledTimes(1); // The hint
+122 -46
View File
@@ -20,6 +20,11 @@ import type { EventEmitter } from 'node:events';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
import { createHash } from 'node:crypto';
import { stableStringify } from '../policy/stable-stringify.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { ResourceRegistry } from '../resources/resource-registry.js';
/**
* Manages the lifecycle of multiple MCP clients, including local child processes.
* This class is responsible for starting, stopping, and discovering tools from
@@ -30,7 +35,6 @@ export class McpClientManager {
// Track all configured servers (including disabled ones) for UI display
private allServerConfigs: Map<string, MCPServerConfig> = new Map();
private readonly clientVersion: string;
private readonly toolRegistry: ToolRegistry;
private readonly cliConfig: Config;
// If we have ongoing MCP client discovery, this completes once that is done.
private discoveryPromise: Promise<void> | undefined;
@@ -42,6 +46,10 @@ export class McpClientManager {
extensionName: string;
}> = [];
private mainToolRegistry: ToolRegistry | undefined;
private mainPromptRegistry: PromptRegistry | undefined;
private mainResourceRegistry: ResourceRegistry | undefined;
/**
* Track whether the user has explicitly interacted with MCP in this session
* (e.g. by running an /mcp command).
@@ -66,16 +74,24 @@ export class McpClientManager {
constructor(
clientVersion: string,
toolRegistry: ToolRegistry,
cliConfig: Config,
eventEmitter?: EventEmitter,
) {
this.clientVersion = clientVersion;
this.toolRegistry = toolRegistry;
this.cliConfig = cliConfig;
this.eventEmitter = eventEmitter;
}
setMainRegistries(registries: {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
}) {
this.mainToolRegistry = registries.toolRegistry;
this.mainPromptRegistry = registries.promptRegistry;
this.mainResourceRegistry = registries.resourceRegistry;
}
setUserInteractedWithMcp() {
this.userInteractedWithMcp = true;
}
@@ -236,16 +252,17 @@ export class McpClientManager {
return false;
}
private async disconnectClient(name: string, skipRefresh = false) {
const existing = this.clients.get(name);
private async disconnectClient(clientKey: string, skipRefresh = false) {
const existing = this.clients.get(clientKey);
if (existing) {
const serverName = existing.getServerName();
try {
this.clients.delete(name);
this.clients.delete(clientKey);
this.eventEmitter?.emit('mcp-client-update', this.clients);
await existing.disconnect();
} catch (error) {
debugLogger.warn(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
`Error stopping client '${serverName}': ${getErrorMessage(error)}`,
);
} finally {
if (!skipRefresh) {
@@ -257,22 +274,61 @@ export class McpClientManager {
}
}
private getClientKey(name: string, config: MCPServerConfig): string {
const { extension, ...rest } = config;
const keyData = {
name,
config: rest,
extensionId: extension?.id,
};
return createHash('sha256').update(stableStringify(keyData)).digest('hex');
}
async maybeDiscoverMcpServer(
name: string,
config: MCPServerConfig,
registries?: {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
},
): Promise<void> {
const existing = this.clients.get(name);
if (
existing &&
existing.getServerConfig().extension?.id !== config.extension?.id
) {
const extensionText = config.extension
? ` from extension "${config.extension.name}"`
: '';
debugLogger.warn(
`Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`,
const clientKey = this.getClientKey(name, config);
const existing = this.clients.get(clientKey);
// If no registries are provided (main agent) and a server with this name already exists
// but with a different configuration, handle potential conflicts.
if (!registries) {
const existingSameName = Array.from(this.clients.values()).find(
(c) => c.getServerName() === name,
);
return;
if (existingSameName) {
const existingConfig = existingSameName.getServerConfig();
const existingKey = this.getClientKey(name, existingConfig);
if (existingKey !== clientKey) {
const bothMain = !config.extension && !existingConfig.extension;
const sameExtension =
config.extension &&
existingConfig.extension &&
config.extension.id === existingConfig.extension.id;
if (bothMain || sameExtension) {
// This is a configuration update from the same source (hot-reload).
// We should stop the old client before starting the new one.
await this.disconnectClient(existingKey, true);
} else {
// This is a conflict (e.g. an extension trying to overwrite a main server).
const extensionText = config.extension
? ` from extension "${config.extension.name}"`
: '';
debugLogger.warn(
`Skipping MCP config for server with name "${name}"${extensionText} as a server with that name already exists from a different source.`,
);
return;
}
}
}
}
// Always track server config for UI display
@@ -291,7 +347,7 @@ export class McpClientManager {
// User-disabled servers: disconnect if running, don't start
if (await this.isDisabledByUser(name)) {
if (existing) {
await this.disconnectClient(name);
await this.disconnectClient(clientKey);
}
return;
}
@@ -302,34 +358,46 @@ export class McpClientManager {
return;
}
const currentDiscoveryPromise = new Promise<void>((resolve, reject) => {
(async () => {
const currentDiscoveryPromise = new Promise<void>((resolve) => {
void (async () => {
try {
if (existing) {
this.clients.delete(name);
await existing.disconnect();
let client = existing;
if (!client) {
client = new McpClient(
name,
config,
this.cliConfig.getWorkspaceContext(),
this.cliConfig,
this.cliConfig.getDebugMode(),
this.clientVersion,
async () => {
debugLogger.log(
`🔔 Refreshing context for server '${name}'...`,
);
await this.scheduleMcpContextRefresh();
},
);
this.clients.set(clientKey, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
}
const client = new McpClient(
name,
config,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getResourceRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig,
this.cliConfig.getDebugMode(),
this.clientVersion,
async () => {
debugLogger.log(`🔔 Refreshing context for server '${name}'...`);
await this.scheduleMcpContextRefresh();
},
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
const targetRegistries =
registries ??
(this.mainToolRegistry &&
this.mainPromptRegistry &&
this.mainResourceRegistry
? {
toolRegistry: this.mainToolRegistry,
promptRegistry: this.mainPromptRegistry,
resourceRegistry: this.mainResourceRegistry,
}
: undefined);
try {
await client.connect();
await client.discover(this.cliConfig);
if (targetRegistries) {
await client.discoverInto(this.cliConfig, targetRegistries);
}
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
@@ -349,13 +417,13 @@ export class McpClientManager {
const errorMessage = getErrorMessage(error);
this.emitDiagnostic(
'error',
`Error initializing MCP server '${name}': ${errorMessage}`,
`Fatal error ensuring MCP server '${name}' is connected: ${errorMessage}`,
error,
);
} finally {
resolve();
}
})().catch(reject);
})();
});
if (this.discoveryPromise) {
@@ -438,6 +506,11 @@ export class McpClientManager {
* Restarts all MCP servers (including newly enabled ones).
*/
async restart(): Promise<void> {
const disconnectionPromises = Array.from(this.clients.keys()).map((key) =>
this.disconnectClient(key, true),
);
await Promise.all(disconnectionPromises);
await Promise.all(
Array.from(this.allServerConfigs.entries()).map(
async ([name, config]) => {
@@ -462,6 +535,8 @@ export class McpClientManager {
if (!config) {
throw new Error(`No MCP server registered with the name "${name}"`);
}
const clientKey = this.getClientKey(name, config);
await this.disconnectClient(clientKey, true);
await this.maybeDiscoverMcpServer(name, config);
await this.scheduleMcpContextRefresh();
}
@@ -506,11 +581,12 @@ export class McpClientManager {
getMcpInstructions(): string {
const instructions: string[] = [];
for (const [name, client] of this.clients) {
for (const client of this.clients.values()) {
const serverName = client.getServerName();
const clientInstructions = client.getInstructions();
if (clientInstructions) {
instructions.push(
`The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
`The following are instructions provided by the tool server '${serverName}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`,
);
}
}
+136 -96
View File
@@ -130,6 +130,12 @@ export interface McpProgressReporter {
unregisterProgressToken(token: string | number): void;
}
export interface RegistrySet {
toolRegistry: ToolRegistry;
promptRegistry: PromptRegistry;
resourceRegistry: ResourceRegistry;
}
/**
* A client for a single MCP server.
*
@@ -147,6 +153,8 @@ export class McpClient implements McpProgressReporter {
private isRefreshingPrompts: boolean = false;
private pendingPromptRefresh: boolean = false;
private readonly registeredRegistries = new Set<RegistrySet>();
/**
* Map of progress tokens to tool call IDs.
* This allows us to route progress notifications to the correct tool call.
@@ -156,9 +164,6 @@ export class McpClient implements McpProgressReporter {
constructor(
private readonly serverName: string,
private readonly serverConfig: MCPServerConfig,
private readonly toolRegistry: ToolRegistry,
private readonly promptRegistry: PromptRegistry,
private readonly resourceRegistry: ResourceRegistry,
private readonly workspaceContext: WorkspaceContext,
private readonly cliConfig: McpContext,
private readonly debugMode: boolean,
@@ -166,6 +171,10 @@ export class McpClient implements McpProgressReporter {
private readonly onContextUpdated?: (signal?: AbortSignal) => Promise<void>,
) {}
getServerName(): string {
return this.serverName;
}
/**
* Connects to the MCP server.
*/
@@ -210,27 +219,34 @@ export class McpClient implements McpProgressReporter {
}
/**
* Discovers tools and prompts from the MCP server.
* Discovers tools and prompts from the MCP server into the specified registries.
*/
async discover(cliConfig: McpContext): Promise<void> {
async discoverInto(
cliConfig: McpContext,
registries: RegistrySet,
): Promise<void> {
this.assertConnected();
this.registeredRegistries.add(registries);
const prompts = await this.fetchPrompts();
const tools = await this.discoverTools(cliConfig);
const tools = await this.discoverTools(
cliConfig,
registries.toolRegistry.getMessageBus(),
);
const resources = await this.discoverResources();
this.updateResourceRegistry(resources);
this.updateResourceRegistry(resources, registries.resourceRegistry);
if (prompts.length === 0 && tools.length === 0 && resources.length === 0) {
throw new Error('No prompts, tools, or resources found on the server.');
}
for (const prompt of prompts) {
this.promptRegistry.registerPrompt(prompt);
registries.promptRegistry.registerPrompt(prompt);
}
for (const tool of tools) {
this.toolRegistry.registerTool(tool);
registries.toolRegistry.registerTool(tool);
}
this.toolRegistry.sortTools();
registries.toolRegistry.sortTools();
// Validate MCP tool names in policy rules against discovered tools
try {
@@ -257,9 +273,11 @@ export class McpClient implements McpProgressReporter {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
this.promptRegistry.removePromptsByServer(this.serverName);
this.resourceRegistry.removeResourcesByServer(this.serverName);
for (const registries of this.registeredRegistries) {
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
registries.promptRegistry.removePromptsByServer(this.serverName);
registries.resourceRegistry.removeResourcesByServer(this.serverName);
}
this.updateStatus(MCPServerStatus.DISCONNECTING);
const client = this.client;
this.client = undefined;
@@ -294,6 +312,7 @@ export class McpClient implements McpProgressReporter {
private async discoverTools(
cliConfig: McpContext,
messageBus: MessageBus,
options?: { timeout?: number; signal?: AbortSignal },
): Promise<DiscoveredMCPTool[]> {
this.assertConnected();
@@ -302,7 +321,7 @@ export class McpClient implements McpProgressReporter {
this.serverConfig,
this.client!,
cliConfig,
this.toolRegistry.getMessageBus(),
messageBus,
{
...(options ?? {
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
@@ -329,8 +348,11 @@ export class McpClient implements McpProgressReporter {
return discoverResources(this.serverName, this.client!, this.cliConfig);
}
private updateResourceRegistry(resources: Resource[]): void {
this.resourceRegistry.setResourcesForServer(this.serverName, resources);
private updateResourceRegistry(
resources: Resource[],
resourceRegistry: ResourceRegistry,
): void {
resourceRegistry.setResourcesForServer(this.serverName, resources);
}
async readResource(
@@ -482,23 +504,32 @@ export class McpClient implements McpProgressReporter {
try {
newResources = await this.discoverResources();
// Verification Retry: If no resources are found or resources didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentResources =
this.resourceRegistry.getResourcesByServer(this.serverName) || [];
const resourceMatch =
newResources.length === currentResources.length &&
newResources.every((nr: Resource) =>
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
);
for (const registries of this.registeredRegistries) {
// Verification Retry: If no resources are found or resources didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentResources =
registries.resourceRegistry.getResourcesByServer(
this.serverName,
) || [];
const resourceMatch =
newResources.length === currentResources.length &&
newResources.every((nr: Resource) =>
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
);
if (resourceMatch && !this.pendingResourceRefresh) {
debugLogger.log(
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
if (resourceMatch && !this.pendingResourceRefresh) {
debugLogger.log(
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newResources = await this.discoverResources();
}
this.updateResourceRegistry(
newResources,
registries.resourceRegistry,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newResources = await this.discoverResources();
}
} catch (err) {
debugLogger.error(
@@ -508,8 +539,6 @@ export class McpClient implements McpProgressReporter {
break;
}
this.updateResourceRegistry(newResources);
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
@@ -575,30 +604,33 @@ export class McpClient implements McpProgressReporter {
signal: abortController.signal,
});
// Verification Retry: If no prompts are found or prompts didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentPrompts =
this.promptRegistry.getPromptsByServer(this.serverName) || [];
const promptsMatch =
newPrompts.length === currentPrompts.length &&
newPrompts.every((np) =>
currentPrompts.some((cp) => cp.name === np.name),
);
for (const registries of this.registeredRegistries) {
// Verification Retry: If no prompts are found or prompts didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentPrompts =
registries.promptRegistry.getPromptsByServer(this.serverName) ||
[];
const promptsMatch =
newPrompts.length === currentPrompts.length &&
newPrompts.every((np) =>
currentPrompts.some((cp) => cp.name === np.name),
);
if (promptsMatch && !this.pendingPromptRefresh) {
debugLogger.log(
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
}
if (promptsMatch && !this.pendingPromptRefresh) {
debugLogger.log(
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
}
this.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) {
this.promptRegistry.registerPrompt(prompt);
registries.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) {
registries.promptRegistry.registerPrompt(prompt);
}
}
} catch (err) {
debugLogger.error(
@@ -666,42 +698,58 @@ export class McpClient implements McpProgressReporter {
const abortController = new AbortController();
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
let newTools;
try {
newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal,
});
debugLogger.log(
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
// Verification Retry (Option 3): If no tools are found or tools didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentTools =
this.toolRegistry.getToolsByServer(this.serverName) || [];
const toolNamesMatch =
newTools.length === currentTools.length &&
newTools.every((nt) =>
currentTools.some(
(ct) =>
ct.name === nt.name ||
(ct instanceof DiscoveredMCPTool &&
ct.serverToolName === nt.serverToolName),
),
for (const registries of this.registeredRegistries) {
let newTools = await this.discoverTools(
this.cliConfig,
registries.toolRegistry.getMessageBus(),
{
signal: abortController.signal,
},
);
debugLogger.log(
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
if (toolNamesMatch && !this.pendingToolRefresh) {
debugLogger.log(
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal,
});
debugLogger.log(
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
// Verification Retry (Option 3): If no tools are found or tools didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentTools =
registries.toolRegistry.getToolsByServer(this.serverName) || [];
const toolNamesMatch =
newTools.length === currentTools.length &&
newTools.every((nt) =>
currentTools.some(
(ct) =>
ct.name === nt.name ||
(ct instanceof DiscoveredMCPTool &&
ct.serverToolName === nt.serverToolName),
),
);
if (toolNamesMatch && !this.pendingToolRefresh) {
debugLogger.log(
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newTools = await this.discoverTools(
this.cliConfig,
registries.toolRegistry.getMessageBus(),
{
signal: abortController.signal,
},
);
debugLogger.log(
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
}
registries.toolRegistry.removeMcpToolsByServer(this.serverName);
for (const tool of newTools) {
registries.toolRegistry.registerTool(tool);
}
registries.toolRegistry.sortTools();
}
} catch (err) {
debugLogger.error(
@@ -711,13 +759,6 @@ export class McpClient implements McpProgressReporter {
break;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
for (const tool of newTools) {
this.toolRegistry.registerTool(tool);
}
this.toolRegistry.sortTools();
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
@@ -1266,7 +1307,6 @@ export async function discoverTools(
mcpServerConfig.extension?.name,
mcpServerConfig.extension?.id,
annotations as Record<string, unknown> | undefined,
mcpServerConfig.originalName,
);
discoveredTools.push(tool);
-1
View File
@@ -348,7 +348,6 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
override readonly extensionName?: string,
override readonly extensionId?: string,
private readonly _toolAnnotations?: Record<string, unknown>,
readonly originalServerName?: string,
) {
super(
nameOverride ??
+13 -24
View File
@@ -284,34 +284,23 @@ describe('ToolRegistry', () => {
});
});
describe('subagent MCP tools filtering', () => {
it('should hide __agent__ prefixed tools when isMainRegistry is true', async () => {
const mainRegistry = new ToolRegistry(config, mockMessageBus, true);
const subagentRegistry = new ToolRegistry(config, mockMessageBus, false);
describe('removeMcpToolsByServer', () => {
it('should remove all tools from a specific server', () => {
const serverName = 'test-server';
const mcpTool1 = createMCPTool(serverName, 'tool1', 'desc1');
const mcpTool2 = createMCPTool(serverName, 'tool2', 'desc2');
const otherTool = createMCPTool('other-server', 'tool3', 'desc3');
const mcpTool = createMCPTool(
'__agent__TestAgent__myServer',
'my-tool',
'description',
);
vi.spyOn(mcpTool, 'getSchema').mockReturnValue({
name: 'my_tool',
description: 'description',
} as unknown as FunctionDeclaration);
toolRegistry.registerTool(mcpTool1);
toolRegistry.registerTool(mcpTool2);
toolRegistry.registerTool(otherTool);
mainRegistry.registerTool(mcpTool);
subagentRegistry.registerTool(mcpTool);
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(2);
const mainDeclarations =
mainRegistry.getFunctionDeclarations('test-model');
const subagentDeclarations =
subagentRegistry.getFunctionDeclarations('test-model');
toolRegistry.removeMcpToolsByServer(serverName);
expect(mainDeclarations.length).toBe(0);
expect(subagentDeclarations.length).toBe(1);
expect(subagentDeclarations[0].name).toBe(
'mcp___agent__TestAgent__myServer_my-tool',
);
expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(0);
expect(toolRegistry.getToolsByServer('other-server')).toHaveLength(1);
});
});
-8
View File
@@ -561,14 +561,6 @@ export class ToolRegistry {
return;
}
if (
this.isMainRegistry &&
tool instanceof DiscoveredMCPTool &&
tool.serverName.startsWith('__agent__')
) {
return;
}
if (
mainAgentTools &&
!mainAgentTools.includes(toolName) &&