diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index c62e89184a..408f20007e 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -174,7 +174,15 @@ export class McpClientManager { this.toolRegistry, this.cliConfig.getPromptRegistry(), this.cliConfig.getWorkspaceContext(), + this.cliConfig, this.cliConfig.getDebugMode(), + async () => { + debugLogger.log('Tools changed, updating Gemini context...'); + const geminiClient = this.cliConfig.getGeminiClient(); + if (geminiClient.isInitialized()) { + await geminiClient.setTools(); + } + }, ); if (!existing) { this.clients.set(name, client); diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 6ddd5b1271..2075d13c3c 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -18,6 +18,8 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; +import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; + import { WorkspaceContext } from '../utils/workspaceContext.js'; import { connectToMcpServer, @@ -111,11 +113,15 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); await client.discover({} as Config); - expect(mockedClient.listTools).toHaveBeenCalledWith({}); + expect(mockedClient.listTools).toHaveBeenCalledWith( + {}, + { timeout: 600000 }, + ); }); it('should not skip tools even if a parameter is missing a type', async () => { @@ -177,6 +183,7 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -217,6 +224,7 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -261,6 +269,7 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -309,6 +318,7 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -371,6 +381,7 @@ describe('mcp-client', () => { mockedToolRegistry, {} as PromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -444,6 +455,7 @@ describe('mcp-client', () => { mockedToolRegistry, mockedPromptRegistry, workspaceContext, + {} as Config, false, ); await client.connect(); @@ -459,6 +471,439 @@ describe('mcp-client', () => { expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce(); }); }); + + describe('Dynamic Tool Updates', () => { + it('should set up notification handler if server supports tool list changes', async () => { + const mockedClient = { + connect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + // Capability enables the listener + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + {} as ToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + ); + + await client.connect(); + + expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith( + ToolListChangedNotificationSchema, + expect.any(Function), + ); + }); + + it('should NOT set up notification handler if server lacks capability', async () => { + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), // No listChanged + setNotificationHandler: vi.fn(), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + {} as ToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + ); + + await client.connect(); + + expect(mockedClient.setNotificationHandler).not.toHaveBeenCalled(); + }); + + it('should refresh tools and notify manager when notification is received', async () => { + // Setup mocks + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'newTool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + removeMcpToolsByServer: vi.fn(), + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + + const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + + // Initialize client with onToolsUpdated callback + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + onToolsUpdatedSpy, + ); + + // 1. Connect (sets up listener) + await client.connect(); + + // 2. Extract the callback passed to setNotificationHandler + const notificationCallback = + mockedClient.setNotificationHandler.mock.calls[0][1]; + + // 3. Trigger the notification manually + await notificationCallback(); + + // 4. Assertions + // It should clear old tools + expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith( + 'test-server', + ); + + // It should fetch new tools (listTools called inside discoverTools) + expect(mockedClient.listTools).toHaveBeenCalled(); + + // It should register the new tool + expect(mockedToolRegistry.registerTool).toHaveBeenCalled(); + + // It should notify the manager + expect(onToolsUpdatedSpy).toHaveBeenCalled(); + + // It should emit feedback event + expect(coreEvents.emitFeedback).toHaveBeenCalledWith( + 'info', + 'Tools updated for server: test-server', + ); + }); + + it('should handle errors during tool refresh gracefully', async () => { + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + // Simulate error during discovery + listTools: vi.fn().mockRejectedValue(new Error('Network blip')), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + removeMcpToolsByServer: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + ); + + await client.connect(); + + const notificationCallback = + mockedClient.setNotificationHandler.mock.calls[0][1]; + + // Trigger notification - should fail internally but catch the error + await notificationCallback(); + + // Should try to remove tools + expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalled(); + + // Should NOT emit success feedback + expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith( + 'info', + expect.stringContaining('Tools updated'), + ); + }); + + it('should handle concurrent updates from multiple servers', async () => { + const createMockSdkClient = (toolName: string) => ({ + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: toolName, + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }); + + const mockClientA = createMockSdkClient('tool-from-A'); + const mockClientB = createMockSdkClient('tool-from-B'); + + vi.mocked(ClientLib.Client) + .mockReturnValueOnce(mockClientA as unknown as ClientLib.Client) + .mockReturnValueOnce(mockClientB as unknown as ClientLib.Client); + + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + removeMcpToolsByServer: vi.fn(), + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + + const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + + const clientA = new McpClient( + 'server-A', + { command: 'cmd-a' }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + onToolsUpdatedSpy, + ); + + const clientB = new McpClient( + 'server-B', + { command: 'cmd-b' }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + onToolsUpdatedSpy, + ); + + await clientA.connect(); + await clientB.connect(); + + const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1]; + const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1]; + + // Trigger burst updates simultaneously + await Promise.all([handlerA(), handlerB()]); + + expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith( + 'server-A', + ); + expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith( + 'server-B', + ); + + // Verify fetching happened on both clients + expect(mockClientA.listTools).toHaveBeenCalled(); + expect(mockClientB.listTools).toHaveBeenCalled(); + + // Verify tools from both servers were registered (2 total calls) + expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); + + // Verify the update callback was triggered for both + expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2); + }); + + it('should abort discovery and log error if timeout is exceeded during refresh', async () => { + vi.useFakeTimers(); + + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + // Mock listTools to simulate a long running process that respects the abort signal + listTools: vi.fn().mockImplementation( + async (params, options) => + new Promise((resolve, reject) => { + if (options?.signal?.aborted) { + return reject(new Error('Operation aborted')); + } + options?.signal?.addEventListener('abort', () => { + reject(new Error('Operation aborted')); + }); + // Intentionally do not resolve immediately to simulate lag + }), + ), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + removeMcpToolsByServer: vi.fn(), + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + + const client = new McpClient( + 'test-server', + // Set a short timeout + { command: 'test-command', timeout: 100 }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + ); + + await client.connect(); + + const notificationCallback = + mockedClient.setNotificationHandler.mock.calls[0][1]; + + const refreshPromise = notificationCallback(); + + vi.advanceTimersByTime(150); + + await refreshPromise; + + expect(mockedClient.listTools).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), + ); + + expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled(); + + vi.useRealTimers(); + }); + + it('should pass abort signal to onToolsUpdated callback', async () => { + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: true } }), + setNotificationHandler: vi.fn(), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + removeMcpToolsByServer: vi.fn(), + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + + const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + {} as Config, + false, + onToolsUpdatedSpy, + ); + + await client.connect(); + + const notificationCallback = + mockedClient.setNotificationHandler.mock.calls[0][1]; + + await notificationCallback(); + + expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal)); + + // Verify the signal passed was not aborted (happy path) + const signal = onToolsUpdatedSpy.mock.calls[0][0]; + expect(signal.aborted).toBe(false); + }); + }); + describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index e988e3b346..21dd018e55 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -23,6 +23,7 @@ import type { } from '@modelcontextprotocol/sdk/types.js'; import { ListRootsRequestSchema, + ToolListChangedNotificationSchema, type Tool as McpTool, } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; @@ -97,6 +98,8 @@ export class McpClient { private client: Client | undefined; private transport: Transport | undefined; private status: MCPServerStatus = MCPServerStatus.DISCONNECTED; + private isRefreshing: boolean = false; + private pendingRefresh: boolean = false; constructor( private readonly serverName: string, @@ -104,7 +107,9 @@ export class McpClient { private readonly toolRegistry: ToolRegistry, private readonly promptRegistry: PromptRegistry, private readonly workspaceContext: WorkspaceContext, + private readonly cliConfig: Config, private readonly debugMode: boolean, + private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise, ) {} /** @@ -124,6 +129,25 @@ export class McpClient { this.debugMode, this.workspaceContext, ); + + // setup dynamic tool listener + const capabilities = this.client.getServerCapabilities(); + + if (capabilities?.tools?.listChanged) { + debugLogger.log( + `Server '${this.serverName}' supports tool updates. Listening for changes...`, + ); + + this.client.setNotificationHandler( + ToolListChangedNotificationSchema, + async () => { + debugLogger.log( + `🔔 Received tool update notification from '${this.serverName}'`, + ); + await this.refreshTools(); + }, + ); + } const originalOnError = this.client.onerror; this.client.onerror = (error) => { if (this.status !== MCPServerStatus.CONNECTED) { @@ -204,7 +228,10 @@ export class McpClient { } } - private async discoverTools(cliConfig: Config): Promise { + private async discoverTools( + cliConfig: Config, + options?: { timeout?: number; signal?: AbortSignal }, + ): Promise { this.assertConnected(); return discoverTools( this.serverName, @@ -212,6 +239,9 @@ export class McpClient { this.client!, cliConfig, this.toolRegistry.getMessageBus(), + options ?? { + timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }, ); } @@ -227,6 +257,75 @@ export class McpClient { getInstructions(): string | undefined { return this.client?.getInstructions(); } + + /** + * Refreshes the tools for this server by re-querying the MCP `tools/list` endpoint. + * + * This method implements a **Coalescing Pattern** to handle rapid bursts of notifications + * (e.g., during server startup or bulk updates) without overwhelming the server or + * creating race conditions in the global ToolRegistry. + */ + private async refreshTools(): Promise { + if (this.isRefreshing) { + debugLogger.log( + `Tool refresh for '${this.serverName}' is already in progress. Pending update.`, + ); + this.pendingRefresh = true; + return; + } + + this.isRefreshing = true; + + try { + do { + this.pendingRefresh = false; + + if (this.status !== MCPServerStatus.CONNECTED || !this.client) break; + + const timeoutMs = this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC; + const abortController = new AbortController(); + const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); + + let newTools; + try { + newTools = await this.discoverTools(this.cliConfig, { + signal: abortController.signal, + }); + } catch (err) { + debugLogger.error( + `Discovery failed during refresh: ${getErrorMessage(err)}`, + ); + clearTimeout(timeoutId); + break; + } + + this.toolRegistry.removeMcpToolsByServer(this.serverName); + + for (const tool of newTools) { + this.toolRegistry.registerTool(tool); + } + this.toolRegistry.sortTools(); + + if (this.onToolsUpdated) { + await this.onToolsUpdated(abortController.signal); + } + + clearTimeout(timeoutId); + + coreEvents.emitFeedback( + 'info', + `Tools updated for server: ${this.serverName}`, + ); + } while (this.pendingRefresh); + } catch (error) { + debugLogger.error( + `Critical error in refresh loop for ${this.serverName}: ${getErrorMessage(error)}`, + ); + } finally { + this.isRefreshing = false; + this.pendingRefresh = false; + } + } } /** @@ -622,6 +721,7 @@ export async function connectAndDiscover( mcpClient, cliConfig, toolRegistry.getMessageBus(), + { timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC }, ); // If we have neither prompts nor tools, it's a failed discovery @@ -671,12 +771,13 @@ export async function discoverTools( mcpClient: Client, cliConfig: Config, messageBus?: MessageBus, + options?: { timeout?: number; signal?: AbortSignal }, ): Promise { try { // Only request tools if the server supports them. if (mcpClient.getServerCapabilities()?.tools == null) return []; - const response = await mcpClient.listTools({}); + const response = await mcpClient.listTools({}, options); const discoveredTools: DiscoveredMCPTool[] = []; for (const toolDef of response.tools) { try {