diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index fbd4785e65..bbab5ef12d 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -16,7 +16,7 @@ import { import { McpClientManager } from './mcp-client-manager.js'; import { McpClient, MCPDiscoveryState } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; -import type { Config } from '../config/config.js'; +import type { Config, GeminiCLIExtension } from '../config/config.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); @@ -320,4 +320,57 @@ describe('McpClientManager', () => { await expect(manager.restartServer('test-server')).resolves.not.toThrow(); }); }); + + describe('Extension handling', () => { + it('should remove mcp servers from allServerConfigs when stopExtension is called', async () => { + const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const mcpServers = { + 'test-server': { command: 'node', args: ['server.js'] }, + }; + const extension: GeminiCLIExtension = { + name: 'test-extension', + mcpServers, + isActive: true, + version: '1.0.0', + path: '/some-path', + contextFiles: [], + id: '123', + }; + + await manager.startExtension(extension); + expect(manager.getMcpServers()).toHaveProperty('test-server'); + + await manager.stopExtension(extension); + expect(manager.getMcpServers()).not.toHaveProperty('test-server'); + }); + + it('should remove servers from blockedMcpServers when stopExtension is called', async () => { + mockConfig.getBlockedMcpServers.mockReturnValue(['blocked-server']); + const manager = new McpClientManager('0.0.1', toolRegistry, mockConfig); + const mcpServers = { + 'blocked-server': { command: 'node', args: ['server.js'] }, + }; + const extension: GeminiCLIExtension = { + name: 'test-extension', + mcpServers, + isActive: true, + version: '1.0.0', + path: '/some-path', + contextFiles: [], + id: '123', + }; + + await manager.startExtension(extension); + expect(manager.getBlockedMcpServers()).toContainEqual({ + name: 'blocked-server', + extensionName: 'test-extension', + }); + + await manager.stopExtension(extension); + expect(manager.getBlockedMcpServers()).not.toContainEqual({ + name: 'blocked-server', + extensionName: 'test-extension', + }); + }); + }); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 743d7adb47..a56876323d 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -72,9 +72,21 @@ export class McpClientManager { async stopExtension(extension: GeminiCLIExtension) { debugLogger.log(`Unloading extension: ${extension.name}`); await Promise.all( - Object.keys(extension.mcpServers ?? {}).map((name) => - this.disconnectClient(name, true), - ), + Object.keys(extension.mcpServers ?? {}).map((name) => { + const config = this.allServerConfigs.get(name); + if (config?.extension === extension) { + this.allServerConfigs.delete(name); + // Also remove from blocked servers if present + const index = this.blockedMcpServers.findIndex( + (s) => s.name === name && s.extensionName === extension.name, + ); + if (index !== -1) { + this.blockedMcpServers.splice(index, 1); + } + return this.disconnectClient(name, true); + } + return Promise.resolve(); + }), ); await this.cliConfig.refreshMcpContext(); }