diff --git a/packages/cli/src/commands/extensions/enable.test.ts b/packages/cli/src/commands/extensions/enable.test.ts index 4b08f85046..aafb5193e4 100644 --- a/packages/cli/src/commands/extensions/enable.test.ts +++ b/packages/cli/src/commands/extensions/enable.test.ts @@ -62,6 +62,18 @@ vi.mock('../utils.js', () => ({ exitCli: vi.fn(), })); +const mockEnablementInstance = vi.hoisted(() => ({ + getDisplayState: vi.fn(), + enable: vi.fn(), + clearSessionDisable: vi.fn(), + autoEnableServers: vi.fn(), +})); +vi.mock('../../config/mcp/mcpServerEnablement.js', () => ({ + McpServerEnablementManager: { + getInstance: () => mockEnablementInstance, + }, +})); + describe('extensions enable command', () => { const mockLoadSettings = vi.mocked(loadSettings); const mockExtensionManager = vi.mocked(ExtensionManager); @@ -75,6 +87,12 @@ describe('extensions enable command', () => { .fn() .mockResolvedValue(undefined); mockExtensionManager.prototype.enableExtension = vi.fn(); + mockExtensionManager.prototype.getExtensions = vi.fn().mockReturnValue([]); + mockEnablementInstance.getDisplayState.mockReset(); + mockEnablementInstance.enable.mockReset(); + mockEnablementInstance.clearSessionDisable.mockReset(); + mockEnablementInstance.autoEnableServers.mockReset(); + mockEnablementInstance.autoEnableServers.mockResolvedValue([]); }); afterEach(() => { @@ -134,6 +152,50 @@ describe('extensions enable command', () => { mockCwd.mockRestore(); }); + + it('should auto-enable disabled MCP servers for the extension', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + mockEnablementInstance.autoEnableServers.mockResolvedValue([ + 'test-server', + ]); + mockExtensionManager.prototype.getExtensions = vi + .fn() + .mockReturnValue([ + { name: 'my-extension', mcpServers: { 'test-server': {} } }, + ]); + + await handleEnable({ name: 'my-extension' }); + + expect(mockEnablementInstance.autoEnableServers).toHaveBeenCalledWith([ + 'test-server', + ]); + expect(emitConsoleLog).toHaveBeenCalledWith( + 'log', + expect.stringContaining("MCP server 'test-server' was disabled"), + ); + mockCwd.mockRestore(); + }); + + it('should not log when MCP servers are already enabled', async () => { + const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir'); + mockEnablementInstance.autoEnableServers.mockResolvedValue([]); + mockExtensionManager.prototype.getExtensions = vi + .fn() + .mockReturnValue([ + { name: 'my-extension', mcpServers: { 'test-server': {} } }, + ]); + + await handleEnable({ name: 'my-extension' }); + + expect(mockEnablementInstance.autoEnableServers).toHaveBeenCalledWith([ + 'test-server', + ]); + expect(emitConsoleLog).not.toHaveBeenCalledWith( + 'log', + expect.stringContaining("MCP server 'test-server' was disabled"), + ); + mockCwd.mockRestore(); + }); }); describe('enableCommand', () => { diff --git a/packages/cli/src/commands/extensions/enable.ts b/packages/cli/src/commands/extensions/enable.ts index 67d7087d2d..55f3e596c4 100644 --- a/packages/cli/src/commands/extensions/enable.ts +++ b/packages/cli/src/commands/extensions/enable.ts @@ -15,6 +15,7 @@ import { } from '@google/gemini-cli-core'; import { promptForSetting } from '../../config/extensions/extensionSettings.js'; import { exitCli } from '../utils.js'; +import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js'; interface EnableArgs { name: string; @@ -37,6 +38,26 @@ export async function handleEnable(args: EnableArgs) { } else { await extensionManager.enableExtension(args.name, SettingScope.User); } + + // Auto-enable any disabled MCP servers for this extension + const extension = extensionManager + .getExtensions() + .find((e) => e.name === args.name); + + if (extension?.mcpServers) { + const mcpEnablementManager = McpServerEnablementManager.getInstance(); + const enabledServers = await mcpEnablementManager.autoEnableServers( + Object.keys(extension.mcpServers ?? {}), + ); + + for (const serverName of enabledServers) { + debugLogger.log( + `MCP server '${serverName}' was disabled - now enabled.`, + ); + } + // Note: No restartServer() - CLI exits immediately, servers load on next session + } + if (args.scope) { debugLogger.log( `Extension "${args.name}" successfully enabled for scope "${args.scope}".`, diff --git a/packages/cli/src/commands/mcp/enableDisable.ts b/packages/cli/src/commands/mcp/enableDisable.ts index f4146897eb..b47e259eca 100644 --- a/packages/cli/src/commands/mcp/enableDisable.ts +++ b/packages/cli/src/commands/mcp/enableDisable.ts @@ -42,21 +42,6 @@ async function handleEnable(args: Args): Promise { return; } - // Check if server is from an extension - const serverKey = Object.keys(servers).find( - (key) => normalizeServerId(key) === name, - ); - const server = serverKey ? servers[serverKey] : undefined; - if (server?.extension) { - debugLogger.log( - `${RED}Error:${RESET} Server '${args.name}' is provided by extension '${server.extension.name}'.`, - ); - debugLogger.log( - `Use 'gemini extensions enable ${server.extension.name}' to manage this extension.`, - ); - return; - } - const result = await canLoadServer(name, { adminMcpEnabled: settings.merged.admin?.mcp?.enabled ?? true, allowedList: settings.merged.mcp?.allowed, @@ -100,21 +85,6 @@ async function handleDisable(args: Args): Promise { return; } - // Check if server is from an extension - const serverKey = Object.keys(servers).find( - (key) => normalizeServerId(key) === name, - ); - const server = serverKey ? servers[serverKey] : undefined; - if (server?.extension) { - debugLogger.log( - `${RED}Error:${RESET} Server '${args.name}' is provided by extension '${server.extension.name}'.`, - ); - debugLogger.log( - `Use 'gemini extensions disable ${server.extension.name}' to manage this extension.`, - ); - return; - } - if (args.session) { manager.disableForSession(name); debugLogger.log( diff --git a/packages/cli/src/config/mcp/mcpServerEnablement.ts b/packages/cli/src/config/mcp/mcpServerEnablement.ts index da8a7a92a8..a510dd6697 100644 --- a/packages/cli/src/config/mcp/mcpServerEnablement.ts +++ b/packages/cli/src/config/mcp/mcpServerEnablement.ts @@ -323,6 +323,35 @@ export class McpServerEnablementManager { }; } + /** + * Auto-enable any disabled MCP servers by name. + * Returns server names that were actually re-enabled. + */ + async autoEnableServers(serverNames: string[]): Promise { + const enabledServers: string[] = []; + + for (const serverName of serverNames) { + const normalizedName = normalizeServerId(serverName); + const state = await this.getDisplayState(normalizedName); + + let wasDisabled = false; + if (state.isPersistentDisabled) { + await this.enable(normalizedName); + wasDisabled = true; + } + if (state.isSessionDisabled) { + this.clearSessionDisable(normalizedName); + wasDisabled = true; + } + + if (wasDisabled) { + enabledServers.push(serverName); + } + } + + return enabledServers; + } + /** * Read config from file asynchronously. */ diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index 6aa748153a..1258e30002 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -29,6 +29,7 @@ import { inferInstallMetadata, } from '../../config/extension-manager.js'; import { SettingScope } from '../../config/settings.js'; +import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js'; import { theme } from '../semantic-colors.js'; import { stat } from 'node:fs/promises'; @@ -381,6 +382,38 @@ async function enableAction(context: CommandContext, args: string) { type: MessageType.INFO, text: `Extension "${name}" enabled for the scope "${scope}"`, }); + + // Auto-enable any disabled MCP servers for this extension + const extension = extensionManager + .getExtensions() + .find((e) => e.name === name); + + if (extension?.mcpServers) { + const mcpEnablementManager = McpServerEnablementManager.getInstance(); + const mcpClientManager = context.services.config?.getMcpClientManager(); + const enabledServers = await mcpEnablementManager.autoEnableServers( + Object.keys(extension.mcpServers ?? {}), + ); + + if (mcpClientManager && enabledServers.length > 0) { + const restartPromises = enabledServers.map((serverName) => + mcpClientManager.restartServer(serverName).catch((error) => { + context.ui.addItem({ + type: MessageType.WARNING, + text: `Failed to restart MCP server '${serverName}': ${getErrorMessage(error)}`, + }); + }), + ); + await Promise.all(restartPromises); + } + + if (enabledServers.length > 0) { + context.ui.addItem({ + type: MessageType.INFO, + text: `Re-enabled MCP servers: ${enabledServers.join(', ')}`, + }); + } + } } } diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 4f4c098918..62154eb6fd 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -397,19 +397,6 @@ async function handleEnableDisable( }; } - // Check if server is from an extension - const serverKey = Object.keys(servers).find( - (key) => normalizeServerId(key) === name, - ); - const server = serverKey ? servers[serverKey] : undefined; - if (server?.extension) { - return { - type: 'message', - messageType: 'error', - content: `Server '${serverName}' is provided by extension '${server.extension.name}'.\nUse '/extensions ${action} ${server.extension.name}' to manage this extension.`, - }; - } - const manager = McpServerEnablementManager.getInstance(); if (enable) {