feat(core): add support for MCP progress updates (#19046)

This commit is contained in:
N. Taylor Mullen
2026-02-18 12:46:12 -08:00
committed by GitHub
parent 1cf05b0375
commit 14415316c0
14 changed files with 270 additions and 14 deletions
+17 -9
View File
@@ -18,7 +18,11 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import {
PromptListChangedNotificationSchema,
ResourceListChangedNotificationSchema,
ToolListChangedNotificationSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
@@ -140,7 +144,7 @@ describe('mcp-client', () => {
await client.discover({} as Config);
expect(mockedClient.listTools).toHaveBeenCalledWith(
{},
{ timeout: 600000 },
expect.objectContaining({ timeout: 600000, progressReporter: client }),
);
});
@@ -710,8 +714,10 @@ describe('mcp-client', () => {
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
setNotificationHandler: vi.fn((_, handler) => {
resourceListHandler = handler;
setNotificationHandler: vi.fn((schema, handler) => {
if (schema === ResourceListChangedNotificationSchema) {
resourceListHandler = handler;
}
}),
getServerCapabilities: vi
.fn()
@@ -772,7 +778,7 @@ describe('mcp-client', () => {
await client.connect();
await client.discover({} as Config);
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(resourceListHandler).toBeDefined();
await resourceListHandler?.({
@@ -802,8 +808,10 @@ describe('mcp-client', () => {
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
setNotificationHandler: vi.fn((_, handler) => {
promptListHandler = handler;
setNotificationHandler: vi.fn((schema, handler) => {
if (schema === PromptListChangedNotificationSchema) {
promptListHandler = handler;
}
}),
getServerCapabilities: vi
.fn()
@@ -854,7 +862,7 @@ describe('mcp-client', () => {
await client.connect();
await client.discover({ sanitizationConfig: EMPTY_CONFIG } as Config);
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
expect(mockedClient.setNotificationHandler).toHaveBeenCalledTimes(2);
expect(promptListHandler).toBeDefined();
await promptListHandler?.({
@@ -1023,7 +1031,7 @@ describe('mcp-client', () => {
await client.connect();
expect(mockedClient.setNotificationHandler).not.toHaveBeenCalled();
expect(mockedClient.setNotificationHandler).toHaveBeenCalledOnce();
});
it('should refresh tools and notify manager when notification is received', async () => {
+77 -4
View File
@@ -30,6 +30,7 @@ import {
ResourceListChangedNotificationSchema,
ToolListChangedNotificationSchema,
PromptListChangedNotificationSchema,
ProgressNotificationSchema,
type Tool as McpTool,
} from '@modelcontextprotocol/sdk/types.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
@@ -44,6 +45,7 @@ import { XcodeMcpBridgeFixTransport } from './xcode-mcp-fix-transport.js';
import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import { randomUUID } from 'node:crypto';
import type { McpAuthProvider } from '../mcp/auth-provider.js';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
@@ -58,6 +60,7 @@ import type {
Unsubscribe,
WorkspaceContext,
} from '../utils/workspaceContext.js';
import { getToolCallContext } from '../utils/toolCallContext.js';
import type { ToolRegistry } from './tool-registry.js';
import { debugLogger } from '../utils/debugLogger.js';
import { type MessageBus } from '../confirmation-bus/message-bus.js';
@@ -105,13 +108,21 @@ export enum MCPDiscoveryState {
COMPLETED = 'completed',
}
/**
* Interface for reporting progress from MCP tool calls.
*/
export interface McpProgressReporter {
registerProgressToken(token: string | number, callId: string): void;
unregisterProgressToken(token: string | number): void;
}
/**
* A client for a single MCP server.
*
* This class is responsible for connecting to, discovering tools from, and
* managing the state of a single MCP server.
*/
export class McpClient {
export class McpClient implements McpProgressReporter {
private client: Client | undefined;
private transport: Transport | undefined;
private status: MCPServerStatus = MCPServerStatus.DISCONNECTED;
@@ -122,6 +133,12 @@ export class McpClient {
private isRefreshingPrompts: boolean = false;
private pendingPromptRefresh: boolean = false;
/**
* Map of progress tokens to tool call IDs.
* This allows us to route progress notifications to the correct tool call.
*/
private readonly progressTokenToCallId = new Map<string | number, string>();
constructor(
private readonly serverName: string,
private readonly serverConfig: MCPServerConfig,
@@ -254,8 +271,11 @@ export class McpClient {
this.client!,
cliConfig,
this.toolRegistry.getMessageBus(),
options ?? {
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
{
...(options ?? {
timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}),
progressReporter: this,
},
);
}
@@ -349,6 +369,25 @@ export class McpClient {
},
);
}
this.client.setNotificationHandler(
ProgressNotificationSchema,
(notification) => {
const { progressToken, progress, total, message } = notification.params;
const callId = this.progressTokenToCallId.get(progressToken);
if (callId) {
coreEvents.emitMcpProgress({
serverName: this.serverName,
callId,
progressToken,
progress,
total,
message,
});
}
},
);
}
/**
@@ -409,6 +448,20 @@ export class McpClient {
}
}
/**
* Registers a progress token for a tool call.
*/
registerProgressToken(token: string | number, callId: string): void {
this.progressTokenToCallId.set(token, callId);
}
/**
* Unregisters a progress token.
*/
unregisterProgressToken(token: string | number): void {
this.progressTokenToCallId.delete(token);
}
/**
* Refreshes prompts for this server by re-querying the MCP `prompts/list` endpoint.
*/
@@ -994,7 +1047,11 @@ export async function discoverTools(
mcpClient: Client,
cliConfig: Config,
messageBus: MessageBus,
options?: { timeout?: number; signal?: AbortSignal },
options?: {
timeout?: number;
signal?: AbortSignal;
progressReporter?: McpProgressReporter;
},
): Promise<DiscoveredMCPTool[]> {
try {
// Only request tools if the server supports them.
@@ -1012,6 +1069,7 @@ export async function discoverTools(
mcpClient,
toolDef,
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
options?.progressReporter,
);
// Extract readOnlyHint from annotations
@@ -1078,6 +1136,7 @@ class McpCallableTool implements CallableTool {
private readonly client: Client,
private readonly toolDef: McpTool,
private readonly timeout: number,
private readonly progressReporter?: McpProgressReporter,
) {}
async tool(): Promise<Tool> {
@@ -1099,12 +1158,22 @@ class McpCallableTool implements CallableTool {
}
const call = functionCalls[0];
const progressToken = randomUUID();
const context = getToolCallContext();
if (context && this.progressReporter) {
this.progressReporter.registerProgressToken(
progressToken,
context.callId,
);
}
try {
const result = await this.client.callTool(
{
name: call.name!,
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
arguments: call.args as Record<string, unknown>,
_meta: { progressToken },
},
undefined,
{ timeout: this.timeout },
@@ -1133,6 +1202,10 @@ class McpCallableTool implements CallableTool {
},
},
];
} finally {
if (this.progressReporter) {
this.progressReporter.unregisterProgressToken(progressToken);
}
}
}
}