mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
Revert "fix(mcp): ensure MCP transport is closed to prevent memory leaks" (#18771)
This commit is contained in:
@@ -901,9 +901,9 @@ describe('mcp-client', () => {
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
|
||||
close: vi.fn(),
|
||||
} as unknown as SdkClientStdioLib.StdioClientTransport);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
unregisterTool: vi.fn(),
|
||||
@@ -2040,7 +2040,7 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
EMPTY_CONFIG,
|
||||
);
|
||||
|
||||
expect(client.client).toBe(mockedClient);
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
|
||||
@@ -2086,7 +2086,7 @@ describe('connectToMcpServer with OAuth', () => {
|
||||
EMPTY_CONFIG,
|
||||
);
|
||||
|
||||
expect(client.client).toBe(mockedClient);
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
||||
@@ -2181,7 +2181,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
||||
EMPTY_CONFIG,
|
||||
);
|
||||
|
||||
expect(client.client).toBe(mockedClient);
|
||||
expect(client).toBe(mockedClient);
|
||||
// First HTTP attempt fails, second SSE attempt succeeds
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
@@ -2222,7 +2222,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
||||
EMPTY_CONFIG,
|
||||
);
|
||||
|
||||
expect(client.client).toBe(mockedClient);
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
@@ -2307,7 +2307,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
|
||||
EMPTY_CONFIG,
|
||||
);
|
||||
|
||||
expect(client.client).toBe(mockedClient);
|
||||
expect(client).toBe(mockedClient);
|
||||
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
|
||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
@@ -146,7 +146,7 @@ export class McpClient {
|
||||
}
|
||||
this.updateStatus(MCPServerStatus.CONNECTING);
|
||||
try {
|
||||
const { client, transport } = await connectToMcpServer(
|
||||
this.client = await connectToMcpServer(
|
||||
this.clientVersion,
|
||||
this.serverName,
|
||||
this.serverConfig,
|
||||
@@ -154,13 +154,11 @@ export class McpClient {
|
||||
this.workspaceContext,
|
||||
this.cliConfig.sanitizationConfig,
|
||||
);
|
||||
this.client = client;
|
||||
this.transport = transport;
|
||||
|
||||
this.registerNotificationHandlers();
|
||||
|
||||
const originalOnError = this.client.onerror;
|
||||
this.client.onerror = async (error) => {
|
||||
this.client.onerror = (error) => {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
return;
|
||||
}
|
||||
@@ -171,14 +169,6 @@ export class McpClient {
|
||||
error,
|
||||
);
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||
// Close transport to prevent memory leaks
|
||||
if (this.transport) {
|
||||
try {
|
||||
await this.transport.close();
|
||||
} catch {
|
||||
// Ignore errors when closing transport on error
|
||||
}
|
||||
}
|
||||
};
|
||||
this.updateStatus(MCPServerStatus.CONNECTED);
|
||||
} catch (error) {
|
||||
@@ -927,9 +917,8 @@ export async function connectAndDiscover(
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
|
||||
let mcpClient: Client | undefined;
|
||||
let transport: Transport | undefined;
|
||||
try {
|
||||
const result = await connectToMcpServer(
|
||||
mcpClient = await connectToMcpServer(
|
||||
clientVersion,
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
@@ -937,20 +926,10 @@ export async function connectAndDiscover(
|
||||
workspaceContext,
|
||||
cliConfig.sanitizationConfig,
|
||||
);
|
||||
mcpClient = result.client;
|
||||
transport = result.transport;
|
||||
|
||||
mcpClient.onerror = async (error) => {
|
||||
mcpClient.onerror = (error) => {
|
||||
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||
// Close transport to prevent memory leaks
|
||||
if (transport) {
|
||||
try {
|
||||
await transport.close();
|
||||
} catch {
|
||||
// Ignore errors when closing transport on error
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Attempt to discover both prompts and tools
|
||||
@@ -1348,18 +1327,16 @@ function createSSETransportWithAuth(
|
||||
* @param client The MCP client to connect
|
||||
* @param config The MCP server configuration
|
||||
* @param accessToken Optional OAuth access token for authentication
|
||||
* @returns The transport used for connection
|
||||
*/
|
||||
async function connectWithSSETransport(
|
||||
client: Client,
|
||||
config: MCPServerConfig,
|
||||
accessToken?: string | null,
|
||||
): Promise<Transport> {
|
||||
): Promise<void> {
|
||||
const transport = createSSETransportWithAuth(config, accessToken);
|
||||
await client.connect(transport, {
|
||||
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
return transport;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1389,7 +1366,6 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
|
||||
* @param config The MCP server configuration
|
||||
* @param accessToken The OAuth access token to use
|
||||
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
|
||||
* @returns The transport used for connection
|
||||
*/
|
||||
async function retryWithOAuth(
|
||||
client: Client,
|
||||
@@ -1397,21 +1373,17 @@ async function retryWithOAuth(
|
||||
config: MCPServerConfig,
|
||||
accessToken: string,
|
||||
httpReturned404: boolean,
|
||||
): Promise<Transport> {
|
||||
): Promise<void> {
|
||||
if (httpReturned404) {
|
||||
// HTTP returned 404, only try SSE
|
||||
debugLogger.log(
|
||||
`Retrying SSE connection to '${serverName}' with OAuth token...`,
|
||||
);
|
||||
const transport = await connectWithSSETransport(
|
||||
client,
|
||||
config,
|
||||
accessToken,
|
||||
);
|
||||
await connectWithSSETransport(client, config, accessToken);
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||
);
|
||||
return transport;
|
||||
return;
|
||||
}
|
||||
|
||||
// HTTP returned 401, try HTTP with OAuth first
|
||||
@@ -1435,7 +1407,6 @@ async function retryWithOAuth(
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
|
||||
);
|
||||
return httpTransport;
|
||||
} catch (httpError) {
|
||||
await httpTransport.close();
|
||||
|
||||
@@ -1447,15 +1418,10 @@ async function retryWithOAuth(
|
||||
!config.httpUrl
|
||||
) {
|
||||
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
|
||||
const sseTransport = await connectWithSSETransport(
|
||||
client,
|
||||
config,
|
||||
accessToken,
|
||||
);
|
||||
await connectWithSSETransport(client, config, accessToken);
|
||||
debugLogger.log(
|
||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||
);
|
||||
return sseTransport;
|
||||
} else {
|
||||
throw httpError;
|
||||
}
|
||||
@@ -1469,7 +1435,7 @@ async function retryWithOAuth(
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
||||
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
||||
* @returns A promise that resolves to a connected MCP `Client` instance and its transport.
|
||||
* @returns A promise that resolves to a connected MCP `Client` instance.
|
||||
* @throws An error if the connection fails or the configuration is invalid.
|
||||
*/
|
||||
export async function connectToMcpServer(
|
||||
@@ -1479,7 +1445,7 @@ export async function connectToMcpServer(
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
sanitizationConfig: EnvironmentSanitizationConfig,
|
||||
): Promise<{ client: Client; transport: Transport }> {
|
||||
): Promise<Client> {
|
||||
const mcpClient = new Client(
|
||||
{
|
||||
name: 'gemini-cli-mcp-client',
|
||||
@@ -1551,7 +1517,7 @@ export async function connectToMcpServer(
|
||||
await mcpClient.connect(transport, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
return { client: mcpClient, transport };
|
||||
return mcpClient;
|
||||
} catch (error) {
|
||||
await transport.close();
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
@@ -1583,7 +1549,7 @@ export async function connectToMcpServer(
|
||||
try {
|
||||
// Try SSE with stored OAuth token if available
|
||||
// This ensures that SSE fallback works for authenticated servers
|
||||
const sseTransport = await connectWithSSETransport(
|
||||
await connectWithSSETransport(
|
||||
mcpClient,
|
||||
mcpServerConfig,
|
||||
await getStoredOAuthToken(mcpServerName),
|
||||
@@ -1592,7 +1558,7 @@ export async function connectToMcpServer(
|
||||
debugLogger.log(
|
||||
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
|
||||
);
|
||||
return { client: mcpClient, transport: sseTransport };
|
||||
return mcpClient;
|
||||
} catch (sseFallbackError) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
sseError = sseFallbackError as Error;
|
||||
@@ -1700,14 +1666,14 @@ export async function connectToMcpServer(
|
||||
);
|
||||
}
|
||||
|
||||
const oauthTransport = await retryWithOAuth(
|
||||
await retryWithOAuth(
|
||||
mcpClient,
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
accessToken,
|
||||
httpReturned404,
|
||||
);
|
||||
return { client: mcpClient, transport: oauthTransport };
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||
@@ -1788,7 +1754,7 @@ export async function connectToMcpServer(
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
// Connection successful with OAuth
|
||||
return { client: mcpClient, transport: oauthTransport };
|
||||
return mcpClient;
|
||||
} else {
|
||||
throw new Error(
|
||||
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||
|
||||
Reference in New Issue
Block a user