From 352fb0c97680b11578f88277f9e0265aa9a2191a Mon Sep 17 00:00:00 2001 From: Jacob Richman Date: Wed, 4 Mar 2026 06:46:17 -0800 Subject: [PATCH] fix(mcp): Notifications/tools/list_changed support not working (#21050) Co-authored-by: Bryan Morgan --- .../core/src/resources/resource-registry.ts | 13 + packages/core/src/tools/mcp-client-manager.ts | 50 +++- packages/core/src/tools/mcp-client.test.ts | 244 +++++++++++++++--- packages/core/src/tools/mcp-client.ts | 144 +++++++++-- 4 files changed, 378 insertions(+), 73 deletions(-) diff --git a/packages/core/src/resources/resource-registry.ts b/packages/core/src/resources/resource-registry.ts index 1c2c754504..ce30456df5 100644 --- a/packages/core/src/resources/resource-registry.ts +++ b/packages/core/src/resources/resource-registry.ts @@ -69,4 +69,17 @@ export class ResourceRegistry { clear(): void { this.resources.clear(); } + + /** + * Returns an array of resources registered from a specific MCP server. + */ + getResourcesByServer(serverName: string): MCPResource[] { + const serverResources: MCPResource[] = []; + for (const resource of this.resources.values()) { + if (resource.serverName === serverName) { + serverResources.push(resource); + } + } + return serverResources.sort((a, b) => a.uri.localeCompare(b.uri)); + } } diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 96d7abf55c..43ea9715bc 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -173,7 +173,7 @@ export class McpClientManager { return Promise.resolve(); }), ); - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } /** @@ -193,7 +193,7 @@ export class McpClientManager { }), ), ); - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } /** @@ -251,7 +251,7 @@ export class McpClientManager { if (!skipRefresh) { // This is required to update the content generator configuration with the // new tool configuration and system instructions. - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } } } @@ -321,7 +321,7 @@ export class McpClientManager { this.cliConfig.getDebugMode(), this.clientVersion, async () => { - debugLogger.log('Tools changed, updating Gemini context...'); + debugLogger.log(`🔔 Refreshing context for server '${name}'...`); await this.scheduleMcpContextRefresh(); }, ); @@ -431,7 +431,7 @@ export class McpClientManager { this.eventEmitter?.emit('mcp-client-update', this.clients); } - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } /** @@ -451,7 +451,7 @@ export class McpClientManager { }, ), ); - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } /** @@ -463,7 +463,7 @@ export class McpClientManager { throw new Error(`No MCP server registered with the name "${name}"`); } await this.maybeDiscoverMcpServer(name, config); - await this.cliConfig.refreshMcpContext(); + await this.scheduleMcpContextRefresh(); } /** @@ -517,21 +517,51 @@ export class McpClientManager { return instructions.join('\n\n'); } + private isRefreshingMcpContext: boolean = false; + private pendingMcpContextRefresh: boolean = false; + private async scheduleMcpContextRefresh(): Promise { + this.pendingMcpContextRefresh = true; + + if (this.isRefreshingMcpContext) { + debugLogger.log( + 'MCP context refresh already in progress, queuing trailing execution.', + ); + return this.pendingRefreshPromise ?? Promise.resolve(); + } + if (this.pendingRefreshPromise) { + debugLogger.log( + 'MCP context refresh already scheduled, coalescing with existing request.', + ); return this.pendingRefreshPromise; } + debugLogger.log('Scheduling MCP context refresh...'); this.pendingRefreshPromise = (async () => { - // Debounce to coalesce multiple rapid updates - await new Promise((resolve) => setTimeout(resolve, 300)); + this.isRefreshingMcpContext = true; try { - await this.cliConfig.refreshMcpContext(); + do { + this.pendingMcpContextRefresh = false; + debugLogger.log('Executing MCP context refresh...'); + await this.cliConfig.refreshMcpContext(); + debugLogger.log('MCP context refresh complete.'); + + // If more refresh requests came in during the execution, wait a bit + // to coalesce them before the next iteration. + if (this.pendingMcpContextRefresh) { + debugLogger.log( + 'Coalescing burst refresh requests (300ms delay)...', + ); + await new Promise((resolve) => setTimeout(resolve, 300)); + } + } while (this.pendingMcpContextRefresh); } catch (error) { debugLogger.error( `Error refreshing MCP context: ${getErrorMessage(error)}`, ); } finally { + this.isRefreshingMcpContext = false; this.pendingRefreshPromise = null; } })(); diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 126fb7ce68..0f7b58c39a 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -22,6 +22,7 @@ import { PromptListChangedNotificationSchema, ResourceListChangedNotificationSchema, ToolListChangedNotificationSchema, + ProgressNotificationSchema, } from '@modelcontextprotocol/sdk/types.js'; import type { DiscoveredMCPTool } from './mcp-tool.js'; @@ -102,6 +103,7 @@ describe('mcp-client', () => { afterEach(() => { vi.restoreAllMocks(); + vi.useRealTimers(); }); describe('McpClient', () => { @@ -140,13 +142,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -221,13 +226,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -328,13 +336,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -388,13 +399,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -701,13 +715,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -778,13 +795,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -864,13 +884,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -950,13 +973,16 @@ describe('mcp-client', () => { const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const promptRegistry = { registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), removePromptsByServer: vi.fn(), } as unknown as PromptRegistry; const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), setResourcesForServer: vi.fn(), removeResourcesByServer: vi.fn(), } as unknown as ResourceRegistry; @@ -1086,6 +1112,7 @@ describe('mcp-client', () => { setNotificationHandler: vi.fn(), listTools: vi.fn().mockResolvedValue({ tools: [] }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + listResources: vi.fn().mockResolvedValue({ resources: [] }), request: vi.fn().mockResolvedValue({}), }; @@ -1096,12 +1123,27 @@ describe('mcp-client', () => { {} as SdkClientStdioLib.StdioClientTransport, ); + const mockedToolRegistry = { + registerTool: vi.fn(), + sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + const client = new McpClient( 'test-server', { command: 'test-command' }, - {} as ToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, + mockedToolRegistry, + { + getPromptsByServer: vi.fn().mockReturnValue([]), + registerPrompt: vi.fn(), + } as unknown as PromptRegistry, + { + getResourcesByServer: vi.fn().mockReturnValue([]), + registerResource: vi.fn(), + removeResourcesByServer: vi.fn(), + setResourcesForServer: vi.fn(), + } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1136,9 +1178,21 @@ describe('mcp-client', () => { const client = new McpClient( 'test-server', { command: 'test-command' }, - {} as ToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, + { + getToolsByServer: vi.fn().mockReturnValue([]), + registerTool: vi.fn(), + sortTools: vi.fn(), + } as unknown as ToolRegistry, + { + getPromptsByServer: vi.fn().mockReturnValue([]), + registerPrompt: vi.fn(), + } as unknown as PromptRegistry, + { + getResourcesByServer: vi.fn().mockReturnValue([]), + registerResource: vi.fn(), + removeResourcesByServer: vi.fn(), + setResourcesForServer: vi.fn(), + } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1147,7 +1201,62 @@ describe('mcp-client', () => { await client.connect(); - expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce(); + // Should be called for ProgressNotificationSchema, even if no other capabilities + expect(mockedClient.setNotificationHandler).toHaveBeenCalled(); + const progressCall = mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ProgressNotificationSchema, + ); + expect(progressCall).toBeDefined(); + }); + + it('should set up notification handler even if listChanged is false (robustness)', async () => { + // Setup mocks + const mockedClient = { + connect: vi.fn(), + getServerCapabilities: vi + .fn() + .mockReturnValue({ tools: { listChanged: false } }), + setNotificationHandler: vi.fn(), + request: vi.fn().mockResolvedValue({}), + registerCapabilities: vi.fn().mockResolvedValue({}), + setRequestHandler: vi.fn().mockResolvedValue({}), + }; + + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + { + getToolsByServer: vi.fn().mockReturnValue([]), + registerTool: vi.fn(), + sortTools: vi.fn(), + } as unknown as ToolRegistry, + { + getPromptsByServer: vi.fn().mockReturnValue([]), + registerPrompt: vi.fn(), + } as unknown as PromptRegistry, + { + getResourcesByServer: vi.fn().mockReturnValue([]), + registerResource: vi.fn(), + removeResourcesByServer: vi.fn(), + setResourcesForServer: vi.fn(), + } as unknown as ResourceRegistry, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + + const toolUpdateCall = + mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + expect(toolUpdateCall).toBeDefined(); }); it('should refresh tools and notify manager when notification is received', async () => { @@ -1167,6 +1276,7 @@ describe('mcp-client', () => { ], }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + listResources: vi.fn().mockResolvedValue({ resources: [] }), request: vi.fn().mockResolvedValue({}), registerCapabilities: vi.fn().mockResolvedValue({}), setRequestHandler: vi.fn().mockResolvedValue({}), @@ -1183,31 +1293,38 @@ describe('mcp-client', () => { removeMcpToolsByServer: vi.fn(), registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; - const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined); - // Initialize client with onToolsUpdated callback + // Initialize client with onContextUpdated callback const client = new McpClient( 'test-server', { command: 'test-command' }, mockedToolRegistry, {} as PromptRegistry, - {} as ResourceRegistry, + { + removeMcpResourcesByServer: vi.fn(), + registerResource: vi.fn(), + } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', - onToolsUpdatedSpy, + onContextUpdatedSpy, ); // 1. Connect (sets up listener) await client.connect(); - // 2. Extract the callback passed to setNotificationHandler - const notificationCallback = - mockedClient.setNotificationHandler.mock.calls[0][1]; + // 2. Extract the callback passed to setNotificationHandler for tools + const toolUpdateCall = + mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const notificationCallback = toolUpdateCall![1]; // 3. Trigger the notification manually await notificationCallback(); @@ -1225,7 +1342,7 @@ describe('mcp-client', () => { expect(mockedToolRegistry.registerTool).toHaveBeenCalled(); // It should notify the manager - expect(onToolsUpdatedSpy).toHaveBeenCalled(); + expect(onContextUpdatedSpy).toHaveBeenCalled(); // It should emit feedback event expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith( @@ -1259,6 +1376,7 @@ describe('mcp-client', () => { const mockedToolRegistry = { removeMcpToolsByServer: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; @@ -1276,8 +1394,11 @@ describe('mcp-client', () => { await client.connect(); - const notificationCallback = - mockedClient.setNotificationHandler.mock.calls[0][1]; + const toolUpdateCall = + mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const notificationCallback = toolUpdateCall![1]; // Trigger notification - should fail internally but catch the error await notificationCallback(); @@ -1328,10 +1449,11 @@ describe('mcp-client', () => { removeMcpToolsByServer: vi.fn(), registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; - const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined); const clientA = new McpClient( 'server-A', @@ -1343,7 +1465,7 @@ describe('mcp-client', () => { MOCK_CONTEXT, false, '0.0.1', - onToolsUpdatedSpy, + onContextUpdatedSpy, ); const clientB = new McpClient( @@ -1356,14 +1478,23 @@ describe('mcp-client', () => { MOCK_CONTEXT, false, '0.0.1', - onToolsUpdatedSpy, + onContextUpdatedSpy, ); await clientA.connect(); await clientB.connect(); - const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1]; - const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1]; + const toolUpdateCallA = + mockClientA.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const handlerA = toolUpdateCallA![1]; + + const toolUpdateCallB = + mockClientB.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const handlerB = toolUpdateCallB![1]; // Trigger burst updates simultaneously await Promise.all([handlerA(), handlerB()]); @@ -1383,12 +1514,11 @@ describe('mcp-client', () => { expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); // Verify the update callback was triggered for both - expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2); + expect(onContextUpdatedSpy).toHaveBeenCalledTimes(2); }); it('should abort discovery and log error if timeout is exceeded during refresh', async () => { vi.useFakeTimers(); - const mockedClient = { connect: vi.fn(), getServerCapabilities: vi @@ -1412,6 +1542,7 @@ describe('mcp-client', () => { }), ), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + listResources: vi.fn().mockResolvedValue({ resources: [] }), request: vi.fn().mockResolvedValue({}), registerCapabilities: vi.fn().mockResolvedValue({}), setRequestHandler: vi.fn().mockResolvedValue({}), @@ -1428,16 +1559,26 @@ describe('mcp-client', () => { removeMcpToolsByServer: vi.fn(), registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; const client = new McpClient( 'test-server', - // Set a short timeout - { command: 'test-command', timeout: 100 }, + // Set a very short timeout + { command: 'test-command', timeout: 50 }, mockedToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, + { + getPromptsByServer: vi.fn().mockReturnValue([]), + registerPrompt: vi.fn(), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry, + { + getResourcesByServer: vi.fn().mockReturnValue([]), + registerResource: vi.fn(), + removeResourcesByServer: vi.fn(), + setResourcesForServer: vi.fn(), + } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, @@ -1446,13 +1587,16 @@ describe('mcp-client', () => { await client.connect(); - const notificationCallback = - mockedClient.setNotificationHandler.mock.calls[0][1]; + const toolUpdateCall = + mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const notificationCallback = toolUpdateCall![1]; const refreshPromise = notificationCallback(); - vi.advanceTimersByTime(150); - + // Advance timers to trigger the timeout (11 minutes to cover even the default timeout) + await vi.advanceTimersByTimeAsync(11 * 60 * 1000); await refreshPromise; expect(mockedClient.listTools).toHaveBeenCalledWith( @@ -1463,8 +1607,6 @@ describe('mcp-client', () => { ); expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled(); - - vi.useRealTimers(); }); it('should pass abort signal to onToolsUpdated callback', async () => { @@ -1492,35 +1634,51 @@ describe('mcp-client', () => { removeMcpToolsByServer: vi.fn(), registerTool: vi.fn(), sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), getMessageBus: vi.fn().mockReturnValue(undefined), } as unknown as ToolRegistry; - const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); + const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined); const client = new McpClient( 'test-server', { command: 'test-command' }, mockedToolRegistry, - {} as PromptRegistry, - {} as ResourceRegistry, + { + getPromptsByServer: vi.fn().mockReturnValue([]), + registerPrompt: vi.fn(), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry, + { + getResourcesByServer: vi.fn().mockReturnValue([]), + registerResource: vi.fn(), + removeResourcesByServer: vi.fn(), + setResourcesForServer: vi.fn(), + } as unknown as ResourceRegistry, workspaceContext, MOCK_CONTEXT, false, '0.0.1', - onToolsUpdatedSpy, + onContextUpdatedSpy, ); await client.connect(); - const notificationCallback = - mockedClient.setNotificationHandler.mock.calls[0][1]; + const toolUpdateCall = + mockedClient.setNotificationHandler.mock.calls.find( + (call) => call[0] === ToolListChangedNotificationSchema, + ); + const notificationCallback = toolUpdateCall![1]; - await notificationCallback(); + vi.useFakeTimers(); + const refreshPromise = notificationCallback(); + await vi.advanceTimersByTimeAsync(500); + await refreshPromise; - expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal)); + expect(onContextUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal)); // Verify the signal passed was not aborted (happy path) - const signal = onToolsUpdatedSpy.mock.calls[0][0]; + const signal = onContextUpdatedSpy.mock.calls[0][0]; expect(signal.aborted).toBe(false); }); }); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 6e0d1066de..af55facaa3 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -70,7 +70,10 @@ import type { ToolRegistry } from './tool-registry.js'; import { debugLogger } from '../utils/debugLogger.js'; import { type MessageBus } from '../confirmation-bus/message-bus.js'; import { coreEvents } from '../utils/events.js'; -import type { ResourceRegistry } from '../resources/resource-registry.js'; +import { + type ResourceRegistry, + type MCPResource, +} from '../resources/resource-registry.js'; import { validateMcpPolicyToolNames } from '../policy/toml-loader.js'; import { sanitizeEnvironment, @@ -156,7 +159,7 @@ export class McpClient implements McpProgressReporter { private readonly cliConfig: McpContext, private readonly debugMode: boolean, private readonly clientVersion: string, - private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise, + private readonly onContextUpdated?: (signal?: AbortSignal) => Promise, ) {} /** @@ -352,10 +355,21 @@ export class McpClient implements McpProgressReporter { const capabilities = this.client.getServerCapabilities(); - if (capabilities?.tools?.listChanged) { - debugLogger.log( - `Server '${this.serverName}' supports tool updates. Listening for changes...`, - ); + debugLogger.log( + `Registering notification handlers for server '${this.serverName}'. Capabilities:`, + capabilities, + ); + + if (capabilities?.tools) { + if (capabilities.tools.listChanged) { + debugLogger.log( + `Server '${this.serverName}' supports tool updates. Listening for changes...`, + ); + } else { + debugLogger.log( + `Server '${this.serverName}' has tools but did not declare 'listChanged' capability. Listening anyway for robustness...`, + ); + } this.client.setNotificationHandler( ToolListChangedNotificationSchema, @@ -368,10 +382,16 @@ export class McpClient implements McpProgressReporter { ); } - if (capabilities?.resources?.listChanged) { - debugLogger.log( - `Server '${this.serverName}' supports resource updates. Listening for changes...`, - ); + if (capabilities?.resources) { + if (capabilities.resources.listChanged) { + debugLogger.log( + `Server '${this.serverName}' supports resource updates. Listening for changes...`, + ); + } else { + debugLogger.log( + `Server '${this.serverName}' has resources but did not declare 'listChanged' capability. Listening anyway for robustness...`, + ); + } this.client.setNotificationHandler( ResourceListChangedNotificationSchema, @@ -384,10 +404,16 @@ export class McpClient implements McpProgressReporter { ); } - if (capabilities?.prompts?.listChanged) { - debugLogger.log( - `Server '${this.serverName}' supports prompt updates. Listening for changes...`, - ); + if (capabilities?.prompts) { + if (capabilities.prompts.listChanged) { + debugLogger.log( + `Server '${this.serverName}' supports prompt updates. Listening for changes...`, + ); + } else { + debugLogger.log( + `Server '${this.serverName}' has prompts but did not declare 'listChanged' capability. Listening anyway for robustness...`, + ); + } this.client.setNotificationHandler( PromptListChangedNotificationSchema, @@ -451,6 +477,25 @@ export class McpClient implements McpProgressReporter { let newResources; try { newResources = await this.discoverResources(); + + // Verification Retry: If no resources are found or resources didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentResources = + this.resourceRegistry.getResourcesByServer(this.serverName) || []; + const resourceMatch = + newResources.length === currentResources.length && + newResources.every((nr: Resource) => + currentResources.some((cr: MCPResource) => cr.uri === nr.uri), + ); + + if (resourceMatch && !this.pendingResourceRefresh) { + debugLogger.log( + `No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newResources = await this.discoverResources(); + } } catch (err) { debugLogger.error( `Resource discovery failed during refresh: ${getErrorMessage(err)}`, @@ -461,6 +506,10 @@ export class McpClient implements McpProgressReporter { this.updateResourceRegistry(newResources); + if (this.onContextUpdated) { + await this.onContextUpdated(abortController.signal); + } + clearTimeout(timeoutId); this.cliConfig.emitMcpDiagnostic( @@ -476,7 +525,6 @@ export class McpClient implements McpProgressReporter { ); } finally { this.isRefreshingResources = false; - this.pendingResourceRefresh = false; } } @@ -519,9 +567,31 @@ export class McpClient implements McpProgressReporter { const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); try { - const newPrompts = await this.fetchPrompts({ + let newPrompts = await this.fetchPrompts({ signal: abortController.signal, }); + + // Verification Retry: If no prompts are found or prompts didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentPrompts = + this.promptRegistry.getPromptsByServer(this.serverName) || []; + const promptsMatch = + newPrompts.length === currentPrompts.length && + newPrompts.every((np) => + currentPrompts.some((cp) => cp.name === np.name), + ); + + if (promptsMatch && !this.pendingPromptRefresh) { + debugLogger.log( + `No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newPrompts = await this.fetchPrompts({ + signal: abortController.signal, + }); + } + this.promptRegistry.removePromptsByServer(this.serverName); for (const prompt of newPrompts) { this.promptRegistry.registerPrompt(prompt); @@ -534,6 +604,10 @@ export class McpClient implements McpProgressReporter { break; } + if (this.onContextUpdated) { + await this.onContextUpdated(abortController.signal); + } + clearTimeout(timeoutId); this.cliConfig.emitMcpDiagnostic( @@ -549,7 +623,6 @@ export class McpClient implements McpProgressReporter { ); } finally { this.isRefreshingPrompts = false; - this.pendingPromptRefresh = false; } } @@ -594,6 +667,38 @@ export class McpClient implements McpProgressReporter { newTools = await this.discoverTools(this.cliConfig, { signal: abortController.signal, }); + debugLogger.log( + `Refresh for '${this.serverName}' discovered ${newTools.length} tools.`, + ); + + // Verification Retry (Option 3): If no tools are found or tools didn't change, + // wait briefly and try one more time. Some servers notify before they're fully ready. + const currentTools = + this.toolRegistry.getToolsByServer(this.serverName) || []; + const toolNamesMatch = + newTools.length === currentTools.length && + newTools.every((nt) => + currentTools.some( + (ct) => + ct.name === nt.name || + (ct instanceof DiscoveredMCPTool && + ct.serverToolName === nt.serverToolName), + ), + ); + + if (toolNamesMatch && !this.pendingToolRefresh) { + debugLogger.log( + `No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`, + ); + const retryDelay = 500; + await new Promise((resolve) => setTimeout(resolve, retryDelay)); + newTools = await this.discoverTools(this.cliConfig, { + signal: abortController.signal, + }); + debugLogger.log( + `Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`, + ); + } } catch (err) { debugLogger.error( `Discovery failed during refresh: ${getErrorMessage(err)}`, @@ -609,8 +714,8 @@ export class McpClient implements McpProgressReporter { } this.toolRegistry.sortTools(); - if (this.onToolsUpdated) { - await this.onToolsUpdated(abortController.signal); + if (this.onContextUpdated) { + await this.onContextUpdated(abortController.signal); } clearTimeout(timeoutId); @@ -628,7 +733,6 @@ export class McpClient implements McpProgressReporter { ); } finally { this.isRefreshingTools = false; - this.pendingToolRefresh = false; } } }