fix(mcp): Notifications/tools/list_changed support not working (#21050)

Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
Jacob Richman
2026-03-04 06:46:17 -08:00
committed by GitHub
parent 7011c13ee6
commit 352fb0c976
4 changed files with 378 additions and 73 deletions
@@ -69,4 +69,17 @@ export class ResourceRegistry {
clear(): void { clear(): void {
this.resources.clear(); this.resources.clear();
} }
/**
* Returns an array of resources registered from a specific MCP server.
*/
getResourcesByServer(serverName: string): MCPResource[] {
const serverResources: MCPResource[] = [];
for (const resource of this.resources.values()) {
if (resource.serverName === serverName) {
serverResources.push(resource);
}
}
return serverResources.sort((a, b) => a.uri.localeCompare(b.uri));
}
} }
+40 -10
View File
@@ -173,7 +173,7 @@ export class McpClientManager {
return Promise.resolve(); return Promise.resolve();
}), }),
); );
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
/** /**
@@ -193,7 +193,7 @@ export class McpClientManager {
}), }),
), ),
); );
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
/** /**
@@ -251,7 +251,7 @@ export class McpClientManager {
if (!skipRefresh) { if (!skipRefresh) {
// This is required to update the content generator configuration with the // This is required to update the content generator configuration with the
// new tool configuration and system instructions. // new tool configuration and system instructions.
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
} }
} }
@@ -321,7 +321,7 @@ export class McpClientManager {
this.cliConfig.getDebugMode(), this.cliConfig.getDebugMode(),
this.clientVersion, this.clientVersion,
async () => { async () => {
debugLogger.log('Tools changed, updating Gemini context...'); debugLogger.log(`🔔 Refreshing context for server '${name}'...`);
await this.scheduleMcpContextRefresh(); await this.scheduleMcpContextRefresh();
}, },
); );
@@ -431,7 +431,7 @@ export class McpClientManager {
this.eventEmitter?.emit('mcp-client-update', this.clients); this.eventEmitter?.emit('mcp-client-update', this.clients);
} }
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
/** /**
@@ -451,7 +451,7 @@ export class McpClientManager {
}, },
), ),
); );
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
/** /**
@@ -463,7 +463,7 @@ export class McpClientManager {
throw new Error(`No MCP server registered with the name "${name}"`); throw new Error(`No MCP server registered with the name "${name}"`);
} }
await this.maybeDiscoverMcpServer(name, config); await this.maybeDiscoverMcpServer(name, config);
await this.cliConfig.refreshMcpContext(); await this.scheduleMcpContextRefresh();
} }
/** /**
@@ -517,21 +517,51 @@ export class McpClientManager {
return instructions.join('\n\n'); return instructions.join('\n\n');
} }
private isRefreshingMcpContext: boolean = false;
private pendingMcpContextRefresh: boolean = false;
private async scheduleMcpContextRefresh(): Promise<void> { private async scheduleMcpContextRefresh(): Promise<void> {
this.pendingMcpContextRefresh = true;
if (this.isRefreshingMcpContext) {
debugLogger.log(
'MCP context refresh already in progress, queuing trailing execution.',
);
return this.pendingRefreshPromise ?? Promise.resolve();
}
if (this.pendingRefreshPromise) { if (this.pendingRefreshPromise) {
debugLogger.log(
'MCP context refresh already scheduled, coalescing with existing request.',
);
return this.pendingRefreshPromise; return this.pendingRefreshPromise;
} }
debugLogger.log('Scheduling MCP context refresh...');
this.pendingRefreshPromise = (async () => { this.pendingRefreshPromise = (async () => {
// Debounce to coalesce multiple rapid updates this.isRefreshingMcpContext = true;
await new Promise((resolve) => setTimeout(resolve, 300));
try { try {
await this.cliConfig.refreshMcpContext(); do {
this.pendingMcpContextRefresh = false;
debugLogger.log('Executing MCP context refresh...');
await this.cliConfig.refreshMcpContext();
debugLogger.log('MCP context refresh complete.');
// If more refresh requests came in during the execution, wait a bit
// to coalesce them before the next iteration.
if (this.pendingMcpContextRefresh) {
debugLogger.log(
'Coalescing burst refresh requests (300ms delay)...',
);
await new Promise((resolve) => setTimeout(resolve, 300));
}
} while (this.pendingMcpContextRefresh);
} catch (error) { } catch (error) {
debugLogger.error( debugLogger.error(
`Error refreshing MCP context: ${getErrorMessage(error)}`, `Error refreshing MCP context: ${getErrorMessage(error)}`,
); );
} finally { } finally {
this.isRefreshingMcpContext = false;
this.pendingRefreshPromise = null; this.pendingRefreshPromise = null;
} }
})(); })();
+201 -43
View File
@@ -22,6 +22,7 @@ import {
PromptListChangedNotificationSchema, PromptListChangedNotificationSchema,
ResourceListChangedNotificationSchema, ResourceListChangedNotificationSchema,
ToolListChangedNotificationSchema, ToolListChangedNotificationSchema,
ProgressNotificationSchema,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import type { DiscoveredMCPTool } from './mcp-tool.js'; import type { DiscoveredMCPTool } from './mcp-tool.js';
@@ -102,6 +103,7 @@ describe('mcp-client', () => {
afterEach(() => { afterEach(() => {
vi.restoreAllMocks(); vi.restoreAllMocks();
vi.useRealTimers();
}); });
describe('McpClient', () => { describe('McpClient', () => {
@@ -140,13 +142,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -221,13 +226,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -328,13 +336,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -388,13 +399,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -701,13 +715,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -778,13 +795,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -864,13 +884,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -950,13 +973,16 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const promptRegistry = { const promptRegistry = {
registerPrompt: vi.fn(), registerPrompt: vi.fn(),
getPromptsByServer: vi.fn().mockReturnValue([]),
removePromptsByServer: vi.fn(), removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry; } as unknown as PromptRegistry;
const resourceRegistry = { const resourceRegistry = {
getResourcesByServer: vi.fn().mockReturnValue([]),
setResourcesForServer: vi.fn(), setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(), removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry; } as unknown as ResourceRegistry;
@@ -1086,6 +1112,7 @@ describe('mcp-client', () => {
setNotificationHandler: vi.fn(), setNotificationHandler: vi.fn(),
listTools: vi.fn().mockResolvedValue({ tools: [] }), listTools: vi.fn().mockResolvedValue({ tools: [] }),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
listResources: vi.fn().mockResolvedValue({ resources: [] }),
request: vi.fn().mockResolvedValue({}), request: vi.fn().mockResolvedValue({}),
}; };
@@ -1096,12 +1123,27 @@ describe('mcp-client', () => {
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry;
const client = new McpClient( const client = new McpClient(
'test-server', 'test-server',
{ command: 'test-command' }, { command: 'test-command' },
{} as ToolRegistry, mockedToolRegistry,
{} as PromptRegistry, {
{} as ResourceRegistry, getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext, workspaceContext,
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
@@ -1136,9 +1178,21 @@ describe('mcp-client', () => {
const client = new McpClient( const client = new McpClient(
'test-server', 'test-server',
{ command: 'test-command' }, { command: 'test-command' },
{} as ToolRegistry, {
{} as PromptRegistry, getToolsByServer: vi.fn().mockReturnValue([]),
{} as ResourceRegistry, registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext, workspaceContext,
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
@@ -1147,7 +1201,62 @@ describe('mcp-client', () => {
await client.connect(); await client.connect();
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce(); // Should be called for ProgressNotificationSchema, even if no other capabilities
expect(mockedClient.setNotificationHandler).toHaveBeenCalled();
const progressCall = mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ProgressNotificationSchema,
);
expect(progressCall).toBeDefined();
});
it('should set up notification handler even if listChanged is false (robustness)', async () => {
// Setup mocks
const mockedClient = {
connect: vi.fn(),
getServerCapabilities: vi
.fn()
.mockReturnValue({ tools: { listChanged: false } }),
setNotificationHandler: vi.fn(),
request: vi.fn().mockResolvedValue({}),
registerCapabilities: vi.fn().mockResolvedValue({}),
setRequestHandler: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
const client = new McpClient(
'test-server',
{ command: 'test-command' },
{
getToolsByServer: vi.fn().mockReturnValue([]),
registerTool: vi.fn(),
sortTools: vi.fn(),
} as unknown as ToolRegistry,
{
getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext,
MOCK_CONTEXT,
false,
'0.0.1',
);
await client.connect();
const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
expect(toolUpdateCall).toBeDefined();
}); });
it('should refresh tools and notify manager when notification is received', async () => { it('should refresh tools and notify manager when notification is received', async () => {
@@ -1167,6 +1276,7 @@ describe('mcp-client', () => {
], ],
}), }),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
listResources: vi.fn().mockResolvedValue({ resources: [] }),
request: vi.fn().mockResolvedValue({}), request: vi.fn().mockResolvedValue({}),
registerCapabilities: vi.fn().mockResolvedValue({}), registerCapabilities: vi.fn().mockResolvedValue({}),
setRequestHandler: vi.fn().mockResolvedValue({}), setRequestHandler: vi.fn().mockResolvedValue({}),
@@ -1183,31 +1293,38 @@ describe('mcp-client', () => {
removeMcpToolsByServer: vi.fn(), removeMcpToolsByServer: vi.fn(),
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
// Initialize client with onToolsUpdated callback // Initialize client with onContextUpdated callback
const client = new McpClient( const client = new McpClient(
'test-server', 'test-server',
{ command: 'test-command' }, { command: 'test-command' },
mockedToolRegistry, mockedToolRegistry,
{} as PromptRegistry, {} as PromptRegistry,
{} as ResourceRegistry, {
removeMcpResourcesByServer: vi.fn(),
registerResource: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext, workspaceContext,
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onContextUpdatedSpy,
); );
// 1. Connect (sets up listener) // 1. Connect (sets up listener)
await client.connect(); await client.connect();
// 2. Extract the callback passed to setNotificationHandler // 2. Extract the callback passed to setNotificationHandler for tools
const notificationCallback = const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls[0][1]; mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const notificationCallback = toolUpdateCall![1];
// 3. Trigger the notification manually // 3. Trigger the notification manually
await notificationCallback(); await notificationCallback();
@@ -1225,7 +1342,7 @@ describe('mcp-client', () => {
expect(mockedToolRegistry.registerTool).toHaveBeenCalled(); expect(mockedToolRegistry.registerTool).toHaveBeenCalled();
// It should notify the manager // It should notify the manager
expect(onToolsUpdatedSpy).toHaveBeenCalled(); expect(onContextUpdatedSpy).toHaveBeenCalled();
// It should emit feedback event // It should emit feedback event
expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith( expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
@@ -1259,6 +1376,7 @@ describe('mcp-client', () => {
const mockedToolRegistry = { const mockedToolRegistry = {
removeMcpToolsByServer: vi.fn(), removeMcpToolsByServer: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
@@ -1276,8 +1394,11 @@ describe('mcp-client', () => {
await client.connect(); await client.connect();
const notificationCallback = const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls[0][1]; mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const notificationCallback = toolUpdateCall![1];
// Trigger notification - should fail internally but catch the error // Trigger notification - should fail internally but catch the error
await notificationCallback(); await notificationCallback();
@@ -1328,10 +1449,11 @@ describe('mcp-client', () => {
removeMcpToolsByServer: vi.fn(), removeMcpToolsByServer: vi.fn(),
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
const clientA = new McpClient( const clientA = new McpClient(
'server-A', 'server-A',
@@ -1343,7 +1465,7 @@ describe('mcp-client', () => {
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onContextUpdatedSpy,
); );
const clientB = new McpClient( const clientB = new McpClient(
@@ -1356,14 +1478,23 @@ describe('mcp-client', () => {
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onContextUpdatedSpy,
); );
await clientA.connect(); await clientA.connect();
await clientB.connect(); await clientB.connect();
const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1]; const toolUpdateCallA =
const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1]; mockClientA.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const handlerA = toolUpdateCallA![1];
const toolUpdateCallB =
mockClientB.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const handlerB = toolUpdateCallB![1];
// Trigger burst updates simultaneously // Trigger burst updates simultaneously
await Promise.all([handlerA(), handlerB()]); await Promise.all([handlerA(), handlerB()]);
@@ -1383,12 +1514,11 @@ describe('mcp-client', () => {
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
// Verify the update callback was triggered for both // Verify the update callback was triggered for both
expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2); expect(onContextUpdatedSpy).toHaveBeenCalledTimes(2);
}); });
it('should abort discovery and log error if timeout is exceeded during refresh', async () => { it('should abort discovery and log error if timeout is exceeded during refresh', async () => {
vi.useFakeTimers(); vi.useFakeTimers();
const mockedClient = { const mockedClient = {
connect: vi.fn(), connect: vi.fn(),
getServerCapabilities: vi getServerCapabilities: vi
@@ -1412,6 +1542,7 @@ describe('mcp-client', () => {
}), }),
), ),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
listResources: vi.fn().mockResolvedValue({ resources: [] }),
request: vi.fn().mockResolvedValue({}), request: vi.fn().mockResolvedValue({}),
registerCapabilities: vi.fn().mockResolvedValue({}), registerCapabilities: vi.fn().mockResolvedValue({}),
setRequestHandler: vi.fn().mockResolvedValue({}), setRequestHandler: vi.fn().mockResolvedValue({}),
@@ -1428,16 +1559,26 @@ describe('mcp-client', () => {
removeMcpToolsByServer: vi.fn(), removeMcpToolsByServer: vi.fn(),
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const client = new McpClient( const client = new McpClient(
'test-server', 'test-server',
// Set a short timeout // Set a very short timeout
{ command: 'test-command', timeout: 100 }, { command: 'test-command', timeout: 50 },
mockedToolRegistry, mockedToolRegistry,
{} as PromptRegistry, {
{} as ResourceRegistry, getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext, workspaceContext,
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
@@ -1446,13 +1587,16 @@ describe('mcp-client', () => {
await client.connect(); await client.connect();
const notificationCallback = const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls[0][1]; mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const notificationCallback = toolUpdateCall![1];
const refreshPromise = notificationCallback(); const refreshPromise = notificationCallback();
vi.advanceTimersByTime(150); // Advance timers to trigger the timeout (11 minutes to cover even the default timeout)
await vi.advanceTimersByTimeAsync(11 * 60 * 1000);
await refreshPromise; await refreshPromise;
expect(mockedClient.listTools).toHaveBeenCalledWith( expect(mockedClient.listTools).toHaveBeenCalledWith(
@@ -1463,8 +1607,6 @@ describe('mcp-client', () => {
); );
expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled(); expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled();
vi.useRealTimers();
}); });
it('should pass abort signal to onToolsUpdated callback', async () => { it('should pass abort signal to onToolsUpdated callback', async () => {
@@ -1492,35 +1634,51 @@ describe('mcp-client', () => {
removeMcpToolsByServer: vi.fn(), removeMcpToolsByServer: vi.fn(),
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
getToolsByServer: vi.fn().mockReturnValue([]),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry; } as unknown as ToolRegistry;
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined); const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
const client = new McpClient( const client = new McpClient(
'test-server', 'test-server',
{ command: 'test-command' }, { command: 'test-command' },
mockedToolRegistry, mockedToolRegistry,
{} as PromptRegistry, {
{} as ResourceRegistry, getPromptsByServer: vi.fn().mockReturnValue([]),
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry,
{
getResourcesByServer: vi.fn().mockReturnValue([]),
registerResource: vi.fn(),
removeResourcesByServer: vi.fn(),
setResourcesForServer: vi.fn(),
} as unknown as ResourceRegistry,
workspaceContext, workspaceContext,
MOCK_CONTEXT, MOCK_CONTEXT,
false, false,
'0.0.1', '0.0.1',
onToolsUpdatedSpy, onContextUpdatedSpy,
); );
await client.connect(); await client.connect();
const notificationCallback = const toolUpdateCall =
mockedClient.setNotificationHandler.mock.calls[0][1]; mockedClient.setNotificationHandler.mock.calls.find(
(call) => call[0] === ToolListChangedNotificationSchema,
);
const notificationCallback = toolUpdateCall![1];
await notificationCallback(); vi.useFakeTimers();
const refreshPromise = notificationCallback();
await vi.advanceTimersByTimeAsync(500);
await refreshPromise;
expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal)); expect(onContextUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
// Verify the signal passed was not aborted (happy path) // Verify the signal passed was not aborted (happy path)
const signal = onToolsUpdatedSpy.mock.calls[0][0]; const signal = onContextUpdatedSpy.mock.calls[0][0];
expect(signal.aborted).toBe(false); expect(signal.aborted).toBe(false);
}); });
}); });
+124 -20
View File
@@ -70,7 +70,10 @@ import type { ToolRegistry } from './tool-registry.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { type MessageBus } from '../confirmation-bus/message-bus.js'; import { type MessageBus } from '../confirmation-bus/message-bus.js';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
import type { ResourceRegistry } from '../resources/resource-registry.js'; import {
type ResourceRegistry,
type MCPResource,
} from '../resources/resource-registry.js';
import { validateMcpPolicyToolNames } from '../policy/toml-loader.js'; import { validateMcpPolicyToolNames } from '../policy/toml-loader.js';
import { import {
sanitizeEnvironment, sanitizeEnvironment,
@@ -156,7 +159,7 @@ export class McpClient implements McpProgressReporter {
private readonly cliConfig: McpContext, private readonly cliConfig: McpContext,
private readonly debugMode: boolean, private readonly debugMode: boolean,
private readonly clientVersion: string, private readonly clientVersion: string,
private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise<void>, private readonly onContextUpdated?: (signal?: AbortSignal) => Promise<void>,
) {} ) {}
/** /**
@@ -352,10 +355,21 @@ export class McpClient implements McpProgressReporter {
const capabilities = this.client.getServerCapabilities(); const capabilities = this.client.getServerCapabilities();
if (capabilities?.tools?.listChanged) { debugLogger.log(
debugLogger.log( `Registering notification handlers for server '${this.serverName}'. Capabilities:`,
`Server '${this.serverName}' supports tool updates. Listening for changes...`, capabilities,
); );
if (capabilities?.tools) {
if (capabilities.tools.listChanged) {
debugLogger.log(
`Server '${this.serverName}' supports tool updates. Listening for changes...`,
);
} else {
debugLogger.log(
`Server '${this.serverName}' has tools but did not declare 'listChanged' capability. Listening anyway for robustness...`,
);
}
this.client.setNotificationHandler( this.client.setNotificationHandler(
ToolListChangedNotificationSchema, ToolListChangedNotificationSchema,
@@ -368,10 +382,16 @@ export class McpClient implements McpProgressReporter {
); );
} }
if (capabilities?.resources?.listChanged) { if (capabilities?.resources) {
debugLogger.log( if (capabilities.resources.listChanged) {
`Server '${this.serverName}' supports resource updates. Listening for changes...`, debugLogger.log(
); `Server '${this.serverName}' supports resource updates. Listening for changes...`,
);
} else {
debugLogger.log(
`Server '${this.serverName}' has resources but did not declare 'listChanged' capability. Listening anyway for robustness...`,
);
}
this.client.setNotificationHandler( this.client.setNotificationHandler(
ResourceListChangedNotificationSchema, ResourceListChangedNotificationSchema,
@@ -384,10 +404,16 @@ export class McpClient implements McpProgressReporter {
); );
} }
if (capabilities?.prompts?.listChanged) { if (capabilities?.prompts) {
debugLogger.log( if (capabilities.prompts.listChanged) {
`Server '${this.serverName}' supports prompt updates. Listening for changes...`, debugLogger.log(
); `Server '${this.serverName}' supports prompt updates. Listening for changes...`,
);
} else {
debugLogger.log(
`Server '${this.serverName}' has prompts but did not declare 'listChanged' capability. Listening anyway for robustness...`,
);
}
this.client.setNotificationHandler( this.client.setNotificationHandler(
PromptListChangedNotificationSchema, PromptListChangedNotificationSchema,
@@ -451,6 +477,25 @@ export class McpClient implements McpProgressReporter {
let newResources; let newResources;
try { try {
newResources = await this.discoverResources(); newResources = await this.discoverResources();
// Verification Retry: If no resources are found or resources didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentResources =
this.resourceRegistry.getResourcesByServer(this.serverName) || [];
const resourceMatch =
newResources.length === currentResources.length &&
newResources.every((nr: Resource) =>
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
);
if (resourceMatch && !this.pendingResourceRefresh) {
debugLogger.log(
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newResources = await this.discoverResources();
}
} catch (err) { } catch (err) {
debugLogger.error( debugLogger.error(
`Resource discovery failed during refresh: ${getErrorMessage(err)}`, `Resource discovery failed during refresh: ${getErrorMessage(err)}`,
@@ -461,6 +506,10 @@ export class McpClient implements McpProgressReporter {
this.updateResourceRegistry(newResources); this.updateResourceRegistry(newResources);
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
clearTimeout(timeoutId); clearTimeout(timeoutId);
this.cliConfig.emitMcpDiagnostic( this.cliConfig.emitMcpDiagnostic(
@@ -476,7 +525,6 @@ export class McpClient implements McpProgressReporter {
); );
} finally { } finally {
this.isRefreshingResources = false; this.isRefreshingResources = false;
this.pendingResourceRefresh = false;
} }
} }
@@ -519,9 +567,31 @@ export class McpClient implements McpProgressReporter {
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs); const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
try { try {
const newPrompts = await this.fetchPrompts({ let newPrompts = await this.fetchPrompts({
signal: abortController.signal, signal: abortController.signal,
}); });
// Verification Retry: If no prompts are found or prompts didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentPrompts =
this.promptRegistry.getPromptsByServer(this.serverName) || [];
const promptsMatch =
newPrompts.length === currentPrompts.length &&
newPrompts.every((np) =>
currentPrompts.some((cp) => cp.name === np.name),
);
if (promptsMatch && !this.pendingPromptRefresh) {
debugLogger.log(
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
}
this.promptRegistry.removePromptsByServer(this.serverName); this.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) { for (const prompt of newPrompts) {
this.promptRegistry.registerPrompt(prompt); this.promptRegistry.registerPrompt(prompt);
@@ -534,6 +604,10 @@ export class McpClient implements McpProgressReporter {
break; break;
} }
if (this.onContextUpdated) {
await this.onContextUpdated(abortController.signal);
}
clearTimeout(timeoutId); clearTimeout(timeoutId);
this.cliConfig.emitMcpDiagnostic( this.cliConfig.emitMcpDiagnostic(
@@ -549,7 +623,6 @@ export class McpClient implements McpProgressReporter {
); );
} finally { } finally {
this.isRefreshingPrompts = false; this.isRefreshingPrompts = false;
this.pendingPromptRefresh = false;
} }
} }
@@ -594,6 +667,38 @@ export class McpClient implements McpProgressReporter {
newTools = await this.discoverTools(this.cliConfig, { newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal, signal: abortController.signal,
}); });
debugLogger.log(
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
// Verification Retry (Option 3): If no tools are found or tools didn't change,
// wait briefly and try one more time. Some servers notify before they're fully ready.
const currentTools =
this.toolRegistry.getToolsByServer(this.serverName) || [];
const toolNamesMatch =
newTools.length === currentTools.length &&
newTools.every((nt) =>
currentTools.some(
(ct) =>
ct.name === nt.name ||
(ct instanceof DiscoveredMCPTool &&
ct.serverToolName === nt.serverToolName),
),
);
if (toolNamesMatch && !this.pendingToolRefresh) {
debugLogger.log(
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
);
const retryDelay = 500;
await new Promise((resolve) => setTimeout(resolve, retryDelay));
newTools = await this.discoverTools(this.cliConfig, {
signal: abortController.signal,
});
debugLogger.log(
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
);
}
} catch (err) { } catch (err) {
debugLogger.error( debugLogger.error(
`Discovery failed during refresh: ${getErrorMessage(err)}`, `Discovery failed during refresh: ${getErrorMessage(err)}`,
@@ -609,8 +714,8 @@ export class McpClient implements McpProgressReporter {
} }
this.toolRegistry.sortTools(); this.toolRegistry.sortTools();
if (this.onToolsUpdated) { if (this.onContextUpdated) {
await this.onToolsUpdated(abortController.signal); await this.onContextUpdated(abortController.signal);
} }
clearTimeout(timeoutId); clearTimeout(timeoutId);
@@ -628,7 +733,6 @@ export class McpClient implements McpProgressReporter {
); );
} finally { } finally {
this.isRefreshingTools = false; this.isRefreshingTools = false;
this.pendingToolRefresh = false;
} }
} }
} }