mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
fix(mcp): Notifications/tools/list_changed support not working (#21050)
Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
@@ -69,4 +69,17 @@ export class ResourceRegistry {
|
||||
clear(): void {
|
||||
this.resources.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array of resources registered from a specific MCP server.
|
||||
*/
|
||||
getResourcesByServer(serverName: string): MCPResource[] {
|
||||
const serverResources: MCPResource[] = [];
|
||||
for (const resource of this.resources.values()) {
|
||||
if (resource.serverName === serverName) {
|
||||
serverResources.push(resource);
|
||||
}
|
||||
}
|
||||
return serverResources.sort((a, b) => a.uri.localeCompare(b.uri));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ export class McpClientManager {
|
||||
return Promise.resolve();
|
||||
}),
|
||||
);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -193,7 +193,7 @@ export class McpClientManager {
|
||||
}),
|
||||
),
|
||||
);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -251,7 +251,7 @@ export class McpClientManager {
|
||||
if (!skipRefresh) {
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration and system instructions.
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -321,7 +321,7 @@ export class McpClientManager {
|
||||
this.cliConfig.getDebugMode(),
|
||||
this.clientVersion,
|
||||
async () => {
|
||||
debugLogger.log('Tools changed, updating Gemini context...');
|
||||
debugLogger.log(`🔔 Refreshing context for server '${name}'...`);
|
||||
await this.scheduleMcpContextRefresh();
|
||||
},
|
||||
);
|
||||
@@ -431,7 +431,7 @@ export class McpClientManager {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
}
|
||||
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -451,7 +451,7 @@ export class McpClientManager {
|
||||
},
|
||||
),
|
||||
);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -463,7 +463,7 @@ export class McpClientManager {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
await this.scheduleMcpContextRefresh();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -517,21 +517,51 @@ export class McpClientManager {
|
||||
return instructions.join('\n\n');
|
||||
}
|
||||
|
||||
private isRefreshingMcpContext: boolean = false;
|
||||
private pendingMcpContextRefresh: boolean = false;
|
||||
|
||||
private async scheduleMcpContextRefresh(): Promise<void> {
|
||||
this.pendingMcpContextRefresh = true;
|
||||
|
||||
if (this.isRefreshingMcpContext) {
|
||||
debugLogger.log(
|
||||
'MCP context refresh already in progress, queuing trailing execution.',
|
||||
);
|
||||
return this.pendingRefreshPromise ?? Promise.resolve();
|
||||
}
|
||||
|
||||
if (this.pendingRefreshPromise) {
|
||||
debugLogger.log(
|
||||
'MCP context refresh already scheduled, coalescing with existing request.',
|
||||
);
|
||||
return this.pendingRefreshPromise;
|
||||
}
|
||||
|
||||
debugLogger.log('Scheduling MCP context refresh...');
|
||||
this.pendingRefreshPromise = (async () => {
|
||||
// Debounce to coalesce multiple rapid updates
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
this.isRefreshingMcpContext = true;
|
||||
try {
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
do {
|
||||
this.pendingMcpContextRefresh = false;
|
||||
debugLogger.log('Executing MCP context refresh...');
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
debugLogger.log('MCP context refresh complete.');
|
||||
|
||||
// If more refresh requests came in during the execution, wait a bit
|
||||
// to coalesce them before the next iteration.
|
||||
if (this.pendingMcpContextRefresh) {
|
||||
debugLogger.log(
|
||||
'Coalescing burst refresh requests (300ms delay)...',
|
||||
);
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
}
|
||||
} while (this.pendingMcpContextRefresh);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error refreshing MCP context: ${getErrorMessage(error)}`,
|
||||
);
|
||||
} finally {
|
||||
this.isRefreshingMcpContext = false;
|
||||
this.pendingRefreshPromise = null;
|
||||
}
|
||||
})();
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
PromptListChangedNotificationSchema,
|
||||
ResourceListChangedNotificationSchema,
|
||||
ToolListChangedNotificationSchema,
|
||||
ProgressNotificationSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
@@ -102,6 +103,7 @@ describe('mcp-client', () => {
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
describe('McpClient', () => {
|
||||
@@ -140,13 +142,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -221,13 +226,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -328,13 +336,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -388,13 +399,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -701,13 +715,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -778,13 +795,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -864,13 +884,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -950,13 +973,16 @@ describe('mcp-client', () => {
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
@@ -1086,6 +1112,7 @@ describe('mcp-client', () => {
|
||||
setNotificationHandler: vi.fn(),
|
||||
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
listResources: vi.fn().mockResolvedValue({ resources: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
@@ -1096,12 +1123,27 @@ describe('mcp-client', () => {
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
mockedToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1136,9 +1178,21 @@ describe('mcp-client', () => {
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
{
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1147,7 +1201,62 @@ describe('mcp-client', () => {
|
||||
|
||||
await client.connect();
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
|
||||
// Should be called for ProgressNotificationSchema, even if no other capabilities
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalled();
|
||||
const progressCall = mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ProgressNotificationSchema,
|
||||
);
|
||||
expect(progressCall).toBeDefined();
|
||||
});
|
||||
|
||||
it('should set up notification handler even if listChanged is false (robustness)', async () => {
|
||||
// Setup mocks
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: false } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
} as unknown as ToolRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
expect(toolUpdateCall).toBeDefined();
|
||||
});
|
||||
|
||||
it('should refresh tools and notify manager when notification is received', async () => {
|
||||
@@ -1167,6 +1276,7 @@ describe('mcp-client', () => {
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
listResources: vi.fn().mockResolvedValue({ resources: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
@@ -1183,31 +1293,38 @@ describe('mcp-client', () => {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
// Initialize client with onToolsUpdated callback
|
||||
// Initialize client with onContextUpdated callback
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
{
|
||||
removeMcpResourcesByServer: vi.fn(),
|
||||
registerResource: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
onToolsUpdatedSpy,
|
||||
onContextUpdatedSpy,
|
||||
);
|
||||
|
||||
// 1. Connect (sets up listener)
|
||||
await client.connect();
|
||||
|
||||
// 2. Extract the callback passed to setNotificationHandler
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
// 2. Extract the callback passed to setNotificationHandler for tools
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const notificationCallback = toolUpdateCall![1];
|
||||
|
||||
// 3. Trigger the notification manually
|
||||
await notificationCallback();
|
||||
@@ -1225,7 +1342,7 @@ describe('mcp-client', () => {
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalled();
|
||||
|
||||
// It should notify the manager
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalled();
|
||||
expect(onContextUpdatedSpy).toHaveBeenCalled();
|
||||
|
||||
// It should emit feedback event
|
||||
expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith(
|
||||
@@ -1259,6 +1376,7 @@ describe('mcp-client', () => {
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
@@ -1276,8 +1394,11 @@ describe('mcp-client', () => {
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const notificationCallback = toolUpdateCall![1];
|
||||
|
||||
// Trigger notification - should fail internally but catch the error
|
||||
await notificationCallback();
|
||||
@@ -1328,10 +1449,11 @@ describe('mcp-client', () => {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const clientA = new McpClient(
|
||||
'server-A',
|
||||
@@ -1343,7 +1465,7 @@ describe('mcp-client', () => {
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
onToolsUpdatedSpy,
|
||||
onContextUpdatedSpy,
|
||||
);
|
||||
|
||||
const clientB = new McpClient(
|
||||
@@ -1356,14 +1478,23 @@ describe('mcp-client', () => {
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
onToolsUpdatedSpy,
|
||||
onContextUpdatedSpy,
|
||||
);
|
||||
|
||||
await clientA.connect();
|
||||
await clientB.connect();
|
||||
|
||||
const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1];
|
||||
const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1];
|
||||
const toolUpdateCallA =
|
||||
mockClientA.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const handlerA = toolUpdateCallA![1];
|
||||
|
||||
const toolUpdateCallB =
|
||||
mockClientB.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const handlerB = toolUpdateCallB![1];
|
||||
|
||||
// Trigger burst updates simultaneously
|
||||
await Promise.all([handlerA(), handlerB()]);
|
||||
@@ -1383,12 +1514,11 @@ describe('mcp-client', () => {
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify the update callback was triggered for both
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2);
|
||||
expect(onContextUpdatedSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should abort discovery and log error if timeout is exceeded during refresh', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
@@ -1412,6 +1542,7 @@ describe('mcp-client', () => {
|
||||
}),
|
||||
),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
listResources: vi.fn().mockResolvedValue({ resources: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
@@ -1428,16 +1559,26 @@ describe('mcp-client', () => {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
// Set a short timeout
|
||||
{ command: 'test-command', timeout: 100 },
|
||||
// Set a very short timeout
|
||||
{ command: 'test-command', timeout: 50 },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
@@ -1446,13 +1587,16 @@ describe('mcp-client', () => {
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const notificationCallback = toolUpdateCall![1];
|
||||
|
||||
const refreshPromise = notificationCallback();
|
||||
|
||||
vi.advanceTimersByTime(150);
|
||||
|
||||
// Advance timers to trigger the timeout (11 minutes to cover even the default timeout)
|
||||
await vi.advanceTimersByTimeAsync(11 * 60 * 1000);
|
||||
await refreshPromise;
|
||||
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||
@@ -1463,8 +1607,6 @@ describe('mcp-client', () => {
|
||||
);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should pass abort signal to onToolsUpdated callback', async () => {
|
||||
@@ -1492,35 +1634,51 @@ describe('mcp-client', () => {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
const onContextUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
{} as ResourceRegistry,
|
||||
{
|
||||
getPromptsByServer: vi.fn().mockReturnValue([]),
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry,
|
||||
{
|
||||
getResourcesByServer: vi.fn().mockReturnValue([]),
|
||||
registerResource: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
setResourcesForServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry,
|
||||
workspaceContext,
|
||||
MOCK_CONTEXT,
|
||||
false,
|
||||
'0.0.1',
|
||||
onToolsUpdatedSpy,
|
||||
onContextUpdatedSpy,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
const toolUpdateCall =
|
||||
mockedClient.setNotificationHandler.mock.calls.find(
|
||||
(call) => call[0] === ToolListChangedNotificationSchema,
|
||||
);
|
||||
const notificationCallback = toolUpdateCall![1];
|
||||
|
||||
await notificationCallback();
|
||||
vi.useFakeTimers();
|
||||
const refreshPromise = notificationCallback();
|
||||
await vi.advanceTimersByTimeAsync(500);
|
||||
await refreshPromise;
|
||||
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
|
||||
expect(onContextUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
|
||||
|
||||
// Verify the signal passed was not aborted (happy path)
|
||||
const signal = onToolsUpdatedSpy.mock.calls[0][0];
|
||||
const signal = onContextUpdatedSpy.mock.calls[0][0];
|
||||
expect(signal.aborted).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -70,7 +70,10 @@ import type { ToolRegistry } from './tool-registry.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { type MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import type { ResourceRegistry } from '../resources/resource-registry.js';
|
||||
import {
|
||||
type ResourceRegistry,
|
||||
type MCPResource,
|
||||
} from '../resources/resource-registry.js';
|
||||
import { validateMcpPolicyToolNames } from '../policy/toml-loader.js';
|
||||
import {
|
||||
sanitizeEnvironment,
|
||||
@@ -156,7 +159,7 @@ export class McpClient implements McpProgressReporter {
|
||||
private readonly cliConfig: McpContext,
|
||||
private readonly debugMode: boolean,
|
||||
private readonly clientVersion: string,
|
||||
private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise<void>,
|
||||
private readonly onContextUpdated?: (signal?: AbortSignal) => Promise<void>,
|
||||
) {}
|
||||
|
||||
/**
|
||||
@@ -352,10 +355,21 @@ export class McpClient implements McpProgressReporter {
|
||||
|
||||
const capabilities = this.client.getServerCapabilities();
|
||||
|
||||
if (capabilities?.tools?.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports tool updates. Listening for changes...`,
|
||||
);
|
||||
debugLogger.log(
|
||||
`Registering notification handlers for server '${this.serverName}'. Capabilities:`,
|
||||
capabilities,
|
||||
);
|
||||
|
||||
if (capabilities?.tools) {
|
||||
if (capabilities.tools.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports tool updates. Listening for changes...`,
|
||||
);
|
||||
} else {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' has tools but did not declare 'listChanged' capability. Listening anyway for robustness...`,
|
||||
);
|
||||
}
|
||||
|
||||
this.client.setNotificationHandler(
|
||||
ToolListChangedNotificationSchema,
|
||||
@@ -368,10 +382,16 @@ export class McpClient implements McpProgressReporter {
|
||||
);
|
||||
}
|
||||
|
||||
if (capabilities?.resources?.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports resource updates. Listening for changes...`,
|
||||
);
|
||||
if (capabilities?.resources) {
|
||||
if (capabilities.resources.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports resource updates. Listening for changes...`,
|
||||
);
|
||||
} else {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' has resources but did not declare 'listChanged' capability. Listening anyway for robustness...`,
|
||||
);
|
||||
}
|
||||
|
||||
this.client.setNotificationHandler(
|
||||
ResourceListChangedNotificationSchema,
|
||||
@@ -384,10 +404,16 @@ export class McpClient implements McpProgressReporter {
|
||||
);
|
||||
}
|
||||
|
||||
if (capabilities?.prompts?.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports prompt updates. Listening for changes...`,
|
||||
);
|
||||
if (capabilities?.prompts) {
|
||||
if (capabilities.prompts.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports prompt updates. Listening for changes...`,
|
||||
);
|
||||
} else {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' has prompts but did not declare 'listChanged' capability. Listening anyway for robustness...`,
|
||||
);
|
||||
}
|
||||
|
||||
this.client.setNotificationHandler(
|
||||
PromptListChangedNotificationSchema,
|
||||
@@ -451,6 +477,25 @@ export class McpClient implements McpProgressReporter {
|
||||
let newResources;
|
||||
try {
|
||||
newResources = await this.discoverResources();
|
||||
|
||||
// Verification Retry: If no resources are found or resources didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentResources =
|
||||
this.resourceRegistry.getResourcesByServer(this.serverName) || [];
|
||||
const resourceMatch =
|
||||
newResources.length === currentResources.length &&
|
||||
newResources.every((nr: Resource) =>
|
||||
currentResources.some((cr: MCPResource) => cr.uri === nr.uri),
|
||||
);
|
||||
|
||||
if (resourceMatch && !this.pendingResourceRefresh) {
|
||||
debugLogger.log(
|
||||
`No resource changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newResources = await this.discoverResources();
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
`Resource discovery failed during refresh: ${getErrorMessage(err)}`,
|
||||
@@ -461,6 +506,10 @@ export class McpClient implements McpProgressReporter {
|
||||
|
||||
this.updateResourceRegistry(newResources);
|
||||
|
||||
if (this.onContextUpdated) {
|
||||
await this.onContextUpdated(abortController.signal);
|
||||
}
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
this.cliConfig.emitMcpDiagnostic(
|
||||
@@ -476,7 +525,6 @@ export class McpClient implements McpProgressReporter {
|
||||
);
|
||||
} finally {
|
||||
this.isRefreshingResources = false;
|
||||
this.pendingResourceRefresh = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -519,9 +567,31 @@ export class McpClient implements McpProgressReporter {
|
||||
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
|
||||
|
||||
try {
|
||||
const newPrompts = await this.fetchPrompts({
|
||||
let newPrompts = await this.fetchPrompts({
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
// Verification Retry: If no prompts are found or prompts didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentPrompts =
|
||||
this.promptRegistry.getPromptsByServer(this.serverName) || [];
|
||||
const promptsMatch =
|
||||
newPrompts.length === currentPrompts.length &&
|
||||
newPrompts.every((np) =>
|
||||
currentPrompts.some((cp) => cp.name === np.name),
|
||||
);
|
||||
|
||||
if (promptsMatch && !this.pendingPromptRefresh) {
|
||||
debugLogger.log(
|
||||
`No prompt changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newPrompts = await this.fetchPrompts({
|
||||
signal: abortController.signal,
|
||||
});
|
||||
}
|
||||
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
for (const prompt of newPrompts) {
|
||||
this.promptRegistry.registerPrompt(prompt);
|
||||
@@ -534,6 +604,10 @@ export class McpClient implements McpProgressReporter {
|
||||
break;
|
||||
}
|
||||
|
||||
if (this.onContextUpdated) {
|
||||
await this.onContextUpdated(abortController.signal);
|
||||
}
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
this.cliConfig.emitMcpDiagnostic(
|
||||
@@ -549,7 +623,6 @@ export class McpClient implements McpProgressReporter {
|
||||
);
|
||||
} finally {
|
||||
this.isRefreshingPrompts = false;
|
||||
this.pendingPromptRefresh = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -594,6 +667,38 @@ export class McpClient implements McpProgressReporter {
|
||||
newTools = await this.discoverTools(this.cliConfig, {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
debugLogger.log(
|
||||
`Refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
|
||||
// Verification Retry (Option 3): If no tools are found or tools didn't change,
|
||||
// wait briefly and try one more time. Some servers notify before they're fully ready.
|
||||
const currentTools =
|
||||
this.toolRegistry.getToolsByServer(this.serverName) || [];
|
||||
const toolNamesMatch =
|
||||
newTools.length === currentTools.length &&
|
||||
newTools.every((nt) =>
|
||||
currentTools.some(
|
||||
(ct) =>
|
||||
ct.name === nt.name ||
|
||||
(ct instanceof DiscoveredMCPTool &&
|
||||
ct.serverToolName === nt.serverToolName),
|
||||
),
|
||||
);
|
||||
|
||||
if (toolNamesMatch && !this.pendingToolRefresh) {
|
||||
debugLogger.log(
|
||||
`No tool changes detected for '${this.serverName}'. Retrying once in 500ms...`,
|
||||
);
|
||||
const retryDelay = 500;
|
||||
await new Promise((resolve) => setTimeout(resolve, retryDelay));
|
||||
newTools = await this.discoverTools(this.cliConfig, {
|
||||
signal: abortController.signal,
|
||||
});
|
||||
debugLogger.log(
|
||||
`Retry refresh for '${this.serverName}' discovered ${newTools.length} tools.`,
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
`Discovery failed during refresh: ${getErrorMessage(err)}`,
|
||||
@@ -609,8 +714,8 @@ export class McpClient implements McpProgressReporter {
|
||||
}
|
||||
this.toolRegistry.sortTools();
|
||||
|
||||
if (this.onToolsUpdated) {
|
||||
await this.onToolsUpdated(abortController.signal);
|
||||
if (this.onContextUpdated) {
|
||||
await this.onContextUpdated(abortController.signal);
|
||||
}
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
@@ -628,7 +733,6 @@ export class McpClient implements McpProgressReporter {
|
||||
);
|
||||
} finally {
|
||||
this.isRefreshingTools = false;
|
||||
this.pendingToolRefresh = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user