diff --git a/packages/core/src/agents/agent-scheduler.ts b/packages/core/src/agents/agent-scheduler.ts index 38804bf01a..40dd6c4f2c 100644 --- a/packages/core/src/agents/agent-scheduler.ts +++ b/packages/core/src/agents/agent-scheduler.ts @@ -11,6 +11,8 @@ import type { CompletedToolCall, } from '../scheduler/types.js'; import type { ToolRegistry } from '../tools/tool-registry.js'; +import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import type { ResourceRegistry } from '../resources/resource-registry.js'; import type { EditorType } from '../utils/editor.js'; /** @@ -25,6 +27,10 @@ export interface AgentSchedulingOptions { parentCallId?: string; /** The tool registry specific to this agent. */ toolRegistry: ToolRegistry; + /** The prompt registry specific to this agent. */ + promptRegistry?: PromptRegistry; + /** The resource registry specific to this agent. */ + resourceRegistry?: ResourceRegistry; /** AbortSignal for cancellation. */ signal: AbortSignal; /** Optional function to get the preferred editor for tool modifications. */ @@ -51,16 +57,26 @@ export async function scheduleAgentTools( subagent, parentCallId, toolRegistry, + promptRegistry, + resourceRegistry, signal, getPreferredEditor, onWaitingForConfirmation, } = options; - // Create a proxy/override of the config to provide the agent-specific tool registry. + // Create a proxy/override of the config to provide the agent-specific registries. // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment const agentConfig: Config = Object.create(config); agentConfig.getToolRegistry = () => toolRegistry; agentConfig.getMessageBus = () => toolRegistry.getMessageBus(); + + if (promptRegistry) { + agentConfig.getPromptRegistry = () => promptRegistry; + } + if (resourceRegistry) { + agentConfig.getResourceRegistry = () => resourceRegistry; + } + // Override toolRegistry property so AgentLoopContext reads the agent-specific registry. Object.defineProperty(agentConfig, 'toolRegistry', { get: () => toolRegistry, diff --git a/packages/core/src/agents/agentLoader.ts b/packages/core/src/agents/agentLoader.ts index 2141f6d915..2cb7b3c439 100644 --- a/packages/core/src/agents/agentLoader.ts +++ b/packages/core/src/agents/agentLoader.ts @@ -550,12 +550,6 @@ export function markdownToAgentDefinition( config.description, config.include_tools, config.exclude_tools, - undefined, // extension - undefined, // oauth - undefined, // authProviderType - undefined, // targetAudience - undefined, // targetServiceAccount - name, // originalName ); } } diff --git a/packages/core/src/agents/local-executor.test.ts b/packages/core/src/agents/local-executor.test.ts index 35208742f7..4e80bef781 100644 --- a/packages/core/src/agents/local-executor.test.ts +++ b/packages/core/src/agents/local-executor.test.ts @@ -48,6 +48,8 @@ import { debugLogger } from '../utils/debugLogger.js'; import { LocalAgentExecutor, type ActivityCallback } from './local-executor.js'; import { makeFakeConfig } from '../test-utils/config.js'; import { ToolRegistry } from '../tools/tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; +import { ResourceRegistry } from '../resources/resource-registry.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { LSTool } from '../tools/ls.js'; import { LS_TOOL_NAME, READ_FILE_TOOL_NAME } from '../tools/tool-names.js'; @@ -2506,31 +2508,27 @@ describe('LocalAgentExecutor', () => { const mcpManager = mockConfig.getMcpClientManager(); expect(mcpManager?.maybeDiscoverMcpServer).toHaveBeenCalledWith( - '__agent__TestAgent__test-server', + 'test-server', mcpServers['test-server'], + expect.objectContaining({ + toolRegistry: expect.any(ToolRegistry), + promptRegistry: expect.any(PromptRegistry), + resourceRegistry: expect.any(ResourceRegistry), + }), ); }); - it('should filter out other agents MCP tools when inheriting tools from parent registry', async () => { - const parentMcpTool1 = new DiscoveredMCPTool( + it('should inherit main registry tools', async () => { + const parentMcpTool = new DiscoveredMCPTool( {} as unknown as CallableTool, - '__agent__OtherAgent__server1', + 'main-server', 'tool1', 'desc1', {}, mockConfig.getMessageBus(), ); - const parentMcpTool2 = new DiscoveredMCPTool( - {} as unknown as CallableTool, - '__agent__TestAgent__server2', - 'tool2', - 'desc2', - {}, - mockConfig.getMessageBus(), - ); - parentToolRegistry.registerTool(parentMcpTool1); - parentToolRegistry.registerTool(parentMcpTool2); + parentToolRegistry.registerTool(parentMcpTool); const definition = createTestDefinition(); definition.toolConfig = undefined; // trigger inheritance @@ -2549,8 +2547,7 @@ describe('LocalAgentExecutor', () => { executor as unknown as { toolRegistry: ToolRegistry } ).toolRegistry.getAllToolNames(); - expect(agentTools).toContain(parentMcpTool2.name); - expect(agentTools).not.toContain(parentMcpTool1.name); + expect(agentTools).toContain(parentMcpTool.name); }); }); }); diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index 8aaecc5b6d..835a275211 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -17,6 +17,8 @@ import { type Schema, } from '@google/genai'; import { ToolRegistry } from '../tools/tool-registry.js'; +import { PromptRegistry } from '../prompts/prompt-registry.js'; +import { ResourceRegistry } from '../resources/resource-registry.js'; import { type AnyDeclarativeTool } from '../tools/tools.js'; import { DiscoveredMCPTool, @@ -99,6 +101,8 @@ export class LocalAgentExecutor { private readonly agentId: string; private readonly toolRegistry: ToolRegistry; + private readonly promptRegistry: PromptRegistry; + private readonly resourceRegistry: ResourceRegistry; private readonly context: AgentLoopContext; private readonly onActivity?: ActivityCallback; private readonly compressionService: ChatCompressionService; @@ -106,7 +110,18 @@ export class LocalAgentExecutor { private hasFailedCompressionAttempt = false; private get config(): Config { - return this.context.config; + // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment + const agentConfig: Config = Object.create(this.context.config); + agentConfig.getToolRegistry = () => this.toolRegistry; + agentConfig.getPromptRegistry = () => this.promptRegistry; + agentConfig.getResourceRegistry = () => this.resourceRegistry; + agentConfig.getMessageBus = () => this.toolRegistry.getMessageBus(); + + Object.defineProperty(agentConfig, 'toolRegistry', { + get: () => this.toolRegistry, + configurable: true, + }); + return agentConfig; } /** @@ -142,17 +157,23 @@ export class LocalAgentExecutor { return parentMessageBus.publish(message); }; - // Create an isolated tool registry for this agent instance. + // Create isolated registries for this agent instance. const agentToolRegistry = new ToolRegistry( context.config, subagentMessageBus, ); + const agentPromptRegistry = new PromptRegistry(); + const agentResourceRegistry = new ResourceRegistry(); + if (definition.mcpServers) { const globalMcpManager = context.config.getMcpClientManager(); if (globalMcpManager) { for (const [name, config] of Object.entries(definition.mcpServers)) { - const prefixedName = `__agent__${definition.name}__${name}`; - await globalMcpManager.maybeDiscoverMcpServer(prefixedName, config); + await globalMcpManager.maybeDiscoverMcpServer(name, config, { + toolRegistry: agentToolRegistry, + promptRegistry: agentPromptRegistry, + resourceRegistry: agentResourceRegistry, + }); } } } @@ -233,32 +254,10 @@ export class LocalAgentExecutor { } else { // If no tools are explicitly configured, default to all available tools. for (const toolName of parentToolRegistry.getAllToolNames()) { - const tool = parentToolRegistry.getTool(toolName); - if ( - tool instanceof DiscoveredMCPTool && - tool.serverName.startsWith('__agent__') - ) { - if (!tool.serverName.startsWith(`__agent__${definition.name}__`)) { - continue; // Skip other agents' MCP tools - } - } registerToolByName(toolName); } } - // Always ensure this agent's own MCP servers are included, even if toolConfig is restricted - parentToolRegistry.getAllTools().forEach((tool) => { - if ( - tool instanceof DiscoveredMCPTool && - tool.serverName.startsWith(`__agent__${definition.name}__`) - ) { - const qualifiedName = tool.name; - if (!agentToolRegistry.getTool(qualifiedName)) { - registerToolByName(qualifiedName); - } - } - }); - agentToolRegistry.sortTools(); // Get the parent prompt ID from context @@ -271,10 +270,12 @@ export class LocalAgentExecutor { return new LocalAgentExecutor( definition, context, - agentToolRegistry, parentPromptId, - parentCallId, + agentToolRegistry, + agentPromptRegistry, + agentResourceRegistry, onActivity, + parentCallId, ); } @@ -287,14 +288,18 @@ export class LocalAgentExecutor { private constructor( definition: LocalAgentDefinition, context: AgentLoopContext, - toolRegistry: ToolRegistry, parentPromptId: string | undefined, - parentCallId: string | undefined, + toolRegistry: ToolRegistry, + promptRegistry: PromptRegistry, + resourceRegistry: ResourceRegistry, onActivity?: ActivityCallback, + parentCallId?: string, ) { this.definition = definition; this.context = context; this.toolRegistry = toolRegistry; + this.promptRegistry = promptRegistry; + this.resourceRegistry = resourceRegistry; this.onActivity = onActivity; this.compressionService = new ChatCompressionService(); this.parentCallId = parentCallId; @@ -538,7 +543,7 @@ export class LocalAgentExecutor { const combinedSignal = AbortSignal.any([signal, deadlineTimer.signal]); logAgentStart( - this.config, + this.context.config, new AgentStartEvent(this.agentId, this.definition.name), ); @@ -745,7 +750,7 @@ export class LocalAgentExecutor { } finally { deadlineTimer.abort(); logAgentFinish( - this.config, + this.context.config, new AgentFinishEvent( this.agentId, this.definition.name, @@ -1165,10 +1170,12 @@ export class LocalAgentExecutor { this.config, toolRequests, { - schedulerId: this.agentId, + schedulerId: promptId, subagent: this.definition.name, parentCallId: this.parentCallId, toolRegistry: this.toolRegistry, + promptRegistry: this.promptRegistry, + resourceRegistry: this.resourceRegistry, signal, onWaitingForConfirmation, }, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 1eca5d5a35..f4a1673a3a 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -367,6 +367,7 @@ describe('Server Config (config.ts)', () => { mcpStarted = true; }), getMcpInstructions: vi.fn(), + setMainRegistries: vi.fn(), }) as Partial as McpClientManager, ); @@ -400,6 +401,7 @@ describe('Server Config (config.ts)', () => { mcpStarted = true; }), getMcpInstructions: vi.fn(), + setMainRegistries: vi.fn(), }) as Partial as McpClientManager, ); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 55104e91a9..728e7bbe06 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -445,11 +445,6 @@ export class MCPServerConfig { readonly targetAudience?: string, /* targetServiceAccount format: @.iam.gserviceaccount.com */ readonly targetServiceAccount?: string, - /** - * The original name of the server before any prefixing (e.g. for subagents). - * This is used by the policy engine to match rules. - */ - readonly originalName?: string, ) {} } @@ -1202,10 +1197,14 @@ export class Config implements McpContext, AgentLoopContext { discoverToolsHandle?.end(); this.mcpClientManager = new McpClientManager( this.clientVersion, - this._toolRegistry, this, this.eventEmitter, ); + this.mcpClientManager.setMainRegistries({ + toolRegistry: this._toolRegistry, + promptRegistry: this.promptRegistry, + resourceRegistry: this.resourceRegistry, + }); // We do not await this promise so that the CLI can start up even if // MCP servers are slow to connect. this.mcpInitializationPromise = Promise.allSettled([ diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index 60cfeb930e..23473e199d 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -47,11 +47,7 @@ import { type ToolCallResponseInfo, } from '../scheduler/types.js'; import { ToolExecutor } from '../scheduler/tool-executor.js'; -import { - DiscoveredMCPTool, - MCP_TOOL_PREFIX, - MCP_QUALIFIED_NAME_SEPARATOR, -} from '../tools/mcp-tool.js'; +import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { getPolicyDenialError } from '../scheduler/policy.js'; import { GeminiCliOperation } from '../telemetry/constants.js'; @@ -642,16 +638,12 @@ export class CoreToolScheduler { // Policy Check using PolicyEngine // We must reconstruct the FunctionCall format expected by PolicyEngine const toolCallForPolicy = { - name: - toolCall.tool instanceof DiscoveredMCPTool && - toolCall.tool.originalServerName - ? `${MCP_TOOL_PREFIX}${toolCall.tool.originalServerName}${MCP_QUALIFIED_NAME_SEPARATOR}${toolCall.tool.serverToolName}` - : toolCall.request.name, + name: toolCall.request.name, args: toolCall.request.args, }; const serverName = toolCall.tool instanceof DiscoveredMCPTool - ? (toolCall.tool.originalServerName ?? toolCall.tool.serverName) + ? toolCall.tool.serverName : undefined; const toolAnnotations = toolCall.tool.toolAnnotations; diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index e436cea356..88e86efc92 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -34,21 +34,25 @@ describe('McpClientManager', () => { beforeEach(() => { mockedMcpClient = vi.mockObject({ connect: vi.fn(), - discover: vi.fn(), + discoverInto: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), getServerConfig: vi.fn(), + getServerName: vi.fn().mockReturnValue('test-server'), } as unknown as McpClient); vi.mocked(McpClient).mockReturnValue(mockedMcpClient); mockConfig = vi.mockObject({ isTrustedFolder: vi.fn().mockReturnValue(true), getMcpServers: vi.fn().mockReturnValue({}), - getPromptRegistry: () => {}, - getResourceRegistry: () => {}, + getPromptRegistry: vi.fn().mockReturnValue({ registerPrompt: vi.fn() }), + getResourceRegistry: vi + .fn() + .mockReturnValue({ setResourcesForServer: vi.fn() }), getDebugMode: () => false, - getWorkspaceContext: () => {}, + getWorkspaceContext: () => ({ getDirectories: () => [] }), getAllowedMcpServers: vi.fn().mockReturnValue([]), getBlockedMcpServers: vi.fn().mockReturnValue([]), + getExcludedMcpServers: vi.fn().mockReturnValue([]), getMcpServerCommand: vi.fn().mockReturnValue(''), getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined), getGeminiClient: vi.fn().mockReturnValue({ @@ -56,21 +60,39 @@ describe('McpClientManager', () => { }), refreshMcpContext: vi.fn(), } as unknown as Config); - toolRegistry = {} as ToolRegistry; + toolRegistry = vi.mockObject({ + registerTool: vi.fn(), + unregisterTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue({}), + removeMcpToolsByServer: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), + } as unknown as ToolRegistry); }); afterEach(() => { vi.restoreAllMocks(); }); + const setupManager = (manager: McpClientManager) => { + manager.setMainRegistries({ + toolRegistry, + promptRegistry: + mockConfig.getPromptRegistry() as unknown as PromptRegistry, + resourceRegistry: + mockConfig.getResourceRegistry() as unknown as ResourceRegistry, + }); + return manager; + }; + it('should discover tools from all configured', async () => { mockConfig.getMcpServers.mockReturnValue({ 'test-server': {}, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); }); @@ -80,12 +102,12 @@ describe('McpClientManager', () => { 'server-2': {}, 'server-3': {}, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); // Each client should be connected/discovered expect(mockedMcpClient.connect).toHaveBeenCalledTimes(3); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(3); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(3); // But context refresh should happen only once expect(mockConfig.refreshMcpContext).toHaveBeenCalledOnce(); @@ -95,7 +117,7 @@ describe('McpClientManager', () => { mockConfig.getMcpServers.mockReturnValue({ 'test-server': {}, }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.NOT_STARTED); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); @@ -112,7 +134,7 @@ describe('McpClientManager', () => { isFileEnabled: vi.fn().mockResolvedValue(false), }); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); await promise; @@ -120,7 +142,7 @@ describe('McpClientManager', () => { expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED); expect(manager.getMcpServerCount()).toBe(0); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should mark discovery completed when all configured servers are blocked', async () => { @@ -129,7 +151,7 @@ describe('McpClientManager', () => { }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const promise = manager.startConfiguredMcpServers(); expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.IN_PROGRESS); await promise; @@ -137,7 +159,7 @@ describe('McpClientManager', () => { expect(manager.getDiscoveryState()).toBe(MCPDiscoveryState.COMPLETED); expect(manager.getMcpServerCount()).toBe(0); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should not discover tools if folder is not trusted', async () => { @@ -145,10 +167,10 @@ describe('McpClientManager', () => { 'test-server': {}, }); mockConfig.isTrustedFolder.mockReturnValue(false); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should not start blocked servers', async () => { @@ -156,10 +178,10 @@ describe('McpClientManager', () => { 'test-server': {}, }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should only start allowed servers if allow list is not empty', async () => { @@ -168,14 +190,14 @@ describe('McpClientManager', () => { 'another-server': {}, }); mockConfig.getAllowedMcpServers.mockReturnValue(['another-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); }); it('should start servers from extensions', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startExtension({ name: 'test-extension', mcpServers: { @@ -188,11 +210,11 @@ describe('McpClientManager', () => { id: '123', }); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); - expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledOnce(); }); it('should not start servers from disabled extensions', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startExtension({ name: 'test-extension', mcpServers: { @@ -205,7 +227,7 @@ describe('McpClientManager', () => { id: '123', }); expect(mockedMcpClient.connect).not.toHaveBeenCalled(); - expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + expect(mockedMcpClient.discoverInto).not.toHaveBeenCalled(); }); it('should add blocked servers to the blockedMcpServers list', async () => { @@ -213,7 +235,7 @@ describe('McpClientManager', () => { 'test-server': {}, }); mockConfig.getBlockedMcpServers.mockReturnValue(['test-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(manager.getBlockedMcpServers()).toEqual([ { name: 'test-server', extensionName: '' }, @@ -226,16 +248,16 @@ describe('McpClientManager', () => { 'test-server': {}, }); mockedMcpClient.getServerConfig.mockReturnValue({}); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1); await manager.restart(); expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2); }); }); @@ -245,21 +267,21 @@ describe('McpClientManager', () => { 'test-server': {}, }); mockedMcpClient.getServerConfig.mockReturnValue({}); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(1); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(1); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(1); await manager.restartServer('test-server'); expect(mockedMcpClient.disconnect).toHaveBeenCalledTimes(1); expect(mockedMcpClient.connect).toHaveBeenCalledTimes(2); - expect(mockedMcpClient.discover).toHaveBeenCalledTimes(2); + expect(mockedMcpClient.discoverInto).toHaveBeenCalledTimes(2); }); it('should throw an error if the server does not exist', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await expect(manager.restartServer('non-existent')).rejects.toThrow( 'No MCP server registered with the name "non-existent"', ); @@ -281,7 +303,7 @@ describe('McpClientManager', () => { }); mockedMcpClient.getServerConfig.mockReturnValue(originalConfig); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); await manager.startConfiguredMcpServers(); // First call should use the original config @@ -306,9 +328,10 @@ describe('McpClientManager', () => { (name, config) => ({ connect: vi.fn(), - discover: vi.fn(), + discoverInto: vi.fn(), disconnect: vi.fn(), getServerConfig: vi.fn().mockReturnValue(config), + getServerName: vi.fn().mockReturnValue(name), getInstructions: vi .fn() .mockReturnValue( @@ -318,12 +341,7 @@ describe('McpClientManager', () => { ), }) as unknown as McpClient, ); - - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); mockConfig.getMcpServers.mockReturnValue({ 'server-with-instructions': {}, @@ -358,11 +376,7 @@ describe('McpClientManager', () => { 'test-server': {}, }); - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); await expect(manager.startConfiguredMcpServers()).resolves.not.toThrow(); }); @@ -381,11 +395,8 @@ describe('McpClientManager', () => { 'test-server': {}, }); - const manager = new McpClientManager( - '0.0.1', - {} as ToolRegistry, - mockConfig, - ); + const manager = new McpClientManager('0.0.1', mockConfig); + await manager.startConfiguredMcpServers(); await expect(manager.restartServer('test-server')).resolves.not.toThrow(); @@ -394,7 +405,7 @@ describe('McpClientManager', () => { describe('Extension handling', () => { it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const mcpServers = { 'test-server': { command: 'node', args: ['server.js'] }, }; @@ -416,7 +427,7 @@ describe('McpClientManager', () => { }); it('should ignore an extension attempting to register a server with an existing name', async () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const userConfig = { command: 'node', args: ['user-server.js'] }; mockConfig.getMcpServers.mockReturnValue({ @@ -447,7 +458,7 @@ describe('McpClientManager', () => { it('should remove servers from blockedMcpServers when stopExtension is called', async () => { mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']); - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); const mcpServers = { 'blocked-server': { command: 'node', args: ['server.js'] }, }; @@ -485,7 +496,7 @@ describe('McpClientManager', () => { }); it('should emit hint instead of full error when user has not interacted with MCP', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic( 'error', 'Something went wrong', @@ -504,7 +515,7 @@ describe('McpClientManager', () => { }); it('should emit full error when user has interacted with MCP', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.setUserInteractedWithMcp(); manager.emitDiagnostic( 'error', @@ -520,7 +531,7 @@ describe('McpClientManager', () => { }); it('should still deduplicate diagnostic messages after user interaction', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.setUserInteractedWithMcp(); manager.emitDiagnostic('error', 'Same error'); @@ -530,7 +541,7 @@ describe('McpClientManager', () => { }); it('should only show hint once per session', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic('error', 'Error 1'); manager.emitDiagnostic('error', 'Error 2'); @@ -543,7 +554,7 @@ describe('McpClientManager', () => { }); it('should capture last error for a server even when silenced', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic( 'error', @@ -558,7 +569,7 @@ describe('McpClientManager', () => { }); it('should show previously deduplicated errors after interaction clears state', () => { - const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const manager = setupManager(new McpClientManager('0.0.1', mockConfig)); manager.emitDiagnostic('error', 'Same error'); expect(coreEventsMock.emitFeedback).toHaveBeenCalledTimes(1); // The hint diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 43ea9715bc..7f2752561e 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -20,6 +20,11 @@ import type { EventEmitter } from 'node:events'; import { coreEvents } from '../utils/events.js'; import { debugLogger } from '../utils/debugLogger.js'; +import { createHash } from 'node:crypto'; +import { stableStringify } from '../policy/stable-stringify.js'; +import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import type { ResourceRegistry } from '../resources/resource-registry.js'; + /** * Manages the lifecycle of multiple MCP clients, including local child processes. * This class is responsible for starting, stopping, and discovering tools from @@ -30,7 +35,6 @@ export class McpClientManager { // Track all configured servers (including disabled ones) for UI display private allServerConfigs: Map = new Map(); private readonly clientVersion: string; - private readonly toolRegistry: ToolRegistry; private readonly cliConfig: Config; // If we have ongoing MCP client discovery, this completes once that is done. private discoveryPromise: Promise | undefined; @@ -42,6 +46,10 @@ export class McpClientManager { extensionName: string; }> = []; + private mainToolRegistry: ToolRegistry | undefined; + private mainPromptRegistry: PromptRegistry | undefined; + private mainResourceRegistry: ResourceRegistry | undefined; + /** * Track whether the user has explicitly interacted with MCP in this session * (e.g. by running an /mcp command). @@ -66,16 +74,24 @@ export class McpClientManager { constructor( clientVersion: string, - toolRegistry: ToolRegistry, cliConfig: Config, eventEmitter?: EventEmitter, ) { this.clientVersion = clientVersion; - this.toolRegistry = toolRegistry; this.cliConfig = cliConfig; this.eventEmitter = eventEmitter; } + setMainRegistries(registries: { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; + }) { + this.mainToolRegistry = registries.toolRegistry; + this.mainPromptRegistry = registries.promptRegistry; + this.mainResourceRegistry = registries.resourceRegistry; + } + setUserInteractedWithMcp() { this.userInteractedWithMcp = true; } @@ -236,16 +252,17 @@ export class McpClientManager { return false; } - private async disconnectClient(name: string, skipRefresh = false) { - const existing = this.clients.get(name); + private async disconnectClient(clientKey: string, skipRefresh = false) { + const existing = this.clients.get(clientKey); if (existing) { + const serverName = existing.getServerName(); try { - this.clients.delete(name); + this.clients.delete(clientKey); this.eventEmitter?.emit('mcp-client-update', this.clients); await existing.disconnect(); } catch (error) { debugLogger.warn( - `Error stopping client '${name}': ${getErrorMessage(error)}`, + `Error stopping client '${serverName}': ${getErrorMessage(error)}`, ); } finally { if (!skipRefresh) { @@ -257,22 +274,61 @@ export class McpClientManager { } } + private getClientKey(name: string, config: MCPServerConfig): string { + const { extension, ...rest } = config; + const keyData = { + name, + config: rest, + extensionId: extension?.id, + }; + return createHash('sha256').update(stableStringify(keyData)).digest('hex'); + } + async maybeDiscoverMcpServer( name: string, config: MCPServerConfig, + registries?: { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; + }, ): Promise { - const existing = this.clients.get(name); - if ( - existing && - existing.getServerConfig().extension?.id !== config.extension?.id - ) { - const extensionText = config.extension - ? ` from extension "${config.extension.name}"` - : ''; - debugLogger.warn( - `Skipping MCP config for server with name "${name}"${extensionText} as it already exists.`, + const clientKey = this.getClientKey(name, config); + const existing = this.clients.get(clientKey); + + // If no registries are provided (main agent) and a server with this name already exists + // but with a different configuration, handle potential conflicts. + if (!registries) { + const existingSameName = Array.from(this.clients.values()).find( + (c) => c.getServerName() === name, ); - return; + if (existingSameName) { + const existingConfig = existingSameName.getServerConfig(); + const existingKey = this.getClientKey(name, existingConfig); + + if (existingKey !== clientKey) { + const bothMain = !config.extension && !existingConfig.extension; + const sameExtension = + config.extension && + existingConfig.extension && + config.extension.id === existingConfig.extension.id; + + if (bothMain || sameExtension) { + // This is a configuration update from the same source (hot-reload). + // We should stop the old client before starting the new one. + await this.disconnectClient(existingKey, true); + } else { + // This is a conflict (e.g. an extension trying to overwrite a main server). + const extensionText = config.extension + ? ` from extension "${config.extension.name}"` + : ''; + debugLogger.warn( + `Skipping MCP config for server with name "${name}"${extensionText} as a server with that name already exists from a different source.`, + ); + return; + } + } + } } // Always track server config for UI display @@ -291,7 +347,7 @@ export class McpClientManager { // User-disabled servers: disconnect if running, don't start if (await this.isDisabledByUser(name)) { if (existing) { - await this.disconnectClient(name); + await this.disconnectClient(clientKey); } return; } @@ -302,34 +358,46 @@ export class McpClientManager { return; } - const currentDiscoveryPromise = new Promise((resolve, reject) => { - (async () => { + const currentDiscoveryPromise = new Promise((resolve) => { + void (async () => { try { - if (existing) { - this.clients.delete(name); - await existing.disconnect(); + let client = existing; + if (!client) { + client = new McpClient( + name, + config, + this.cliConfig.getWorkspaceContext(), + this.cliConfig, + this.cliConfig.getDebugMode(), + this.clientVersion, + async () => { + debugLogger.log( + `🔔 Refreshing context for server '${name}'...`, + ); + await this.scheduleMcpContextRefresh(); + }, + ); + this.clients.set(clientKey, client); + this.eventEmitter?.emit('mcp-client-update', this.clients); } - const client = new McpClient( - name, - config, - this.toolRegistry, - this.cliConfig.getPromptRegistry(), - this.cliConfig.getResourceRegistry(), - this.cliConfig.getWorkspaceContext(), - this.cliConfig, - this.cliConfig.getDebugMode(), - this.clientVersion, - async () => { - debugLogger.log(`🔔 Refreshing context for server '${name}'...`); - await this.scheduleMcpContextRefresh(); - }, - ); - this.clients.set(name, client); - this.eventEmitter?.emit('mcp-client-update', this.clients); + const targetRegistries = + registries ?? + (this.mainToolRegistry && + this.mainPromptRegistry && + this.mainResourceRegistry + ? { + toolRegistry: this.mainToolRegistry, + promptRegistry: this.mainPromptRegistry, + resourceRegistry: this.mainResourceRegistry, + } + : undefined); + try { await client.connect(); - await client.discover(this.cliConfig); + if (targetRegistries) { + await client.discoverInto(this.cliConfig, targetRegistries); + } this.eventEmitter?.emit('mcp-client-update', this.clients); } catch (error) { this.eventEmitter?.emit('mcp-client-update', this.clients); @@ -349,13 +417,13 @@ export class McpClientManager { const errorMessage = getErrorMessage(error); this.emitDiagnostic( 'error', - `Error initializing MCP server '${name}': ${errorMessage}`, + `Fatal error ensuring MCP server '${name}' is connected: ${errorMessage}`, error, ); } finally { resolve(); } - })().catch(reject); + })(); }); if (this.discoveryPromise) { @@ -438,6 +506,11 @@ export class McpClientManager { * Restarts all MCP servers (including newly enabled ones). */ async restart(): Promise { + const disconnectionPromises = Array.from(this.clients.keys()).map((key) => + this.disconnectClient(key, true), + ); + await Promise.all(disconnectionPromises); + await Promise.all( Array.from(this.allServerConfigs.entries()).map( async ([name, config]) => { @@ -462,6 +535,8 @@ export class McpClientManager { if (!config) { throw new Error(`No MCP server registered with the name "${name}"`); } + const clientKey = this.getClientKey(name, config); + await this.disconnectClient(clientKey, true); await this.maybeDiscoverMcpServer(name, config); await this.scheduleMcpContextRefresh(); } @@ -506,11 +581,12 @@ export class McpClientManager { getMcpInstructions(): string { const instructions: string[] = []; - for (const [name, client] of this.clients) { + for (const client of this.clients.values()) { + const serverName = client.getServerName(); const clientInstructions = client.getInstructions(); if (clientInstructions) { instructions.push( - `The following are instructions provided by the tool server '${name}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`, + `The following are instructions provided by the tool server '${serverName}':\n---[start of server instructions]---\n${clientInstructions}\n---[end of server instructions]---`, ); } } diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 9b501e8387..7beae97559 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -130,6 +130,12 @@ export interface McpProgressReporter { unregisterProgressToken(token: string | number): void; } +export interface RegistrySet { + toolRegistry: ToolRegistry; + promptRegistry: PromptRegistry; + resourceRegistry: ResourceRegistry; +} + /** * A client for a single MCP server. * @@ -147,6 +153,8 @@ export class McpClient implements McpProgressReporter { private isRefreshingPrompts: boolean = false; private pendingPromptRefresh: boolean = false; + private readonly registeredRegistries = new Set(); + /** * Map of progress tokens to tool call IDs. * This allows us to route progress notifications to the correct tool call. @@ -156,9 +164,6 @@ export class McpClient implements McpProgressReporter { constructor( private readonly serverName: string, private readonly serverConfig: MCPServerConfig, - private readonly toolRegistry: ToolRegistry, - private readonly promptRegistry: PromptRegistry, - private readonly resourceRegistry: ResourceRegistry, private readonly workspaceContext: WorkspaceContext, private readonly cliConfig: McpContext, private readonly debugMode: boolean, @@ -166,6 +171,10 @@ export class McpClient implements McpProgressReporter { private readonly onContextUpdated?: (signal?: AbortSignal) => Promise, ) {} + getServerName(): string { + return this.serverName; + } + /** * Connects to the MCP server. */ @@ -210,27 +219,34 @@ export class McpClient implements McpProgressReporter { } /** - * Discovers tools and prompts from the MCP server. + * Discovers tools and prompts from the MCP server into the specified registries. */ - async discover(cliConfig: McpContext): Promise { + async discoverInto( + cliConfig: McpContext, + registries: RegistrySet, + ): Promise { this.assertConnected(); + this.registeredRegistries.add(registries); const prompts = await this.fetchPrompts(); - const tools = await this.discoverTools(cliConfig); + const tools = await this.discoverTools( + cliConfig, + registries.toolRegistry.getMessageBus(), + ); const resources = await this.discoverResources(); - this.updateResourceRegistry(resources); + this.updateResourceRegistry(resources, registries.resourceRegistry); if (prompts.length === 0 && tools.length === 0 && resources.length === 0) { throw new Error('No prompts, tools, or resources found on the server.'); } for (const prompt of prompts) { - this.promptRegistry.registerPrompt(prompt); + registries.promptRegistry.registerPrompt(prompt); } for (const tool of tools) { - this.toolRegistry.registerTool(tool); + registries.toolRegistry.registerTool(tool); } - this.toolRegistry.sortTools(); + registries.toolRegistry.sortTools(); // Validate MCP tool names in policy rules against discovered tools try { @@ -257,9 +273,11 @@ export class McpClient implements McpProgressReporter { if (this.status !== MCPServerStatus.CONNECTED) { return; } - this.toolRegistry.removeMcpToolsByServer(this.serverName); - this.promptRegistry.removePromptsByServer(this.serverName); - this.resourceRegistry.removeResourcesByServer(this.serverName); + for (const registries of this.registeredRegistries) { + registries.toolRegistry.removeMcpToolsByServer(this.serverName); + registries.promptRegistry.removePromptsByServer(this.serverName); + registries.resourceRegistry.removeResourcesByServer(this.serverName); + } this.updateStatus(MCPServerStatus.DISCONNECTING); const client = this.client; this.client = undefined; @@ -294,6 +312,7 @@ export class McpClient implements McpProgressReporter { private async discoverTools( cliConfig: McpContext, + messageBus: MessageBus, options?: { timeout?: number; signal?: AbortSignal }, ): Promise { this.assertConnected(); @@ -302,7 +321,7 @@ export class McpClient implements McpProgressReporter { this.serverConfig, this.client!, cliConfig, - this.toolRegistry.getMessageBus(), + messageBus, { ...(options ?? { timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, @@ -329,8 +348,11 @@ export class McpClient implements McpProgressReporter { return discoverResources(this.serverName, this.client!, this.cliConfig); } - private updateResourceRegistry(resources: Resource[]): void { - this.resourceRegistry.setResourcesForServer(this.serverName, resources); + private updateResourceRegistry( + resources: Resource[], + resourceRegistry: ResourceRegistry, + ): void { + resourceRegistry.setResourcesForServer(this.serverName, resources); } async readResource( @@ -482,23 +504,32 @@ export class McpClient implements McpProgressReporter { try { newResources = await this.discoverResources(); - // Verification Retry: If no resources are found or resources didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentResources = - this.resourceRegistry.getResourcesByServer(this.serverName) || []; - const resourceMatch = - newResources.length === currentResources.length && - newResources.every((nr: Resource) => - currentResources.some((cr: MCPResource) => cr.uri === nr.uri), - ); + for (const registries of this.registeredRegistries) { + // Verification Retry: If no resources are found or resources didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentResources = + registries.resourceRegistry.getResourcesByServer( + this.serverName, + ) || []; + const resourceMatch = + newResources.length === currentResources.length && + newResources.every((nr: Resource) => + currentResources.some((cr: MCPResource) => cr.uri === nr.uri), + ); - if (resourceMatch && !this.pendingResourceRefresh) { - debugLogger.log( - `No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`, + if (resourceMatch && !this.pendingResourceRefresh) { + debugLogger.log( + `No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newResources = await this.discoverResources(); + } + + this.updateResourceRegistry( + newResources, + registries.resourceRegistry, ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newResources = await this.discoverResources(); } } catch (err) { debugLogger.error( @@ -508,8 +539,6 @@ export class McpClient implements McpProgressReporter { break; } - this.updateResourceRegistry(newResources); - if (this.onContextUpdated) { await this.onContextUpdated(abortController.signal); } @@ -575,30 +604,33 @@ export class McpClient implements McpProgressReporter { signal: abortController.signal, }); - // Verification Retry: If no prompts are found or prompts didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentPrompts = - this.promptRegistry.getPromptsByServer(this.serverName) || []; - const promptsMatch = - newPrompts.length === currentPrompts.length && - newPrompts.every((np) => - currentPrompts.some((cp) => cp.name === np.name), - ); + for (const registries of this.registeredRegistries) { + // Verification Retry: If no prompts are found or prompts didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentPrompts = + registries.promptRegistry.getPromptsByServer(this.serverName) || + []; + const promptsMatch = + newPrompts.length === currentPrompts.length && + newPrompts.every((np) => + currentPrompts.some((cp) => cp.name === np.name), + ); - if (promptsMatch && !this.pendingPromptRefresh) { - debugLogger.log( - `No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`, - ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newPrompts = await this.fetchPrompts({ - signal: abortController.signal, - }); - } + if (promptsMatch && !this.pendingPromptRefresh) { + debugLogger.log( + `No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newPrompts = await this.fetchPrompts({ + signal: abortController.signal, + }); + } - this.promptRegistry.removePromptsByServer(this.serverName); - for (const prompt of newPrompts) { - this.promptRegistry.registerPrompt(prompt); + registries.promptRegistry.removePromptsByServer(this.serverName); + for (const prompt of newPrompts) { + registries.promptRegistry.registerPrompt(prompt); + } } } catch (err) { debugLogger.error( @@ -666,42 +698,58 @@ export class McpClient implements McpProgressReporter { const abortController = new AbortController(); const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); - let newTools; try { - newTools = await this.discoverTools(this.cliConfig, { - signal: abortController.signal, - }); - debugLogger.log( - `Refresh for '${this.serverName}' discovered ${newTools.length} tools.`, - ); - - // Verification Retry (Option 3): If no tools are found or tools didn't change, - // wait briefly and try one more time. Some servers notify before they're fully ready. - const currentTools = - this.toolRegistry.getToolsByServer(this.serverName) || []; - const toolNamesMatch = - newTools.length === currentTools.length && - newTools.every((nt) => - currentTools.some( - (ct) => - ct.name === nt.name || - (ct instanceof DiscoveredMCPTool && - ct.serverToolName === nt.serverToolName), - ), + for (const registries of this.registeredRegistries) { + let newTools = await this.discoverTools( + this.cliConfig, + registries.toolRegistry.getMessageBus(), + { + signal: abortController.signal, + }, + ); + debugLogger.log( + `Refresh for '${this.serverName}' discovered ${newTools.length} tools.`, ); - if (toolNamesMatch && !this.pendingToolRefresh) { - debugLogger.log( - `No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`, - ); - const retryDelay = 500; - await new Promise((resolve) => setTimeout(resolve, retryDelay)); - newTools = await this.discoverTools(this.cliConfig, { - signal: abortController.signal, - }); - debugLogger.log( - `Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`, - ); + // Verification Retry (Option 3): If no tools are found or tools didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentTools = + registries.toolRegistry.getToolsByServer(this.serverName) || []; + const toolNamesMatch = + newTools.length === currentTools.length && + newTools.every((nt) => + currentTools.some( + (ct) => + ct.name === nt.name || + (ct instanceof DiscoveredMCPTool && + ct.serverToolName === nt.serverToolName), + ), + ); + + if (toolNamesMatch && !this.pendingToolRefresh) { + debugLogger.log( + `No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newTools = await this.discoverTools( + this.cliConfig, + registries.toolRegistry.getMessageBus(), + { + signal: abortController.signal, + }, + ); + debugLogger.log( + `Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`, + ); + } + + registries.toolRegistry.removeMcpToolsByServer(this.serverName); + + for (const tool of newTools) { + registries.toolRegistry.registerTool(tool); + } + registries.toolRegistry.sortTools(); } } catch (err) { debugLogger.error( @@ -711,13 +759,6 @@ export class McpClient implements McpProgressReporter { break; } - this.toolRegistry.removeMcpToolsByServer(this.serverName); - - for (const tool of newTools) { - this.toolRegistry.registerTool(tool); - } - this.toolRegistry.sortTools(); - if (this.onContextUpdated) { await this.onContextUpdated(abortController.signal); } @@ -1266,7 +1307,6 @@ export async function discoverTools( mcpServerConfig.extension?.name, mcpServerConfig.extension?.id, annotations as Record | undefined, - mcpServerConfig.originalName, ); discoveredTools.push(tool); diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 0d8717b378..5702f88a52 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -348,7 +348,6 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< override readonly extensionName?: string, override readonly extensionId?: string, private readonly _toolAnnotations?: Record, - readonly originalServerName?: string, ) { super( nameOverride ?? diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index c8c2c2856c..291f43d908 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -284,34 +284,23 @@ describe('ToolRegistry', () => { }); }); - describe('subagent MCP tools filtering', () => { - it('should hide __agent__ prefixed tools when isMainRegistry is true', async () => { - const mainRegistry = new ToolRegistry(config, mockMessageBus, true); - const subagentRegistry = new ToolRegistry(config, mockMessageBus, false); + describe('removeMcpToolsByServer', () => { + it('should remove all tools from a specific server', () => { + const serverName = 'test-server'; + const mcpTool1 = createMCPTool(serverName, 'tool1', 'desc1'); + const mcpTool2 = createMCPTool(serverName, 'tool2', 'desc2'); + const otherTool = createMCPTool('other-server', 'tool3', 'desc3'); - const mcpTool = createMCPTool( - '__agent__TestAgent__myServer', - 'my-tool', - 'description', - ); - vi.spyOn(mcpTool, 'getSchema').mockReturnValue({ - name: 'my_tool', - description: 'description', - } as unknown as FunctionDeclaration); + toolRegistry.registerTool(mcpTool1); + toolRegistry.registerTool(mcpTool2); + toolRegistry.registerTool(otherTool); - mainRegistry.registerTool(mcpTool); - subagentRegistry.registerTool(mcpTool); + expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(2); - const mainDeclarations = - mainRegistry.getFunctionDeclarations('test-model'); - const subagentDeclarations = - subagentRegistry.getFunctionDeclarations('test-model'); + toolRegistry.removeMcpToolsByServer(serverName); - expect(mainDeclarations.length).toBe(0); - expect(subagentDeclarations.length).toBe(1); - expect(subagentDeclarations[0].name).toBe( - 'mcp___agent__TestAgent__myServer_my-tool', - ); + expect(toolRegistry.getToolsByServer(serverName)).toHaveLength(0); + expect(toolRegistry.getToolsByServer('other-server')).toHaveLength(1); }); }); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index ea2a973818..8789e82ca6 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -561,14 +561,6 @@ export class ToolRegistry { return; } - if ( - this.isMainRegistry && - tool instanceof DiscoveredMCPTool && - tool.serverName.startsWith('__agent__') - ) { - return; - } - if ( mainAgentTools && !mainAgentTools.includes(toolName) &&