fix(mcp): ensure MCP transport is closed to prevent memory leaks (#18054)

Co-authored-by: Jack Wotherspoon <jackwoth@google.com>
This commit is contained in:
Chris Coutinho
2026-02-04 22:00:41 +01:00
committed by GitHub
parent 3afc8f25e1
commit 821355c429
2 changed files with 59 additions and 25 deletions

View File

@@ -749,9 +749,9 @@ describe('mcp-client', () => {
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
close: vi.fn(),
} as unknown as SdkClientStdioLib.StdioClientTransport);
const mockedToolRegistry = {
registerTool: vi.fn(),
unregisterTool: vi.fn(),
@@ -1888,7 +1888,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);
expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
@@ -1934,7 +1934,7 @@ describe('connectToMcpServer with OAuth', () => {
EMPTY_CONFIG,
);
expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
@@ -2029,7 +2029,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);
expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
// First HTTP attempt fails, second SSE attempt succeeds
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
@@ -2070,7 +2070,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
EMPTY_CONFIG,
);
expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
});
});
@@ -2155,7 +2155,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
EMPTY_CONFIG,
);
expect(client).toBe(mockedClient);
expect(client.client).toBe(mockedClient);
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
});

View File

@@ -144,7 +144,7 @@ export class McpClient {
}
this.updateStatus(MCPServerStatus.CONNECTING);
try {
this.client = await connectToMcpServer(
const { client, transport } = await connectToMcpServer(
this.clientVersion,
this.serverName,
this.serverConfig,
@@ -152,11 +152,13 @@ export class McpClient {
this.workspaceContext,
this.cliConfig.sanitizationConfig,
);
this.client = client;
this.transport = transport;
this.registerNotificationHandlers();
const originalOnError = this.client.onerror;
this.client.onerror = (error) => {
this.client.onerror = async (error) => {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
@@ -167,6 +169,14 @@ export class McpClient {
error,
);
this.updateStatus(MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (this.transport) {
try {
await this.transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
this.updateStatus(MCPServerStatus.CONNECTED);
} catch (error) {
@@ -909,8 +919,9 @@ export async function connectAndDiscover(
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
let mcpClient: Client | undefined;
let transport: Transport | undefined;
try {
mcpClient = await connectToMcpServer(
const result = await connectToMcpServer(
clientVersion,
mcpServerName,
mcpServerConfig,
@@ -918,10 +929,20 @@ export async function connectAndDiscover(
workspaceContext,
cliConfig.sanitizationConfig,
);
mcpClient = result.client;
transport = result.transport;
mcpClient.onerror = (error) => {
mcpClient.onerror = async (error) => {
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
// Close transport to prevent memory leaks
if (transport) {
try {
await transport.close();
} catch {
// Ignore errors when closing transport on error
}
}
};
// Attempt to discover both prompts and tools
@@ -1302,16 +1323,18 @@ function createSSETransportWithAuth(
* @param client The MCP client to connect
* @param config The MCP server configuration
* @param accessToken Optional OAuth access token for authentication
* @returns The transport used for connection
*/
async function connectWithSSETransport(
client: Client,
config: MCPServerConfig,
accessToken?: string | null,
): Promise<void> {
): Promise<Transport> {
const transport = createSSETransportWithAuth(config, accessToken);
await client.connect(transport, {
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return transport;
}
/**
@@ -1341,6 +1364,7 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
* @param config The MCP server configuration
* @param accessToken The OAuth access token to use
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
* @returns The transport used for connection
*/
async function retryWithOAuth(
client: Client,
@@ -1348,17 +1372,21 @@ async function retryWithOAuth(
config: MCPServerConfig,
accessToken: string,
httpReturned404: boolean,
): Promise<void> {
): Promise<Transport> {
if (httpReturned404) {
// HTTP returned 404, only try SSE
debugLogger.log(
`Retrying SSE connection to '${serverName}' with OAuth token...`,
);
await connectWithSSETransport(client, config, accessToken);
const transport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return;
return transport;
}
// HTTP returned 401, try HTTP with OAuth first
@@ -1382,6 +1410,7 @@ async function retryWithOAuth(
debugLogger.log(
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
);
return httpTransport;
} catch (httpError) {
await httpTransport.close();
@@ -1393,10 +1422,15 @@ async function retryWithOAuth(
!config.httpUrl
) {
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
await connectWithSSETransport(client, config, accessToken);
const sseTransport = await connectWithSSETransport(
client,
config,
accessToken,
);
debugLogger.log(
`Successfully connected to '${serverName}' using SSE with OAuth.`,
);
return sseTransport;
} else {
throw httpError;
}
@@ -1410,7 +1444,7 @@ async function retryWithOAuth(
*
* @param mcpServerName The name of the MCP server, used for logging and identification.
* @param mcpServerConfig The configuration specifying how to connect to the server.
* @returns A promise that resolves to a connected MCP `Client` instance.
* @returns A promise that resolves to a connected MCP `Client` instance and its transport.
* @throws An error if the connection fails or the configuration is invalid.
*/
export async function connectToMcpServer(
@@ -1420,7 +1454,7 @@ export async function connectToMcpServer(
debugMode: boolean,
workspaceContext: WorkspaceContext,
sanitizationConfig: EnvironmentSanitizationConfig,
): Promise<Client> {
): Promise<{ client: Client; transport: Transport }> {
const mcpClient = new Client(
{
name: 'gemini-cli-mcp-client',
@@ -1492,7 +1526,7 @@ export async function connectToMcpServer(
await mcpClient.connect(transport, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
return mcpClient;
return { client: mcpClient, transport };
} catch (error) {
await transport.close();
firstAttemptError = error as Error;
@@ -1523,7 +1557,7 @@ export async function connectToMcpServer(
try {
// Try SSE with stored OAuth token if available
// This ensures that SSE fallback works for authenticated servers
await connectWithSSETransport(
const sseTransport = await connectWithSSETransport(
mcpClient,
mcpServerConfig,
await getStoredOAuthToken(mcpServerName),
@@ -1532,7 +1566,7 @@ export async function connectToMcpServer(
debugLogger.log(
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
);
return mcpClient;
return { client: mcpClient, transport: sseTransport };
} catch (sseFallbackError) {
sseError = sseFallbackError as Error;
@@ -1639,14 +1673,14 @@ export async function connectToMcpServer(
);
}
await retryWithOAuth(
const oauthTransport = await retryWithOAuth(
mcpClient,
mcpServerName,
mcpServerConfig,
accessToken,
httpReturned404,
);
return mcpClient;
return { client: mcpClient, transport: oauthTransport };
} else {
throw new Error(
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
@@ -1727,7 +1761,7 @@ export async function connectToMcpServer(
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
// Connection successful with OAuth
return mcpClient;
return { client: mcpClient, transport: oauthTransport };
} else {
throw new Error(
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,