diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 45e06a43c6..4a07b1aa4f 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -223,7 +223,7 @@ describe('mcp-client', () => { consoleWarnSpy.mockRestore(); }); - it('should handle errors when discovering prompts', async () => { + it('should propagate errors when discovering prompts', async () => { const mockedClient = { connect: vi.fn(), discover: vi.fn(), @@ -269,9 +269,7 @@ describe('mcp-client', () => { '0.0.1', ); await client.connect(); - await expect(client.discover({} as Config)).rejects.toThrow( - 'No prompts, tools, or resources found on the server.', - ); + await expect(client.discover({} as Config)).rejects.toThrow('Test error'); expect(coreEvents.emitFeedback).toHaveBeenCalledWith( 'error', `Error discovering prompts from test-server: Test error`, @@ -640,6 +638,89 @@ describe('mcp-client', () => { ); }); + it('refreshes prompts when prompt list change notification is received', async () => { + let listCallCount = 0; + let promptListHandler: + | ((notification: unknown) => Promise | void) + | undefined; + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn((_, handler) => { + promptListHandler = handler; + }), + getServerCapabilities: vi + .fn() + .mockReturnValue({ prompts: { listChanged: true } }), + listPrompts: vi.fn().mockImplementation(() => { + listCallCount += 1; + if (listCallCount === 1) { + return Promise.resolve({ + prompts: [{ name: 'one', description: 'first' }], + }); + } + return Promise.resolve({ + prompts: [{ name: 'two', description: 'second' }], + }); + }), + request: vi.fn().mockResolvedValue({ prompts: [] }), + } as unknown as ClientLib.Client; + vi.mocked(ClientLib.Client).mockReturnValue(mockedClient); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + const mockedToolRegistry = { + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + const promptRegistry = { + registerPrompt: vi.fn(), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry; + const resourceRegistry = { + setResourcesForServer: vi.fn(), + removeResourcesByServer: vi.fn(), + } as unknown as ResourceRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + mockedToolRegistry, + promptRegistry, + resourceRegistry, + workspaceContext, + { sanitizationConfig: EMPTY_CONFIG } as Config, + false, + '0.0.1', + ); + await client.connect(); + await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config); + + expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce(); + expect(promptListHandler).toBeDefined(); + + await promptListHandler?.({ + method: 'notifications/prompts/list_changed', + }); + + expect(promptRegistry.removePromptsByServer).toHaveBeenCalledWith( + 'test-server', + ); + expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith( + expect.objectContaining({ name: 'two' }), + ); + expect(coreEvents.emitFeedback).toHaveBeenCalledWith( + 'info', + 'Prompts updated for server: test-server', + ); + }); + it('should remove tools and prompts on disconnect', async () => { const mockedClient = { connect: vi.fn(), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 872a5019d4..4533ce4236 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -29,6 +29,7 @@ import { ReadResourceResultSchema, ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, + PromptListChangedNotificationSchema, type Tool as McpTool, } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; @@ -112,6 +113,8 @@ export class McpClient { private pendingToolRefresh: boolean = false; private isRefreshingResources: boolean = false; private pendingResourceRefresh: boolean = false; + private isRefreshingPrompts: boolean = false; + private pendingPromptRefresh: boolean = false; constructor( private readonly serverName: string, @@ -174,7 +177,7 @@ export class McpClient { async discover(cliConfig: Config): Promise { this.assertConnected(); - const prompts = await this.discoverPrompts(); + const prompts = await this.fetchPrompts(); const tools = await this.discoverTools(cliConfig); const resources = await this.discoverResources(); this.updateResourceRegistry(resources); @@ -183,6 +186,9 @@ export class McpClient { throw new Error('No prompts, tools, or resources found on the server.'); } + for (const prompt of prompts) { + this.promptRegistry.registerPrompt(prompt); + } for (const tool of tools) { this.toolRegistry.registerTool(tool); } @@ -248,9 +254,11 @@ export class McpClient { ); } - private async discoverPrompts(): Promise { + private async fetchPrompts(options?: { + signal?: AbortSignal; + }): Promise { this.assertConnected(); - return discoverPrompts(this.serverName, this.client!, this.promptRegistry); + return discoverPrompts(this.serverName, this.client!, options); } private async discoverResources(): Promise { @@ -315,6 +323,22 @@ export class McpClient { }, ); } + + if (capabilities?.prompts?.listChanged) { + debugLogger.log( + `Server '${this.serverName}' supports prompt updates. Listening for changes...`, + ); + + this.client.setNotificationHandler( + PromptListChangedNotificationSchema, + async () => { + debugLogger.log( + `🔔 Received prompt update notification from '${this.serverName}'`, + ); + await this.refreshPrompts(); + }, + ); + } } /** @@ -375,6 +399,63 @@ export class McpClient { } } + /** + * Refreshes prompts for this server by re-querying the MCP `prompts/list` endpoint. + */ + private async refreshPrompts(): Promise { + if (this.isRefreshingPrompts) { + debugLogger.log( + `Prompt refresh for '${this.serverName}' is already in progress. Pending update.`, + ); + this.pendingPromptRefresh = true; + return; + } + + this.isRefreshingPrompts = true; + + try { + do { + this.pendingPromptRefresh = false; + + if (this.status !== MCPServerStatus.CONNECTED || !this.client) break; + + const timeoutMs = this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC; + const abortController = new AbortController(); + const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); + + try { + const newPrompts = await this.fetchPrompts({ + signal: abortController.signal, + }); + this.promptRegistry.removePromptsByServer(this.serverName); + for (const prompt of newPrompts) { + this.promptRegistry.registerPrompt(prompt); + } + } catch (err) { + debugLogger.error( + `Prompt discovery failed during refresh: ${getErrorMessage(err)}`, + ); + clearTimeout(timeoutId); + break; + } + + clearTimeout(timeoutId); + + coreEvents.emitFeedback( + 'info', + `Prompts updated for server: ${this.serverName}`, + ); + } while (this.pendingPromptRefresh); + } catch (error) { + debugLogger.error( + `Critical error in prompt refresh loop for ${this.serverName}: ${getErrorMessage(error)}`, + ); + } finally { + this.isRefreshingPrompts = false; + this.pendingPromptRefresh = false; + } + } + getServerConfig(): MCPServerConfig { return this.serverConfig; } @@ -840,11 +921,7 @@ export async function connectAndDiscover( }; // Attempt to discover both prompts and tools - const prompts = await discoverPrompts( - mcpServerName, - mcpClient, - promptRegistry, - ); + const prompts = await discoverPrompts(mcpServerName, mcpClient); const tools = await discoverTools( mcpServerName, mcpServerConfig, @@ -862,7 +939,10 @@ export async function connectAndDiscover( // If we found anything, the server is connected updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED); - // Register any discovered tools + // Register any discovered prompts and tools + for (const prompt of prompts) { + promptRegistry.registerPrompt(prompt); + } for (const tool of tools) { toolRegistry.registerTool(tool); } @@ -1038,39 +1118,32 @@ class McpCallableTool implements CallableTool { export async function discoverPrompts( mcpServerName: string, mcpClient: Client, - promptRegistry: PromptRegistry, -): Promise { + options?: { signal?: AbortSignal }, +): Promise { + // Only request prompts if the server supports them. + if (mcpClient.getServerCapabilities()?.prompts == null) return []; + try { - // Only request prompts if the server supports them. - if (mcpClient.getServerCapabilities()?.prompts == null) return []; - - const response = await mcpClient.listPrompts({}); - - for (const prompt of response.prompts) { - promptRegistry.registerPrompt({ - ...prompt, - serverName: mcpServerName, - invoke: (params: Record) => - invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), - }); - } - return response.prompts; + const response = await mcpClient.listPrompts({}, options); + return response.prompts.map((prompt) => ({ + ...prompt, + serverName: mcpServerName, + invoke: (params: Record) => + invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params), + })); } catch (error) { - // It's okay if this fails, not all servers will have prompts. - // Don't log an error if the method is not found, which is a common case. - if ( - error instanceof Error && - !error.message?.includes('Method not found') - ) { - coreEvents.emitFeedback( - 'error', - `Error discovering prompts from ${mcpServerName}: ${getErrorMessage( - error, - )}`, - error, - ); + // It's okay if the method is not found, which is a common case. + if (error instanceof Error && error.message?.includes('Method not found')) { + return []; } - return []; + coreEvents.emitFeedback( + 'error', + `Error discovering prompts from ${mcpServerName}: ${getErrorMessage( + error, + )}`, + error, + ); + throw error; } }