mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-26 21:14:35 -07:00
Revert "fix(mcp): ensure MCP transport is closed to prevent memory leaks" (#18771)
This commit is contained in:
@@ -901,9 +901,9 @@ describe('mcp-client', () => {
|
|||||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||||
mockedClient as unknown as ClientLib.Client,
|
mockedClient as unknown as ClientLib.Client,
|
||||||
);
|
);
|
||||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue({
|
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||||
close: vi.fn(),
|
{} as SdkClientStdioLib.StdioClientTransport,
|
||||||
} as unknown as SdkClientStdioLib.StdioClientTransport);
|
);
|
||||||
const mockedToolRegistry = {
|
const mockedToolRegistry = {
|
||||||
registerTool: vi.fn(),
|
registerTool: vi.fn(),
|
||||||
unregisterTool: vi.fn(),
|
unregisterTool: vi.fn(),
|
||||||
@@ -2040,7 +2040,7 @@ describe('connectToMcpServer with OAuth', () => {
|
|||||||
EMPTY_CONFIG,
|
EMPTY_CONFIG,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(client.client).toBe(mockedClient);
|
expect(client).toBe(mockedClient);
|
||||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||||
|
|
||||||
@@ -2086,7 +2086,7 @@ describe('connectToMcpServer with OAuth', () => {
|
|||||||
EMPTY_CONFIG,
|
EMPTY_CONFIG,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(client.client).toBe(mockedClient);
|
expect(client).toBe(mockedClient);
|
||||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||||
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl);
|
||||||
@@ -2181,7 +2181,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
|||||||
EMPTY_CONFIG,
|
EMPTY_CONFIG,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(client.client).toBe(mockedClient);
|
expect(client).toBe(mockedClient);
|
||||||
// First HTTP attempt fails, second SSE attempt succeeds
|
// First HTTP attempt fails, second SSE attempt succeeds
|
||||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
});
|
});
|
||||||
@@ -2222,7 +2222,7 @@ describe('connectToMcpServer - HTTP→SSE fallback', () => {
|
|||||||
EMPTY_CONFIG,
|
EMPTY_CONFIG,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(client.client).toBe(mockedClient);
|
expect(client).toBe(mockedClient);
|
||||||
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
expect(mockedClient.connect).toHaveBeenCalledTimes(2);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
@@ -2307,7 +2307,7 @@ describe('connectToMcpServer - OAuth with transport fallback', () => {
|
|||||||
EMPTY_CONFIG,
|
EMPTY_CONFIG,
|
||||||
);
|
);
|
||||||
|
|
||||||
expect(client.client).toBe(mockedClient);
|
expect(client).toBe(mockedClient);
|
||||||
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
|
expect(mockedClient.connect).toHaveBeenCalledTimes(3);
|
||||||
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce();
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ export class McpClient {
|
|||||||
}
|
}
|
||||||
this.updateStatus(MCPServerStatus.CONNECTING);
|
this.updateStatus(MCPServerStatus.CONNECTING);
|
||||||
try {
|
try {
|
||||||
const { client, transport } = await connectToMcpServer(
|
this.client = await connectToMcpServer(
|
||||||
this.clientVersion,
|
this.clientVersion,
|
||||||
this.serverName,
|
this.serverName,
|
||||||
this.serverConfig,
|
this.serverConfig,
|
||||||
@@ -154,13 +154,11 @@ export class McpClient {
|
|||||||
this.workspaceContext,
|
this.workspaceContext,
|
||||||
this.cliConfig.sanitizationConfig,
|
this.cliConfig.sanitizationConfig,
|
||||||
);
|
);
|
||||||
this.client = client;
|
|
||||||
this.transport = transport;
|
|
||||||
|
|
||||||
this.registerNotificationHandlers();
|
this.registerNotificationHandlers();
|
||||||
|
|
||||||
const originalOnError = this.client.onerror;
|
const originalOnError = this.client.onerror;
|
||||||
this.client.onerror = async (error) => {
|
this.client.onerror = (error) => {
|
||||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -171,14 +169,6 @@ export class McpClient {
|
|||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
this.updateStatus(MCPServerStatus.DISCONNECTED);
|
||||||
// Close transport to prevent memory leaks
|
|
||||||
if (this.transport) {
|
|
||||||
try {
|
|
||||||
await this.transport.close();
|
|
||||||
} catch {
|
|
||||||
// Ignore errors when closing transport on error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
this.updateStatus(MCPServerStatus.CONNECTED);
|
this.updateStatus(MCPServerStatus.CONNECTED);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -927,9 +917,8 @@ export async function connectAndDiscover(
|
|||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||||
|
|
||||||
let mcpClient: Client | undefined;
|
let mcpClient: Client | undefined;
|
||||||
let transport: Transport | undefined;
|
|
||||||
try {
|
try {
|
||||||
const result = await connectToMcpServer(
|
mcpClient = await connectToMcpServer(
|
||||||
clientVersion,
|
clientVersion,
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
@@ -937,20 +926,10 @@ export async function connectAndDiscover(
|
|||||||
workspaceContext,
|
workspaceContext,
|
||||||
cliConfig.sanitizationConfig,
|
cliConfig.sanitizationConfig,
|
||||||
);
|
);
|
||||||
mcpClient = result.client;
|
|
||||||
transport = result.transport;
|
|
||||||
|
|
||||||
mcpClient.onerror = async (error) => {
|
mcpClient.onerror = (error) => {
|
||||||
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
|
coreEvents.emitFeedback('error', `MCP ERROR (${mcpServerName}):`, error);
|
||||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
updateMCPServerStatus(mcpServerName, MCPServerStatus.DISCONNECTED);
|
||||||
// Close transport to prevent memory leaks
|
|
||||||
if (transport) {
|
|
||||||
try {
|
|
||||||
await transport.close();
|
|
||||||
} catch {
|
|
||||||
// Ignore errors when closing transport on error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Attempt to discover both prompts and tools
|
// Attempt to discover both prompts and tools
|
||||||
@@ -1348,18 +1327,16 @@ function createSSETransportWithAuth(
|
|||||||
* @param client The MCP client to connect
|
* @param client The MCP client to connect
|
||||||
* @param config The MCP server configuration
|
* @param config The MCP server configuration
|
||||||
* @param accessToken Optional OAuth access token for authentication
|
* @param accessToken Optional OAuth access token for authentication
|
||||||
* @returns The transport used for connection
|
|
||||||
*/
|
*/
|
||||||
async function connectWithSSETransport(
|
async function connectWithSSETransport(
|
||||||
client: Client,
|
client: Client,
|
||||||
config: MCPServerConfig,
|
config: MCPServerConfig,
|
||||||
accessToken?: string | null,
|
accessToken?: string | null,
|
||||||
): Promise<Transport> {
|
): Promise<void> {
|
||||||
const transport = createSSETransportWithAuth(config, accessToken);
|
const transport = createSSETransportWithAuth(config, accessToken);
|
||||||
await client.connect(transport, {
|
await client.connect(transport, {
|
||||||
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
});
|
});
|
||||||
return transport;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -1389,7 +1366,6 @@ async function showAuthRequiredMessage(serverName: string): Promise<never> {
|
|||||||
* @param config The MCP server configuration
|
* @param config The MCP server configuration
|
||||||
* @param accessToken The OAuth access token to use
|
* @param accessToken The OAuth access token to use
|
||||||
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
|
* @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server)
|
||||||
* @returns The transport used for connection
|
|
||||||
*/
|
*/
|
||||||
async function retryWithOAuth(
|
async function retryWithOAuth(
|
||||||
client: Client,
|
client: Client,
|
||||||
@@ -1397,21 +1373,17 @@ async function retryWithOAuth(
|
|||||||
config: MCPServerConfig,
|
config: MCPServerConfig,
|
||||||
accessToken: string,
|
accessToken: string,
|
||||||
httpReturned404: boolean,
|
httpReturned404: boolean,
|
||||||
): Promise<Transport> {
|
): Promise<void> {
|
||||||
if (httpReturned404) {
|
if (httpReturned404) {
|
||||||
// HTTP returned 404, only try SSE
|
// HTTP returned 404, only try SSE
|
||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`Retrying SSE connection to '${serverName}' with OAuth token...`,
|
`Retrying SSE connection to '${serverName}' with OAuth token...`,
|
||||||
);
|
);
|
||||||
const transport = await connectWithSSETransport(
|
await connectWithSSETransport(client, config, accessToken);
|
||||||
client,
|
|
||||||
config,
|
|
||||||
accessToken,
|
|
||||||
);
|
|
||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||||
);
|
);
|
||||||
return transport;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTP returned 401, try HTTP with OAuth first
|
// HTTP returned 401, try HTTP with OAuth first
|
||||||
@@ -1435,7 +1407,6 @@ async function retryWithOAuth(
|
|||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
|
`Successfully connected to '${serverName}' using HTTP with OAuth.`,
|
||||||
);
|
);
|
||||||
return httpTransport;
|
|
||||||
} catch (httpError) {
|
} catch (httpError) {
|
||||||
await httpTransport.close();
|
await httpTransport.close();
|
||||||
|
|
||||||
@@ -1447,15 +1418,10 @@ async function retryWithOAuth(
|
|||||||
!config.httpUrl
|
!config.httpUrl
|
||||||
) {
|
) {
|
||||||
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
|
debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`);
|
||||||
const sseTransport = await connectWithSSETransport(
|
await connectWithSSETransport(client, config, accessToken);
|
||||||
client,
|
|
||||||
config,
|
|
||||||
accessToken,
|
|
||||||
);
|
|
||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
`Successfully connected to '${serverName}' using SSE with OAuth.`,
|
||||||
);
|
);
|
||||||
return sseTransport;
|
|
||||||
} else {
|
} else {
|
||||||
throw httpError;
|
throw httpError;
|
||||||
}
|
}
|
||||||
@@ -1469,7 +1435,7 @@ async function retryWithOAuth(
|
|||||||
*
|
*
|
||||||
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
||||||
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
||||||
* @returns A promise that resolves to a connected MCP `Client` instance and its transport.
|
* @returns A promise that resolves to a connected MCP `Client` instance.
|
||||||
* @throws An error if the connection fails or the configuration is invalid.
|
* @throws An error if the connection fails or the configuration is invalid.
|
||||||
*/
|
*/
|
||||||
export async function connectToMcpServer(
|
export async function connectToMcpServer(
|
||||||
@@ -1479,7 +1445,7 @@ export async function connectToMcpServer(
|
|||||||
debugMode: boolean,
|
debugMode: boolean,
|
||||||
workspaceContext: WorkspaceContext,
|
workspaceContext: WorkspaceContext,
|
||||||
sanitizationConfig: EnvironmentSanitizationConfig,
|
sanitizationConfig: EnvironmentSanitizationConfig,
|
||||||
): Promise<{ client: Client; transport: Transport }> {
|
): Promise<Client> {
|
||||||
const mcpClient = new Client(
|
const mcpClient = new Client(
|
||||||
{
|
{
|
||||||
name: 'gemini-cli-mcp-client',
|
name: 'gemini-cli-mcp-client',
|
||||||
@@ -1551,7 +1517,7 @@ export async function connectToMcpServer(
|
|||||||
await mcpClient.connect(transport, {
|
await mcpClient.connect(transport, {
|
||||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
});
|
});
|
||||||
return { client: mcpClient, transport };
|
return mcpClient;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
await transport.close();
|
await transport.close();
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||||
@@ -1583,7 +1549,7 @@ export async function connectToMcpServer(
|
|||||||
try {
|
try {
|
||||||
// Try SSE with stored OAuth token if available
|
// Try SSE with stored OAuth token if available
|
||||||
// This ensures that SSE fallback works for authenticated servers
|
// This ensures that SSE fallback works for authenticated servers
|
||||||
const sseTransport = await connectWithSSETransport(
|
await connectWithSSETransport(
|
||||||
mcpClient,
|
mcpClient,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
await getStoredOAuthToken(mcpServerName),
|
await getStoredOAuthToken(mcpServerName),
|
||||||
@@ -1592,7 +1558,7 @@ export async function connectToMcpServer(
|
|||||||
debugLogger.log(
|
debugLogger.log(
|
||||||
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
|
`MCP server '${mcpServerName}': Successfully connected using SSE transport.`,
|
||||||
);
|
);
|
||||||
return { client: mcpClient, transport: sseTransport };
|
return mcpClient;
|
||||||
} catch (sseFallbackError) {
|
} catch (sseFallbackError) {
|
||||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||||
sseError = sseFallbackError as Error;
|
sseError = sseFallbackError as Error;
|
||||||
@@ -1700,14 +1666,14 @@ export async function connectToMcpServer(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const oauthTransport = await retryWithOAuth(
|
await retryWithOAuth(
|
||||||
mcpClient,
|
mcpClient,
|
||||||
mcpServerName,
|
mcpServerName,
|
||||||
mcpServerConfig,
|
mcpServerConfig,
|
||||||
accessToken,
|
accessToken,
|
||||||
httpReturned404,
|
httpReturned404,
|
||||||
);
|
);
|
||||||
return { client: mcpClient, transport: oauthTransport };
|
return mcpClient;
|
||||||
} else {
|
} else {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
`Failed to handle automatic OAuth for server '${mcpServerName}'`,
|
||||||
@@ -1788,7 +1754,7 @@ export async function connectToMcpServer(
|
|||||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||||
});
|
});
|
||||||
// Connection successful with OAuth
|
// Connection successful with OAuth
|
||||||
return { client: mcpClient, transport: oauthTransport };
|
return mcpClient;
|
||||||
} else {
|
} else {
|
||||||
throw new Error(
|
throw new Error(
|
||||||
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
`OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`,
|
||||||
|
|||||||
Reference in New Issue
Block a user