feat (mcp): Refresh MCP prompts on list changed notification (#14863)

Co-authored-by: christine betts <chrstn@uw.edu>
Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
Alex Gavrilescu
2026-01-27 13:50:27 +01:00
committed by GitHub
parent eccc200f4f
commit 88d3df912f
2 changed files with 197 additions and 43 deletions

View File

@@ -223,7 +223,7 @@ describe('mcp-client', () => {
consoleWarnSpy.mockRestore();
});
it('should handle errors when discovering prompts', async () => {
it('should propagate errors when discovering prompts', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
@@ -269,9 +269,7 @@ describe('mcp-client', () => {
'0.0.1',
);
await client.connect();
await expect(client.discover({} as Config)).rejects.toThrow(
'No prompts, tools, or resources found on the server.',
);
await expect(client.discover({} as Config)).rejects.toThrow('Test error');
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
'error',
`Error discovering prompts from test-server: Test error`,
@@ -640,6 +638,89 @@ describe('mcp-client', () => {
);
});
it('refreshes prompts when prompt list change notification is received', async () => {
let listCallCount = 0;
let promptListHandler:
| ((notification: unknown) => Promise<void> | void)
| undefined;
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
setNotificationHandler: vi.fn((_, handler) => {
promptListHandler = handler;
}),
getServerCapabilities: vi
.fn()
.mockReturnValue({ prompts: { listChanged: true } }),
listPrompts: vi.fn().mockImplementation(() => {
listCallCount += 1;
if (listCallCount === 1) {
return Promise.resolve({
prompts: [{ name: 'one', description: 'first' }],
});
}
return Promise.resolve({
prompts: [{ name: 'two', description: 'second' }],
});
}),
request: vi.fn().mockResolvedValue({ prompts: [] }),
} as unknown as ClientLib.Client;
vi.mocked(ClientLib.Client).mockReturnValue(mockedClient);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry;
const promptRegistry = {
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry;
const resourceRegistry = {
setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config,
false,
'0.0.1',
);
await client.connect();
await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config);
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
expect(promptListHandler).toBeDefined();
await promptListHandler?.({
method: 'notifications/prompts/list_changed',
});
expect(promptRegistry.removePromptsByServer).toHaveBeenCalledWith(
'test-server',
);
expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith(
expect.objectContaining({ name: 'two' }),
);
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
'info',
'Prompts updated for server: test-server',
);
});
it('should remove tools and prompts on disconnect', async () => {
const mockedClient = {
connect: vi.fn(),

View File

@@ -29,6 +29,7 @@ import {
ReadResourceResultSchema,
ResourceListChangedNotificationSchema,
ToolListChangedNotificationSchema,
PromptListChangedNotificationSchema,
type Tool as McpTool,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
@@ -112,6 +113,8 @@ export class McpClient {
private pendingToolRefresh: boolean = false;
private isRefreshingResources: boolean = false;
private pendingResourceRefresh: boolean = false;
private isRefreshingPrompts: boolean = false;
private pendingPromptRefresh: boolean = false;
constructor(
private readonly serverName: string,
@@ -174,7 +177,7 @@ export class McpClient {
async discover(cliConfig: Config): Promise<void> {
this.assertConnected();
const prompts = await this.discoverPrompts();
const prompts = await this.fetchPrompts();
const tools = await this.discoverTools(cliConfig);
const resources = await this.discoverResources();
this.updateResourceRegistry(resources);
@@ -183,6 +186,9 @@ export class McpClient {
throw new Error('No prompts, tools, or resources found on the server.');
}
for (const prompt of prompts) {
this.promptRegistry.registerPrompt(prompt);
}
for (const tool of tools) {
this.toolRegistry.registerTool(tool);
}
@@ -248,9 +254,11 @@ export class McpClient {
);
}
private async discoverPrompts(): Promise<Prompt[]> {
private async fetchPrompts(options?: {
signal?: AbortSignal;
}): Promise<DiscoveredMCPPrompt[]> {
this.assertConnected();
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
return discoverPrompts(this.serverName, this.client!, options);
}
private async discoverResources(): Promise<Resource[]> {
@@ -315,6 +323,22 @@ export class McpClient {
},
);
}
if (capabilities?.prompts?.listChanged) {
debugLogger.log(
`Server '${this.serverName}' supports prompt updates. Listening for changes...`,
);
this.client.setNotificationHandler(
PromptListChangedNotificationSchema,
async () => {
debugLogger.log(
`🔔 Received prompt update notification from '${this.serverName}'`,
);
await this.refreshPrompts();
},
);
}
}
/**
@@ -375,6 +399,63 @@ export class McpClient {
}
}
/**
* Refreshes prompts for this server by re-querying the MCP `prompts/list` endpoint.
*/
private async refreshPrompts(): Promise<void> {
if (this.isRefreshingPrompts) {
debugLogger.log(
`Prompt refresh for '${this.serverName}' is already in progress. Pending update.`,
);
this.pendingPromptRefresh = true;
return;
}
this.isRefreshingPrompts = true;
try {
do {
this.pendingPromptRefresh = false;
if (this.status !== MCPServerStatus.CONNECTED || !this.client) break;
const timeoutMs = this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC;
const abortController = new AbortController();
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
try {
const newPrompts = await this.fetchPrompts({
signal: abortController.signal,
});
this.promptRegistry.removePromptsByServer(this.serverName);
for (const prompt of newPrompts) {
this.promptRegistry.registerPrompt(prompt);
}
} catch (err) {
debugLogger.error(
`Prompt discovery failed during refresh: ${getErrorMessage(err)}`,
);
clearTimeout(timeoutId);
break;
}
clearTimeout(timeoutId);
coreEvents.emitFeedback(
'info',
`Prompts updated for server: ${this.serverName}`,
);
} while (this.pendingPromptRefresh);
} catch (error) {
debugLogger.error(
`Critical error in prompt refresh loop for ${this.serverName}: ${getErrorMessage(error)}`,
);
} finally {
this.isRefreshingPrompts = false;
this.pendingPromptRefresh = false;
}
}
getServerConfig(): MCPServerConfig {
return this.serverConfig;
}
@@ -840,11 +921,7 @@ export async function connectAndDiscover(
};
// Attempt to discover both prompts and tools
const prompts = await discoverPrompts(
mcpServerName,
mcpClient,
promptRegistry,
);
const prompts = await discoverPrompts(mcpServerName, mcpClient);
const tools = await discoverTools(
mcpServerName,
mcpServerConfig,
@@ -862,7 +939,10 @@ export async function connectAndDiscover(
// If we found anything, the server is connected
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
// Register any discovered tools
// Register any discovered prompts and tools
for (const prompt of prompts) {
promptRegistry.registerPrompt(prompt);
}
for (const tool of tools) {
toolRegistry.registerTool(tool);
}
@@ -1038,39 +1118,32 @@ class McpCallableTool implements CallableTool {
export async function discoverPrompts(
mcpServerName: string,
mcpClient: Client,
promptRegistry: PromptRegistry,
): Promise<Prompt[]> {
options?: { signal?: AbortSignal },
): Promise<DiscoveredMCPPrompt[]> {
// Only request prompts if the server supports them.
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
try {
// Only request prompts if the server supports them.
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
const response = await mcpClient.listPrompts({});
for (const prompt of response.prompts) {
promptRegistry.registerPrompt({
...prompt,
serverName: mcpServerName,
invoke: (params: Record<string, unknown>) =>
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
});
}
return response.prompts;
const response = await mcpClient.listPrompts({}, options);
return response.prompts.map((prompt) => ({
...prompt,
serverName: mcpServerName,
invoke: (params: Record<string, unknown>) =>
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
}));
} catch (error) {
// It's okay if this fails, not all servers will have prompts.
// Don't log an error if the method is not found, which is a common case.
if (
error instanceof Error &&
!error.message?.includes('Method not found')
) {
coreEvents.emitFeedback(
'error',
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
error,
)}`,
error,
);
// It's okay if the method is not found, which is a common case.
if (error instanceof Error && error.message?.includes('Method not found')) {
return [];
}
return [];
coreEvents.emitFeedback(
'error',
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
error,
)}`,
error,
);
throw error;
}
}