mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 05:55:17 -07:00
Add support for MCP dynamic tool update by notifications/tools/list_changed (#14375)
This commit is contained in:
@@ -174,7 +174,15 @@ export class McpClientManager {
|
|||||||
this.toolRegistry,
|
this.toolRegistry,
|
||||||
this.cliConfig.getPromptRegistry(),
|
this.cliConfig.getPromptRegistry(),
|
||||||
this.cliConfig.getWorkspaceContext(),
|
this.cliConfig.getWorkspaceContext(),
|
||||||
|
this.cliConfig,
|
||||||
this.cliConfig.getDebugMode(),
|
this.cliConfig.getDebugMode(),
|
||||||
|
async () => {
|
||||||
|
debugLogger.log('Tools changed, updating Gemini context...');
|
||||||
|
const geminiClient = this.cliConfig.getGeminiClient();
|
||||||
|
if (geminiClient.isInitialized()) {
|
||||||
|
await geminiClient.setTools();
|
||||||
|
}
|
||||||
|
},
|
||||||
);
|
);
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
this.clients.set(name, client);
|
this.clients.set(name, client);
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
|||||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||||
|
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||||
|
|
||||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||||
import {
|
import {
|
||||||
connectToMcpServer,
|
connectToMcpServer,
|
||||||
@@ -111,11 +113,15 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
await client.discover({} as Config);
|
await client.discover({} as Config);
|
||||||
expect(mockedClient.listTools).toHaveBeenCalledWith({});
|
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||||
|
{},
|
||||||
|
{ timeout: 600000 },
|
||||||
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should not skip tools even if a parameter is missing a type', async () => {
|
it('should not skip tools even if a parameter is missing a type', async () => {
|
||||||
@@ -177,6 +183,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -217,6 +224,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -261,6 +269,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -309,6 +318,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -371,6 +381,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
{} as PromptRegistry,
|
{} as PromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -444,6 +455,7 @@ describe('mcp-client', () => {
|
|||||||
mockedToolRegistry,
|
mockedToolRegistry,
|
||||||
mockedPromptRegistry,
|
mockedPromptRegistry,
|
||||||
workspaceContext,
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
await client.connect();
|
await client.connect();
|
||||||
@@ -459,6 +471,439 @@ describe('mcp-client', () => {
|
|||||||
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('Dynamic Tool Updates', () => {
|
||||||
|
it('should set up notification handler if server supports tool list changes', async () => {
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getStatus: vi.fn(),
|
||||||
|
registerCapabilities: vi.fn(),
|
||||||
|
setRequestHandler: vi.fn(),
|
||||||
|
// Capability enables the listener
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||||
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{ command: 'test-command' },
|
||||||
|
{} as ToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith(
|
||||||
|
ToolListChangedNotificationSchema,
|
||||||
|
expect.any(Function),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should NOT set up notification handler if server lacks capability', async () => {
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), // No listChanged
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{ command: 'test-command' },
|
||||||
|
{} as ToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
expect(mockedClient.setNotificationHandler).not.toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should refresh tools and notify manager when notification is received', async () => {
|
||||||
|
// Setup mocks
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
listTools: vi.fn().mockResolvedValue({
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: 'newTool',
|
||||||
|
inputSchema: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockedToolRegistry = {
|
||||||
|
removeMcpToolsByServer: vi.fn(),
|
||||||
|
registerTool: vi.fn(),
|
||||||
|
sortTools: vi.fn(),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
// Initialize client with onToolsUpdated callback
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{ command: 'test-command' },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
onToolsUpdatedSpy,
|
||||||
|
);
|
||||||
|
|
||||||
|
// 1. Connect (sets up listener)
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
// 2. Extract the callback passed to setNotificationHandler
|
||||||
|
const notificationCallback =
|
||||||
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||||
|
|
||||||
|
// 3. Trigger the notification manually
|
||||||
|
await notificationCallback();
|
||||||
|
|
||||||
|
// 4. Assertions
|
||||||
|
// It should clear old tools
|
||||||
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||||
|
'test-server',
|
||||||
|
);
|
||||||
|
|
||||||
|
// It should fetch new tools (listTools called inside discoverTools)
|
||||||
|
expect(mockedClient.listTools).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// It should register the new tool
|
||||||
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// It should notify the manager
|
||||||
|
expect(onToolsUpdatedSpy).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// It should emit feedback event
|
||||||
|
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||||
|
'info',
|
||||||
|
'Tools updated for server: test-server',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle errors during tool refresh gracefully', async () => {
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
// Simulate error during discovery
|
||||||
|
listTools: vi.fn().mockRejectedValue(new Error('Network blip')),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockedToolRegistry = {
|
||||||
|
removeMcpToolsByServer: vi.fn(),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{ command: 'test-command' },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
const notificationCallback =
|
||||||
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||||
|
|
||||||
|
// Trigger notification - should fail internally but catch the error
|
||||||
|
await notificationCallback();
|
||||||
|
|
||||||
|
// Should try to remove tools
|
||||||
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Should NOT emit success feedback
|
||||||
|
expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith(
|
||||||
|
'info',
|
||||||
|
expect.stringContaining('Tools updated'),
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle concurrent updates from multiple servers', async () => {
|
||||||
|
const createMockSdkClient = (toolName: string) => ({
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
listTools: vi.fn().mockResolvedValue({
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: toolName,
|
||||||
|
inputSchema: { type: 'object', properties: {} },
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}),
|
||||||
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockClientA = createMockSdkClient('tool-from-A');
|
||||||
|
const mockClientB = createMockSdkClient('tool-from-B');
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client)
|
||||||
|
.mockReturnValueOnce(mockClientA as unknown as ClientLib.Client)
|
||||||
|
.mockReturnValueOnce(mockClientB as unknown as ClientLib.Client);
|
||||||
|
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockedToolRegistry = {
|
||||||
|
removeMcpToolsByServer: vi.fn(),
|
||||||
|
registerTool: vi.fn(),
|
||||||
|
sortTools: vi.fn(),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
const clientA = new McpClient(
|
||||||
|
'server-A',
|
||||||
|
{ command: 'cmd-a' },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
onToolsUpdatedSpy,
|
||||||
|
);
|
||||||
|
|
||||||
|
const clientB = new McpClient(
|
||||||
|
'server-B',
|
||||||
|
{ command: 'cmd-b' },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
onToolsUpdatedSpy,
|
||||||
|
);
|
||||||
|
|
||||||
|
await clientA.connect();
|
||||||
|
await clientB.connect();
|
||||||
|
|
||||||
|
const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1];
|
||||||
|
const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1];
|
||||||
|
|
||||||
|
// Trigger burst updates simultaneously
|
||||||
|
await Promise.all([handlerA(), handlerB()]);
|
||||||
|
|
||||||
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||||
|
'server-A',
|
||||||
|
);
|
||||||
|
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||||
|
'server-B',
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify fetching happened on both clients
|
||||||
|
expect(mockClientA.listTools).toHaveBeenCalled();
|
||||||
|
expect(mockClientB.listTools).toHaveBeenCalled();
|
||||||
|
|
||||||
|
// Verify tools from both servers were registered (2 total calls)
|
||||||
|
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
||||||
|
|
||||||
|
// Verify the update callback was triggered for both
|
||||||
|
expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should abort discovery and log error if timeout is exceeded during refresh', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
// Mock listTools to simulate a long running process that respects the abort signal
|
||||||
|
listTools: vi.fn().mockImplementation(
|
||||||
|
async (params, options) =>
|
||||||
|
new Promise((resolve, reject) => {
|
||||||
|
if (options?.signal?.aborted) {
|
||||||
|
return reject(new Error('Operation aborted'));
|
||||||
|
}
|
||||||
|
options?.signal?.addEventListener('abort', () => {
|
||||||
|
reject(new Error('Operation aborted'));
|
||||||
|
});
|
||||||
|
// Intentionally do not resolve immediately to simulate lag
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockedToolRegistry = {
|
||||||
|
removeMcpToolsByServer: vi.fn(),
|
||||||
|
registerTool: vi.fn(),
|
||||||
|
sortTools: vi.fn(),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
// Set a short timeout
|
||||||
|
{ command: 'test-command', timeout: 100 },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
const notificationCallback =
|
||||||
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||||
|
|
||||||
|
const refreshPromise = notificationCallback();
|
||||||
|
|
||||||
|
vi.advanceTimersByTime(150);
|
||||||
|
|
||||||
|
await refreshPromise;
|
||||||
|
|
||||||
|
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
|
expect.objectContaining({
|
||||||
|
signal: expect.any(AbortSignal),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||||
|
|
||||||
|
vi.useRealTimers();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should pass abort signal to onToolsUpdated callback', async () => {
|
||||||
|
const mockedClient = {
|
||||||
|
connect: vi.fn(),
|
||||||
|
getServerCapabilities: vi
|
||||||
|
.fn()
|
||||||
|
.mockReturnValue({ tools: { listChanged: true } }),
|
||||||
|
setNotificationHandler: vi.fn(),
|
||||||
|
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||||
|
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||||
|
request: vi.fn().mockResolvedValue({}),
|
||||||
|
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||||
|
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
|
mockedClient as unknown as ClientLib.Client,
|
||||||
|
);
|
||||||
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockedToolRegistry = {
|
||||||
|
removeMcpToolsByServer: vi.fn(),
|
||||||
|
registerTool: vi.fn(),
|
||||||
|
sortTools: vi.fn(),
|
||||||
|
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||||
|
|
||||||
|
const client = new McpClient(
|
||||||
|
'test-server',
|
||||||
|
{ command: 'test-command' },
|
||||||
|
mockedToolRegistry,
|
||||||
|
{} as PromptRegistry,
|
||||||
|
workspaceContext,
|
||||||
|
{} as Config,
|
||||||
|
false,
|
||||||
|
onToolsUpdatedSpy,
|
||||||
|
);
|
||||||
|
|
||||||
|
await client.connect();
|
||||||
|
|
||||||
|
const notificationCallback =
|
||||||
|
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||||
|
|
||||||
|
await notificationCallback();
|
||||||
|
|
||||||
|
expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
|
||||||
|
|
||||||
|
// Verify the signal passed was not aborted (happy path)
|
||||||
|
const signal = onToolsUpdatedSpy.mock.calls[0][0];
|
||||||
|
expect(signal.aborted).toBe(false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
describe('appendMcpServerCommand', () => {
|
describe('appendMcpServerCommand', () => {
|
||||||
it('should do nothing if no MCP servers or command are configured', () => {
|
it('should do nothing if no MCP servers or command are configured', () => {
|
||||||
const out = populateMcpServerCommand({}, undefined);
|
const out = populateMcpServerCommand({}, undefined);
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import type {
|
|||||||
} from '@modelcontextprotocol/sdk/types.js';
|
} from '@modelcontextprotocol/sdk/types.js';
|
||||||
import {
|
import {
|
||||||
ListRootsRequestSchema,
|
ListRootsRequestSchema,
|
||||||
|
ToolListChangedNotificationSchema,
|
||||||
type Tool as McpTool,
|
type Tool as McpTool,
|
||||||
} from '@modelcontextprotocol/sdk/types.js';
|
} from '@modelcontextprotocol/sdk/types.js';
|
||||||
import { parse } from 'shell-quote';
|
import { parse } from 'shell-quote';
|
||||||
@@ -97,6 +98,8 @@ export class McpClient {
|
|||||||
private client: Client | undefined;
|
private client: Client | undefined;
|
||||||
private transport: Transport | undefined;
|
private transport: Transport | undefined;
|
||||||
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
|
||||||
|
private isRefreshing: boolean = false;
|
||||||
|
private pendingRefresh: boolean = false;
|
||||||
|
|
||||||
constructor(
|
constructor(
|
||||||
private readonly serverName: string,
|
private readonly serverName: string,
|
||||||
@@ -104,7 +107,9 @@ export class McpClient {
|
|||||||
private readonly toolRegistry: ToolRegistry,
|
private readonly toolRegistry: ToolRegistry,
|
||||||
private readonly promptRegistry: PromptRegistry,
|
private readonly promptRegistry: PromptRegistry,
|
||||||
private readonly workspaceContext: WorkspaceContext,
|
private readonly workspaceContext: WorkspaceContext,
|
||||||
|
private readonly cliConfig: Config,
|
||||||
private readonly debugMode: boolean,
|
private readonly debugMode: boolean,
|
||||||
|
private readonly onToolsUpdated?: (signal?: AbortSignal) => Promise<void>,
|
||||||
) {}
|
) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -124,6 +129,25 @@ export class McpClient {
|
|||||||
this.debugMode,
|
this.debugMode,
|
||||||
this.workspaceContext,
|
this.workspaceContext,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// setup dynamic tool listener
|
||||||
|
const capabilities = this.client.getServerCapabilities();
|
||||||
|
|
||||||
|
if (capabilities?.tools?.listChanged) {
|
||||||
|
debugLogger.log(
|
||||||
|
`Server '${this.serverName}' supports tool updates. Listening for changes...`,
|
||||||
|
);
|
||||||
|
|
||||||
|
this.client.setNotificationHandler(
|
||||||
|
ToolListChangedNotificationSchema,
|
||||||
|
async () => {
|
||||||
|
debugLogger.log(
|
||||||
|
`🔔 Received tool update notification from '${this.serverName}'`,
|
||||||
|
);
|
||||||
|
await this.refreshTools();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
const originalOnError = this.client.onerror;
|
const originalOnError = this.client.onerror;
|
||||||
this.client.onerror = (error) => {
|
this.client.onerror = (error) => {
|
||||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
@@ -204,7 +228,10 @@ export class McpClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
private async discoverTools(
|
||||||
|
cliConfig: Config,
|
||||||
|
options?: { timeout?: number; signal?: AbortSignal },
|
||||||
|
): Promise<DiscoveredMCPTool[]> {
|
||||||
this.assertConnected();
|
this.assertConnected();
|
||||||
return discoverTools(
|
return discoverTools(
|
||||||
this.serverName,
|
this.serverName,
|
||||||
@@ -212,6 +239,9 @@ export class McpClient {
|
|||||||
this.client!,
|
this.client!,
|
||||||
cliConfig,
|
cliConfig,
|
||||||
this.toolRegistry.getMessageBus(),
|
this.toolRegistry.getMessageBus(),
|
||||||
|
options ?? {
|
||||||
|
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,6 +257,75 @@ export class McpClient {
|
|||||||
getInstructions(): string | undefined {
|
getInstructions(): string | undefined {
|
||||||
return this.client?.getInstructions();
|
return this.client?.getInstructions();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Refreshes the tools for this server by re-querying the MCP `tools/list` endpoint.
|
||||||
|
*
|
||||||
|
* This method implements a **Coalescing Pattern** to handle rapid bursts of notifications
|
||||||
|
* (e.g., during server startup or bulk updates) without overwhelming the server or
|
||||||
|
* creating race conditions in the global ToolRegistry.
|
||||||
|
*/
|
||||||
|
private async refreshTools(): Promise<void> {
|
||||||
|
if (this.isRefreshing) {
|
||||||
|
debugLogger.log(
|
||||||
|
`Tool refresh for '${this.serverName}' is already in progress. Pending update.`,
|
||||||
|
);
|
||||||
|
this.pendingRefresh = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.isRefreshing = true;
|
||||||
|
|
||||||
|
try {
|
||||||
|
do {
|
||||||
|
this.pendingRefresh = false;
|
||||||
|
|
||||||
|
if (this.status !== MCPServerStatus.CONNECTED || !this.client) break;
|
||||||
|
|
||||||
|
const timeoutMs = this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC;
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
|
||||||
|
|
||||||
|
let newTools;
|
||||||
|
try {
|
||||||
|
newTools = await this.discoverTools(this.cliConfig, {
|
||||||
|
signal: abortController.signal,
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
debugLogger.error(
|
||||||
|
`Discovery failed during refresh: ${getErrorMessage(err)}`,
|
||||||
|
);
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||||
|
|
||||||
|
for (const tool of newTools) {
|
||||||
|
this.toolRegistry.registerTool(tool);
|
||||||
|
}
|
||||||
|
this.toolRegistry.sortTools();
|
||||||
|
|
||||||
|
if (this.onToolsUpdated) {
|
||||||
|
await this.onToolsUpdated(abortController.signal);
|
||||||
|
}
|
||||||
|
|
||||||
|
clearTimeout(timeoutId);
|
||||||
|
|
||||||
|
coreEvents.emitFeedback(
|
||||||
|
'info',
|
||||||
|
`Tools updated for server: ${this.serverName}`,
|
||||||
|
);
|
||||||
|
} while (this.pendingRefresh);
|
||||||
|
} catch (error) {
|
||||||
|
debugLogger.error(
|
||||||
|
`Critical error in refresh loop for ${this.serverName}: ${getErrorMessage(error)}`,
|
||||||
|
);
|
||||||
|
} finally {
|
||||||
|
this.isRefreshing = false;
|
||||||
|
this.pendingRefresh = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -622,6 +721,7 @@ export async function connectAndDiscover(
|
|||||||
mcpClient,
|
mcpClient,
|
||||||
cliConfig,
|
cliConfig,
|
||||||
toolRegistry.getMessageBus(),
|
toolRegistry.getMessageBus(),
|
||||||
|
{ timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC },
|
||||||
);
|
);
|
||||||
|
|
||||||
// If we have neither prompts nor tools, it's a failed discovery
|
// If we have neither prompts nor tools, it's a failed discovery
|
||||||
@@ -671,12 +771,13 @@ export async function discoverTools(
|
|||||||
mcpClient: Client,
|
mcpClient: Client,
|
||||||
cliConfig: Config,
|
cliConfig: Config,
|
||||||
messageBus?: MessageBus,
|
messageBus?: MessageBus,
|
||||||
|
options?: { timeout?: number; signal?: AbortSignal },
|
||||||
): Promise<DiscoveredMCPTool[]> {
|
): Promise<DiscoveredMCPTool[]> {
|
||||||
try {
|
try {
|
||||||
// Only request tools if the server supports them.
|
// Only request tools if the server supports them.
|
||||||
if (mcpClient.getServerCapabilities()?.tools == null) return [];
|
if (mcpClient.getServerCapabilities()?.tools == null) return [];
|
||||||
|
|
||||||
const response = await mcpClient.listTools({});
|
const response = await mcpClient.listTools({}, options);
|
||||||
const discoveredTools: DiscoveredMCPTool[] = [];
|
const discoveredTools: DiscoveredMCPTool[] = [];
|
||||||
for (const toolDef of response.tools) {
|
for (const toolDef of response.tools) {
|
||||||
try {
|
try {
|
||||||
|
|||||||
Reference in New Issue
Block a user