mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-15 08:31:14 -07:00
feat (mcp): Refresh MCP prompts on list changed notification (#14863)
Co-authored-by: christine betts <chrstn@uw.edu> Co-authored-by: Bryan Morgan <bryanmorgan@google.com>
This commit is contained in:
@@ -223,7 +223,7 @@ describe('mcp-client', () => {
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle errors when discovering prompts', async () => {
|
||||
it('should propagate errors when discovering prompts', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
@@ -269,9 +269,7 @@ describe('mcp-client', () => {
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await expect(client.discover({} as Config)).rejects.toThrow(
|
||||
'No prompts, tools, or resources found on the server.',
|
||||
);
|
||||
await expect(client.discover({} as Config)).rejects.toThrow('Test error');
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'error',
|
||||
`Error discovering prompts from test-server: Test error`,
|
||||
@@ -640,6 +638,89 @@ describe('mcp-client', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('refreshes prompts when prompt list change notification is received', async () => {
|
||||
let listCallCount = 0;
|
||||
let promptListHandler:
|
||||
| ((notification: unknown) => Promise<void> | void)
|
||||
| undefined;
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
setNotificationHandler: vi.fn((_, handler) => {
|
||||
promptListHandler = handler;
|
||||
}),
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ prompts: { listChanged: true } }),
|
||||
listPrompts: vi.fn().mockImplementation(() => {
|
||||
listCallCount += 1;
|
||||
if (listCallCount === 1) {
|
||||
return Promise.resolve({
|
||||
prompts: [{ name: 'one', description: 'first' }],
|
||||
});
|
||||
}
|
||||
return Promise.resolve({
|
||||
prompts: [{ name: 'two', description: 'second' }],
|
||||
});
|
||||
}),
|
||||
request: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
} as unknown as ClientLib.Client;
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(mockedClient);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config);
|
||||
|
||||
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
|
||||
expect(promptListHandler).toBeDefined();
|
||||
|
||||
await promptListHandler?.({
|
||||
method: 'notifications/prompts/list_changed',
|
||||
});
|
||||
|
||||
expect(promptRegistry.removePromptsByServer).toHaveBeenCalledWith(
|
||||
'test-server',
|
||||
);
|
||||
expect(promptRegistry.registerPrompt).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({ name: 'two' }),
|
||||
);
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'info',
|
||||
'Prompts updated for server: test-server',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove tools and prompts on disconnect', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
ReadResourceResultSchema,
|
||||
ResourceListChangedNotificationSchema,
|
||||
ToolListChangedNotificationSchema,
|
||||
PromptListChangedNotificationSchema,
|
||||
type Tool as McpTool,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
@@ -112,6 +113,8 @@ export class McpClient {
|
||||
private pendingToolRefresh: boolean = false;
|
||||
private isRefreshingResources: boolean = false;
|
||||
private pendingResourceRefresh: boolean = false;
|
||||
private isRefreshingPrompts: boolean = false;
|
||||
private pendingPromptRefresh: boolean = false;
|
||||
|
||||
constructor(
|
||||
private readonly serverName: string,
|
||||
@@ -174,7 +177,7 @@ export class McpClient {
|
||||
async discover(cliConfig: Config): Promise<void> {
|
||||
this.assertConnected();
|
||||
|
||||
const prompts = await this.discoverPrompts();
|
||||
const prompts = await this.fetchPrompts();
|
||||
const tools = await this.discoverTools(cliConfig);
|
||||
const resources = await this.discoverResources();
|
||||
this.updateResourceRegistry(resources);
|
||||
@@ -183,6 +186,9 @@ export class McpClient {
|
||||
throw new Error('No prompts, tools, or resources found on the server.');
|
||||
}
|
||||
|
||||
for (const prompt of prompts) {
|
||||
this.promptRegistry.registerPrompt(prompt);
|
||||
}
|
||||
for (const tool of tools) {
|
||||
this.toolRegistry.registerTool(tool);
|
||||
}
|
||||
@@ -248,9 +254,11 @@ export class McpClient {
|
||||
);
|
||||
}
|
||||
|
||||
private async discoverPrompts(): Promise<Prompt[]> {
|
||||
private async fetchPrompts(options?: {
|
||||
signal?: AbortSignal;
|
||||
}): Promise<DiscoveredMCPPrompt[]> {
|
||||
this.assertConnected();
|
||||
return discoverPrompts(this.serverName, this.client!, this.promptRegistry);
|
||||
return discoverPrompts(this.serverName, this.client!, options);
|
||||
}
|
||||
|
||||
private async discoverResources(): Promise<Resource[]> {
|
||||
@@ -315,6 +323,22 @@ export class McpClient {
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
if (capabilities?.prompts?.listChanged) {
|
||||
debugLogger.log(
|
||||
`Server '${this.serverName}' supports prompt updates. Listening for changes...`,
|
||||
);
|
||||
|
||||
this.client.setNotificationHandler(
|
||||
PromptListChangedNotificationSchema,
|
||||
async () => {
|
||||
debugLogger.log(
|
||||
`🔔 Received prompt update notification from '${this.serverName}'`,
|
||||
);
|
||||
await this.refreshPrompts();
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -375,6 +399,63 @@ export class McpClient {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refreshes prompts for this server by re-querying the MCP `prompts/list` endpoint.
|
||||
*/
|
||||
private async refreshPrompts(): Promise<void> {
|
||||
if (this.isRefreshingPrompts) {
|
||||
debugLogger.log(
|
||||
`Prompt refresh for '${this.serverName}' is already in progress. Pending update.`,
|
||||
);
|
||||
this.pendingPromptRefresh = true;
|
||||
return;
|
||||
}
|
||||
|
||||
this.isRefreshingPrompts = true;
|
||||
|
||||
try {
|
||||
do {
|
||||
this.pendingPromptRefresh = false;
|
||||
|
||||
if (this.status !== MCPServerStatus.CONNECTED || !this.client) break;
|
||||
|
||||
const timeoutMs = this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC;
|
||||
const abortController = new AbortController();
|
||||
const timeoutId = setTimeout(() => abortController.abort(), timeoutMs);
|
||||
|
||||
try {
|
||||
const newPrompts = await this.fetchPrompts({
|
||||
signal: abortController.signal,
|
||||
});
|
||||
this.promptRegistry.removePromptsByServer(this.serverName);
|
||||
for (const prompt of newPrompts) {
|
||||
this.promptRegistry.registerPrompt(prompt);
|
||||
}
|
||||
} catch (err) {
|
||||
debugLogger.error(
|
||||
`Prompt discovery failed during refresh: ${getErrorMessage(err)}`,
|
||||
);
|
||||
clearTimeout(timeoutId);
|
||||
break;
|
||||
}
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
coreEvents.emitFeedback(
|
||||
'info',
|
||||
`Prompts updated for server: ${this.serverName}`,
|
||||
);
|
||||
} while (this.pendingPromptRefresh);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Critical error in prompt refresh loop for ${this.serverName}: ${getErrorMessage(error)}`,
|
||||
);
|
||||
} finally {
|
||||
this.isRefreshingPrompts = false;
|
||||
this.pendingPromptRefresh = false;
|
||||
}
|
||||
}
|
||||
|
||||
getServerConfig(): MCPServerConfig {
|
||||
return this.serverConfig;
|
||||
}
|
||||
@@ -840,11 +921,7 @@ export async function connectAndDiscover(
|
||||
};
|
||||
|
||||
// Attempt to discover both prompts and tools
|
||||
const prompts = await discoverPrompts(
|
||||
mcpServerName,
|
||||
mcpClient,
|
||||
promptRegistry,
|
||||
);
|
||||
const prompts = await discoverPrompts(mcpServerName, mcpClient);
|
||||
const tools = await discoverTools(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
@@ -862,7 +939,10 @@ export async function connectAndDiscover(
|
||||
// If we found anything, the server is connected
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTED);
|
||||
|
||||
// Register any discovered tools
|
||||
// Register any discovered prompts and tools
|
||||
for (const prompt of prompts) {
|
||||
promptRegistry.registerPrompt(prompt);
|
||||
}
|
||||
for (const tool of tools) {
|
||||
toolRegistry.registerTool(tool);
|
||||
}
|
||||
@@ -1038,39 +1118,32 @@ class McpCallableTool implements CallableTool {
|
||||
export async function discoverPrompts(
|
||||
mcpServerName: string,
|
||||
mcpClient: Client,
|
||||
promptRegistry: PromptRegistry,
|
||||
): Promise<Prompt[]> {
|
||||
options?: { signal?: AbortSignal },
|
||||
): Promise<DiscoveredMCPPrompt[]> {
|
||||
// Only request prompts if the server supports them.
|
||||
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
|
||||
|
||||
try {
|
||||
// Only request prompts if the server supports them.
|
||||
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
|
||||
|
||||
const response = await mcpClient.listPrompts({});
|
||||
|
||||
for (const prompt of response.prompts) {
|
||||
promptRegistry.registerPrompt({
|
||||
...prompt,
|
||||
serverName: mcpServerName,
|
||||
invoke: (params: Record<string, unknown>) =>
|
||||
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||
});
|
||||
}
|
||||
return response.prompts;
|
||||
const response = await mcpClient.listPrompts({}, options);
|
||||
return response.prompts.map((prompt) => ({
|
||||
...prompt,
|
||||
serverName: mcpServerName,
|
||||
invoke: (params: Record<string, unknown>) =>
|
||||
invokeMcpPrompt(mcpServerName, mcpClient, prompt.name, params),
|
||||
}));
|
||||
} catch (error) {
|
||||
// It's okay if this fails, not all servers will have prompts.
|
||||
// Don't log an error if the method is not found, which is a common case.
|
||||
if (
|
||||
error instanceof Error &&
|
||||
!error.message?.includes('Method not found')
|
||||
) {
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
// It's okay if the method is not found, which is a common case.
|
||||
if (error instanceof Error && error.message?.includes('Method not found')) {
|
||||
return [];
|
||||
}
|
||||
return [];
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error discovering prompts from ${mcpServerName}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user