mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-04 00:44:05 -07:00
Add support for MCP dynamic tool update by notifications/tools/list_changed (#14375)
This commit is contained in:
@@ -18,6 +18,8 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import {
|
||||
connectToMcpServer,
|
||||
@@ -111,11 +113,15 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith({});
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||
{},
|
||||
{ timeout: 600000 },
|
||||
);
|
||||
});
|
||||
|
||||
it('should not skip tools even if a parameter is missing a type', async () => {
|
||||
@@ -177,6 +183,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -217,6 +224,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -261,6 +269,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -309,6 +318,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -371,6 +381,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -444,6 +455,7 @@ describe('mcp-client', () => {
|
||||
mockedToolRegistry,
|
||||
mockedPromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
@@ -459,6 +471,439 @@ describe('mcp-client', () => {
|
||||
expect(mockedPromptRegistry.removePromptsByServer).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Dynamic Tool Updates', () => {
|
||||
it('should set up notification handler if server supports tool list changes', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
// Capability enables the listener
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledWith(
|
||||
ToolListChangedNotificationSchema,
|
||||
expect.any(Function),
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT set up notification handler if server lacks capability', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), // No listChanged
|
||||
setNotificationHandler: vi.fn(),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
expect(mockedClient.setNotificationHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should refresh tools and notify manager when notification is received', async () => {
|
||||
// Setup mocks
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'newTool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
// Initialize client with onToolsUpdated callback
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
onToolsUpdatedSpy,
|
||||
);
|
||||
|
||||
// 1. Connect (sets up listener)
|
||||
await client.connect();
|
||||
|
||||
// 2. Extract the callback passed to setNotificationHandler
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
|
||||
// 3. Trigger the notification manually
|
||||
await notificationCallback();
|
||||
|
||||
// 4. Assertions
|
||||
// It should clear old tools
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
|
||||
// It should fetch new tools (listTools called inside discoverTools)
|
||||
expect(mockedClient.listTools).toHaveBeenCalled();
|
||||
|
||||
// It should register the new tool
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalled();
|
||||
|
||||
// It should notify the manager
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalled();
|
||||
|
||||
// It should emit feedback event
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'info',
|
||||
'Tools updated for server: test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors during tool refresh gracefully', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
// Simulate error during discovery
|
||||
listTools: vi.fn().mockRejectedValue(new Error('Network blip')),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
|
||||
// Trigger notification - should fail internally but catch the error
|
||||
await notificationCallback();
|
||||
|
||||
// Should try to remove tools
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalled();
|
||||
|
||||
// Should NOT emit success feedback
|
||||
expect(coreEvents.emitFeedback).not.toHaveBeenCalledWith(
|
||||
'info',
|
||||
expect.stringContaining('Tools updated'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle concurrent updates from multiple servers', async () => {
|
||||
const createMockSdkClient = (toolName: string) => ({
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: toolName,
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
});
|
||||
|
||||
const mockClientA = createMockSdkClient('tool-from-A');
|
||||
const mockClientB = createMockSdkClient('tool-from-B');
|
||||
|
||||
vi.mocked(ClientLib.Client)
|
||||
.mockReturnValueOnce(mockClientA as unknown as ClientLib.Client)
|
||||
.mockReturnValueOnce(mockClientB as unknown as ClientLib.Client);
|
||||
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const clientA = new McpClient(
|
||||
'server-A',
|
||||
{ command: 'cmd-a' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
onToolsUpdatedSpy,
|
||||
);
|
||||
|
||||
const clientB = new McpClient(
|
||||
'server-B',
|
||||
{ command: 'cmd-b' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
onToolsUpdatedSpy,
|
||||
);
|
||||
|
||||
await clientA.connect();
|
||||
await clientB.connect();
|
||||
|
||||
const handlerA = mockClientA.setNotificationHandler.mock.calls[0][1];
|
||||
const handlerB = mockClientB.setNotificationHandler.mock.calls[0][1];
|
||||
|
||||
// Trigger burst updates simultaneously
|
||||
await Promise.all([handlerA(), handlerB()]);
|
||||
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||
'server-A',
|
||||
);
|
||||
expect(mockedToolRegistry.removeMcpToolsByServer).toHaveBeenCalledWith(
|
||||
'server-B',
|
||||
);
|
||||
|
||||
// Verify fetching happened on both clients
|
||||
expect(mockClientA.listTools).toHaveBeenCalled();
|
||||
expect(mockClientB.listTools).toHaveBeenCalled();
|
||||
|
||||
// Verify tools from both servers were registered (2 total calls)
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Verify the update callback was triggered for both
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should abort discovery and log error if timeout is exceeded during refresh', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
// Mock listTools to simulate a long running process that respects the abort signal
|
||||
listTools: vi.fn().mockImplementation(
|
||||
async (params, options) =>
|
||||
new Promise((resolve, reject) => {
|
||||
if (options?.signal?.aborted) {
|
||||
return reject(new Error('Operation aborted'));
|
||||
}
|
||||
options?.signal?.addEventListener('abort', () => {
|
||||
reject(new Error('Operation aborted'));
|
||||
});
|
||||
// Intentionally do not resolve immediately to simulate lag
|
||||
}),
|
||||
),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
// Set a short timeout
|
||||
{ command: 'test-command', timeout: 100 },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
|
||||
const refreshPromise = notificationCallback();
|
||||
|
||||
vi.advanceTimersByTime(150);
|
||||
|
||||
await refreshPromise;
|
||||
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({
|
||||
signal: expect.any(AbortSignal),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).not.toHaveBeenCalled();
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should pass abort signal to onToolsUpdated callback', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: { listChanged: true } }),
|
||||
setNotificationHandler: vi.fn(),
|
||||
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
registerCapabilities: vi.fn().mockResolvedValue({}),
|
||||
setRequestHandler: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockedToolRegistry = {
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const onToolsUpdatedSpy = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
{} as Config,
|
||||
false,
|
||||
onToolsUpdatedSpy,
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
|
||||
const notificationCallback =
|
||||
mockedClient.setNotificationHandler.mock.calls[0][1];
|
||||
|
||||
await notificationCallback();
|
||||
|
||||
expect(onToolsUpdatedSpy).toHaveBeenCalledWith(expect.any(AbortSignal));
|
||||
|
||||
// Verify the signal passed was not aborted (happy path)
|
||||
const signal = onToolsUpdatedSpy.mock.calls[0][0];
|
||||
expect(signal.aborted).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('appendMcpServerCommand', () => {
|
||||
it('should do nothing if no MCP servers or command are configured', () => {
|
||||
const out = populateMcpServerCommand({}, undefined);
|
||||
|
||||
Reference in New Issue
Block a user