fix(mcp): ensure MCP transport is closed to prevent memory leaks (#18054)

Co-authored-by: Jack Wotherspoon <jackwoth@google.com>
This commit is contained in:
Chris Coutinho
2026-02-04 22:00:41 +01:00
committed by GitHub
parent 3afc8f25e1
commit 821355c429
2 changed files with 59 additions and 25 deletions
+8 -8
View File
@@ -749,9 +749,9 @@ describe('mcp-client', () => {
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
); );
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
{} as SdkClientStdioLib.StdioClientTransport, close: vi.fn(),
); } as unknown as SdkClientStdioLib.StdioClientTransport);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
unregisterTool: vi.fn(), unregisterTool: vi.fn(),
@@ -1888,7 +1888,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client).toBe(mockedClient); expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
@@ -1934,7 +1934,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client).toBe(mockedClient); expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl); expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
@@ -2029,7 +2029,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client).toBe(mockedClient); expect(client.client).toBe(mockedClient);
// First HTTP attempt fails, second SSE attempt succeeds // First HTTP attempt fails, second SSE attempt succeeds
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
}); });
@@ -2070,7 +2070,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client).toBe(mockedClient); expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
}); });
}); });
@@ -2155,7 +2155,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client).toBe(mockedClient); expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(3); expect(mockedClient.connect).toHaveBeenCalledTimes(3);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
}); });
+51 -17
View File
@@ -144,7 +144,7 @@ export class McpClient {
} }
this.updateStatus(MCPServerStatus.CONNECTING); this.updateStatus(MCPServerStatus.CONNECTING);
try { try {
this.client = await connectToMcpServer( const { client, transport } = await connectToMcpServer(
this.clientVersion, this.clientVersion,
this.serverName, this.serverName,
this.serverConfig, this.serverConfig,
@@ -152,11 +152,13 @@ export class McpClient {
this.workspaceContext, this.workspaceContext,
this.cliConfig.sanitizationConfig, this.cliConfig.sanitizationConfig,
); );
this.client = client;
this.transport = transport;
this.registerNotificationHandlers(); this.registerNotificationHandlers();
const originalOnError = this.client.onerror; const originalOnError = this.client.onerror;
this.client.onerror = (error) => { this.client.onerror = async (error) => {
if (this.status !== MCPServerStatus.CONNECTED) { if (this.status !== MCPServerStatus.CONNECTED) {
return; return;
} }
@@ -167,6 +169,14 @@ export class McpClient {
error, error,
); );
this.updateStatus(MCPServerStatus.DISCONNECTED); this.updateStatus(MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (this.transport) {
try {
await this.transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
}; };
this.updateStatus(MCPServerStatus.CONNECTED); this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) { } catch (error) {
@@ -909,8 +919,9 @@ export async function connectAndDiscover(
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
let mcpClient: Client | undefined; let mcpClient: Client | undefined;
let transport: Transport | undefined;
try { try {
mcpClient = await connectToMcpServer( const result = await connectToMcpServer(
clientVersion, clientVersion,
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
@@ -918,10 +929,20 @@ export async function connectAndDiscover(
workspaceContext, workspaceContext,
cliConfig.sanitizationConfig, cliConfig.sanitizationConfig,
); );
mcpClient = result.client;
transport = result.transport;
mcpClient.onerror = (error) => { mcpClient.onerror = async (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error); coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
}; };
// Attempt to discover both prompts and tools // Attempt to discover both prompts and tools
@@ -1302,16 +1323,18 @@ function createSSETransportWithAuth(
* @param client The MCP client to connect * @param client The MCP client to connect
* @param config The MCP server configuration * @param config The MCP server configuration
* @param accessToken Optional OAuth access token for authentication * @param accessToken Optional OAuth access token for authentication
* @returns The transport used for connection
*/ */
async function connectWithSSETransport( async function connectWithSSETransport(
client: Client, client: Client,
config: MCPServerConfig, config: MCPServerConfig,
accessToken?: string | null, accessToken?: string | null,
): Promise<void> { ): Promise<Transport> {
const transport = createSSETransportWithAuth(config, accessToken); const transport = createSSETransportWithAuth(config, accessToken);
await client.connect(transport, { await client.connect(transport, {
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
return transport;
} }
/** /**
@@ -1341,6 +1364,7 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration * @param config The MCP server configuration
* @param accessToken The OAuth access token to use * @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server) * @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @returns The transport used for connection
*/ */
async function retryWithOAuth( async function retryWithOAuth(
client: Client, client: Client,
@@ -1348,17 +1372,21 @@ async function retryWithOAuth(
config: MCPServerConfig, config: MCPServerConfig,
accessToken: string, accessToken: string,
httpReturned404: boolean, httpReturned404: boolean,
): Promise<void> { ): Promise<Transport> {
if (httpReturned404) { if (httpReturned404) {
// HTTP returned 404, only try SSE // HTTP returned 404, only try SSE
debugLogger.log( debugLogger.log(
`Retrying SSE connection to '${serverName}' with OAuth token...`, `Retrying SSE connection to '${serverName}' with OAuth token...`,
); );
await connectWithSSETransport(client, config, accessToken); const transport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`, `Successfully connected to '${serverName}' using SSE with OAuth.`,
); );
return; return transport;
} }
// HTTP returned 401, try HTTP with OAuth first // HTTP returned 401, try HTTP with OAuth first
@@ -1382,6 +1410,7 @@ async function retryWithOAuth(
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using HTTP with OAuth.`, `Successfully connected to '${serverName}' using HTTP with OAuth.`,
); );
return httpTransport;
} catch (httpError) { } catch (httpError) {
await httpTransport.close(); await httpTransport.close();
@@ -1393,10 +1422,15 @@ async function retryWithOAuth(
!config.httpUrl !config.httpUrl
) { ) {
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`); debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
await connectWithSSETransport(client, config, accessToken); const sseTransport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`, `Successfully connected to '${serverName}' using SSE with OAuth.`,
); );
return sseTransport;
} else { } else {
throw httpError; throw httpError;
} }
@@ -1410,7 +1444,7 @@ async function retryWithOAuth(
* *
* @param mcpServerName The name of the MCP server, used for logging and identification. * @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server. * @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance. * @returns A promise that resolves to a connected MCP `Client` instance and its transport.
* @throws An error if the connection fails or the configuration is invalid. * @throws An error if the connection fails or the configuration is invalid.
*/ */
export async function connectToMcpServer( export async function connectToMcpServer(
@@ -1420,7 +1454,7 @@ export async function connectToMcpServer(
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig, sanitizationConfig: EnvironmentSanitizationConfig,
): Promise<Client> { ): Promise<{ client: Client; transport: Transport }> {
const mcpClient = new Client( const mcpClient = new Client(
{ {
name: 'gemini-cli-mcp-client', name: 'gemini-cli-mcp-client',
@@ -1492,7 +1526,7 @@ export async function connectToMcpServer(
await mcpClient.connect(transport, { await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
return mcpClient; return { client: mcpClient, transport };
} catch (error) { } catch (error) {
await transport.close(); await transport.close();
firstAttemptError = error as Error; firstAttemptError = error as Error;
@@ -1523,7 +1557,7 @@ export async function connectToMcpServer(
try { try {
// Try SSE with stored OAuth token if available // Try SSE with stored OAuth token if available
// This ensures that SSE fallback works for authenticated servers // This ensures that SSE fallback works for authenticated servers
await connectWithSSETransport( const sseTransport = await connectWithSSETransport(
mcpClient, mcpClient,
mcpServerConfig, mcpServerConfig,
await getStoredOAuthToken(mcpServerName), await getStoredOAuthToken(mcpServerName),
@@ -1532,7 +1566,7 @@ export async function connectToMcpServer(
debugLogger.log( debugLogger.log(
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`, `MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
); );
return mcpClient; return { client: mcpClient, transport: sseTransport };
} catch (sseFallbackError) { } catch (sseFallbackError) {
sseError = sseFallbackError as Error; sseError = sseFallbackError as Error;
@@ -1639,14 +1673,14 @@ export async function connectToMcpServer(
); );
} }
await retryWithOAuth( const oauthTransport = await retryWithOAuth(
mcpClient, mcpClient,
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
accessToken, accessToken,
httpReturned404, httpReturned404,
); );
return mcpClient; return { client: mcpClient, transport: oauthTransport };
} else { } else {
throw new Error( throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`, `Failed to handle automatic OAuth for server '${mcpServerName}'`,
@@ -1727,7 +1761,7 @@ export async function connectToMcpServer(
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
// Connection successful with OAuth // Connection successful with OAuth
return mcpClient; return { client: mcpClient, transport: oauthTransport };
} else { } else {
throw new Error( throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`, `OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,