Revert "fix(mcp): ensure MCP transport is closed to prevent memory leaks" (#18771)

This commit is contained in:
Shreya Keshive
2026-02-10 17:00:36 -05:00
committed by GitHub
parent 49533cd106
commit 9590a092ae
2 changed files with 25 additions and 59 deletions
+8 -8
View File
@@ -901,9 +901,9 @@ describe('mcp-client', () => {
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
); );
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({ vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
close: vi.fn(), {} as SdkClientStdioLib.StdioClientTransport,
} as unknown as SdkClientStdioLib.StdioClientTransport); );
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
unregisterTool: vi.fn(), unregisterTool: vi.fn(),
@@ -2040,7 +2040,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client.client).toBe(mockedClient); expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
@@ -2086,7 +2086,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client.client).toBe(mockedClient); expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl); expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
@@ -2181,7 +2181,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client.client).toBe(mockedClient); expect(client).toBe(mockedClient);
// First HTTP attempt fails, second SSE attempt succeeds // First HTTP attempt fails, second SSE attempt succeeds
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
}); });
@@ -2222,7 +2222,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client.client).toBe(mockedClient); expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockedClient.connect).toHaveBeenCalledTimes(2);
}); });
}); });
@@ -2307,7 +2307,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
EMPTY_CONFIG, EMPTY_CONFIG,
); );
expect(client.client).toBe(mockedClient); expect(client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(3); expect(mockedClient.connect).toHaveBeenCalledTimes(3);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
}); });
+17 -51
View File
@@ -146,7 +146,7 @@ export class McpClient {
} }
this.updateStatus(MCPServerStatus.CONNECTING); this.updateStatus(MCPServerStatus.CONNECTING);
try { try {
const { client, transport } = await connectToMcpServer( this.client = await connectToMcpServer(
this.clientVersion, this.clientVersion,
this.serverName, this.serverName,
this.serverConfig, this.serverConfig,
@@ -154,13 +154,11 @@ export class McpClient {
this.workspaceContext, this.workspaceContext,
this.cliConfig.sanitizationConfig, this.cliConfig.sanitizationConfig,
); );
this.client = client;
this.transport = transport;
this.registerNotificationHandlers(); this.registerNotificationHandlers();
const originalOnError = this.client.onerror; const originalOnError = this.client.onerror;
this.client.onerror = async (error) => { this.client.onerror = (error) => {
if (this.status !== MCPServerStatus.CONNECTED) { if (this.status !== MCPServerStatus.CONNECTED) {
return; return;
} }
@@ -171,14 +169,6 @@ export class McpClient {
error, error,
); );
this.updateStatus(MCPServerStatus.DISCONNECTED); this.updateStatus(MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (this.transport) {
try {
await this.transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
}; };
this.updateStatus(MCPServerStatus.CONNECTED); this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) { } catch (error) {
@@ -927,9 +917,8 @@ export async function connectAndDiscover(
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
let mcpClient: Client | undefined; let mcpClient: Client | undefined;
let transport: Transport | undefined;
try { try {
const result = await connectToMcpServer( mcpClient = await connectToMcpServer(
clientVersion, clientVersion,
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
@@ -937,20 +926,10 @@ export async function connectAndDiscover(
workspaceContext, workspaceContext,
cliConfig.sanitizationConfig, cliConfig.sanitizationConfig,
); );
mcpClient = result.client;
transport = result.transport;
mcpClient.onerror = async (error) => { mcpClient.onerror = (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error); coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED); updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
}; };
// Attempt to discover both prompts and tools // Attempt to discover both prompts and tools
@@ -1348,18 +1327,16 @@ function createSSETransportWithAuth(
* @param client The MCP client to connect * @param client The MCP client to connect
* @param config The MCP server configuration * @param config The MCP server configuration
* @param accessToken Optional OAuth access token for authentication * @param accessToken Optional OAuth access token for authentication
* @returns The transport used for connection
*/ */
async function connectWithSSETransport( async function connectWithSSETransport(
client: Client, client: Client,
config: MCPServerConfig, config: MCPServerConfig,
accessToken?: string | null, accessToken?: string | null,
): Promise<Transport> { ): Promise<void> {
const transport = createSSETransportWithAuth(config, accessToken); const transport = createSSETransportWithAuth(config, accessToken);
await client.connect(transport, { await client.connect(transport, {
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
return transport;
} }
/** /**
@@ -1389,7 +1366,6 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration * @param config The MCP server configuration
* @param accessToken The OAuth access token to use * @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server) * @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @returns The transport used for connection
*/ */
async function retryWithOAuth( async function retryWithOAuth(
client: Client, client: Client,
@@ -1397,21 +1373,17 @@ async function retryWithOAuth(
config: MCPServerConfig, config: MCPServerConfig,
accessToken: string, accessToken: string,
httpReturned404: boolean, httpReturned404: boolean,
): Promise<Transport> { ): Promise<void> {
if (httpReturned404) { if (httpReturned404) {
// HTTP returned 404, only try SSE // HTTP returned 404, only try SSE
debugLogger.log( debugLogger.log(
`Retrying SSE connection to '${serverName}' with OAuth token...`, `Retrying SSE connection to '${serverName}' with OAuth token...`,
); );
const transport = await connectWithSSETransport( await connectWithSSETransport(client, config, accessToken);
client,
config,
accessToken,
);
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`, `Successfully connected to '${serverName}' using SSE with OAuth.`,
); );
return transport; return;
} }
// HTTP returned 401, try HTTP with OAuth first // HTTP returned 401, try HTTP with OAuth first
@@ -1435,7 +1407,6 @@ async function retryWithOAuth(
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using HTTP with OAuth.`, `Successfully connected to '${serverName}' using HTTP with OAuth.`,
); );
return httpTransport;
} catch (httpError) { } catch (httpError) {
await httpTransport.close(); await httpTransport.close();
@@ -1447,15 +1418,10 @@ async function retryWithOAuth(
!config.httpUrl !config.httpUrl
) { ) {
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`); debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
const sseTransport = await connectWithSSETransport( await connectWithSSETransport(client, config, accessToken);
client,
config,
accessToken,
);
debugLogger.log( debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`, `Successfully connected to '${serverName}' using SSE with OAuth.`,
); );
return sseTransport;
} else { } else {
throw httpError; throw httpError;
} }
@@ -1469,7 +1435,7 @@ async function retryWithOAuth(
* *
* @param mcpServerName The name of the MCP server, used for logging and identification. * @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server. * @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance and its transport. * @returns A promise that resolves to a connected MCP `Client` instance.
* @throws An error if the connection fails or the configuration is invalid. * @throws An error if the connection fails or the configuration is invalid.
*/ */
export async function connectToMcpServer( export async function connectToMcpServer(
@@ -1479,7 +1445,7 @@ export async function connectToMcpServer(
debugMode: boolean, debugMode: boolean,
workspaceContext: WorkspaceContext, workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig, sanitizationConfig: EnvironmentSanitizationConfig,
): Promise<{ client: Client; transport: Transport }> { ): Promise<Client> {
const mcpClient = new Client( const mcpClient = new Client(
{ {
name: 'gemini-cli-mcp-client', name: 'gemini-cli-mcp-client',
@@ -1551,7 +1517,7 @@ export async function connectToMcpServer(
await mcpClient.connect(transport, { await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
return { client: mcpClient, transport }; return mcpClient;
} catch (error) { } catch (error) {
await transport.close(); await transport.close();
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
@@ -1583,7 +1549,7 @@ export async function connectToMcpServer(
try { try {
// Try SSE with stored OAuth token if available // Try SSE with stored OAuth token if available
// This ensures that SSE fallback works for authenticated servers // This ensures that SSE fallback works for authenticated servers
const sseTransport = await connectWithSSETransport( await connectWithSSETransport(
mcpClient, mcpClient,
mcpServerConfig, mcpServerConfig,
await getStoredOAuthToken(mcpServerName), await getStoredOAuthToken(mcpServerName),
@@ -1592,7 +1558,7 @@ export async function connectToMcpServer(
debugLogger.log( debugLogger.log(
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`, `MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
); );
return { client: mcpClient, transport: sseTransport }; return mcpClient;
} catch (sseFallbackError) { } catch (sseFallbackError) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
sseError = sseFallbackError as Error; sseError = sseFallbackError as Error;
@@ -1700,14 +1666,14 @@ export async function connectToMcpServer(
); );
} }
const oauthTransport = await retryWithOAuth( await retryWithOAuth(
mcpClient, mcpClient,
mcpServerName, mcpServerName,
mcpServerConfig, mcpServerConfig,
accessToken, accessToken,
httpReturned404, httpReturned404,
); );
return { client: mcpClient, transport: oauthTransport }; return mcpClient;
} else { } else {
throw new Error( throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`, `Failed to handle automatic OAuth for server '${mcpServerName}'`,
@@ -1788,7 +1754,7 @@ export async function connectToMcpServer(
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
}); });
// Connection successful with OAuth // Connection successful with OAuth
return { client: mcpClient, transport: oauthTransport }; return mcpClient;
} else { } else {
throw new Error( throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`, `OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,