mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 23:51:16 -07:00
feat(mcp/extensions): Allow users to selectively enable/disable MCP servers included in an extension( Issue #11057 & #17402) (#17434)
This commit is contained in:
@@ -62,6 +62,18 @@ vi.mock('../utils.js', () => ({
|
||||
exitCli: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockEnablementInstance = vi.hoisted(() => ({
|
||||
getDisplayState: vi.fn(),
|
||||
enable: vi.fn(),
|
||||
clearSessionDisable: vi.fn(),
|
||||
autoEnableServers: vi.fn(),
|
||||
}));
|
||||
vi.mock('../../config/mcp/mcpServerEnablement.js', () => ({
|
||||
McpServerEnablementManager: {
|
||||
getInstance: () => mockEnablementInstance,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('extensions enable command', () => {
|
||||
const mockLoadSettings = vi.mocked(loadSettings);
|
||||
const mockExtensionManager = vi.mocked(ExtensionManager);
|
||||
@@ -75,6 +87,12 @@ describe('extensions enable command', () => {
|
||||
.fn()
|
||||
.mockResolvedValue(undefined);
|
||||
mockExtensionManager.prototype.enableExtension = vi.fn();
|
||||
mockExtensionManager.prototype.getExtensions = vi.fn().mockReturnValue([]);
|
||||
mockEnablementInstance.getDisplayState.mockReset();
|
||||
mockEnablementInstance.enable.mockReset();
|
||||
mockEnablementInstance.clearSessionDisable.mockReset();
|
||||
mockEnablementInstance.autoEnableServers.mockReset();
|
||||
mockEnablementInstance.autoEnableServers.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -134,6 +152,50 @@ describe('extensions enable command', () => {
|
||||
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should auto-enable disabled MCP servers for the extension', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
mockEnablementInstance.autoEnableServers.mockResolvedValue([
|
||||
'test-server',
|
||||
]);
|
||||
mockExtensionManager.prototype.getExtensions = vi
|
||||
.fn()
|
||||
.mockReturnValue([
|
||||
{ name: 'my-extension', mcpServers: { 'test-server': {} } },
|
||||
]);
|
||||
|
||||
await handleEnable({ name: 'my-extension' });
|
||||
|
||||
expect(mockEnablementInstance.autoEnableServers).toHaveBeenCalledWith([
|
||||
'test-server',
|
||||
]);
|
||||
expect(emitConsoleLog).toHaveBeenCalledWith(
|
||||
'log',
|
||||
expect.stringContaining("MCP server 'test-server' was disabled"),
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
|
||||
it('should not log when MCP servers are already enabled', async () => {
|
||||
const mockCwd = vi.spyOn(process, 'cwd').mockReturnValue('/test/dir');
|
||||
mockEnablementInstance.autoEnableServers.mockResolvedValue([]);
|
||||
mockExtensionManager.prototype.getExtensions = vi
|
||||
.fn()
|
||||
.mockReturnValue([
|
||||
{ name: 'my-extension', mcpServers: { 'test-server': {} } },
|
||||
]);
|
||||
|
||||
await handleEnable({ name: 'my-extension' });
|
||||
|
||||
expect(mockEnablementInstance.autoEnableServers).toHaveBeenCalledWith([
|
||||
'test-server',
|
||||
]);
|
||||
expect(emitConsoleLog).not.toHaveBeenCalledWith(
|
||||
'log',
|
||||
expect.stringContaining("MCP server 'test-server' was disabled"),
|
||||
);
|
||||
mockCwd.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('enableCommand', () => {
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from '@google/gemini-cli-core';
|
||||
import { promptForSetting } from '../../config/extensions/extensionSettings.js';
|
||||
import { exitCli } from '../utils.js';
|
||||
import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js';
|
||||
|
||||
interface EnableArgs {
|
||||
name: string;
|
||||
@@ -37,6 +38,26 @@ export async function handleEnable(args: EnableArgs) {
|
||||
} else {
|
||||
await extensionManager.enableExtension(args.name, SettingScope.User);
|
||||
}
|
||||
|
||||
// Auto-enable any disabled MCP servers for this extension
|
||||
const extension = extensionManager
|
||||
.getExtensions()
|
||||
.find((e) => e.name === args.name);
|
||||
|
||||
if (extension?.mcpServers) {
|
||||
const mcpEnablementManager = McpServerEnablementManager.getInstance();
|
||||
const enabledServers = await mcpEnablementManager.autoEnableServers(
|
||||
Object.keys(extension.mcpServers ?? {}),
|
||||
);
|
||||
|
||||
for (const serverName of enabledServers) {
|
||||
debugLogger.log(
|
||||
`MCP server '${serverName}' was disabled - now enabled.`,
|
||||
);
|
||||
}
|
||||
// Note: No restartServer() - CLI exits immediately, servers load on next session
|
||||
}
|
||||
|
||||
if (args.scope) {
|
||||
debugLogger.log(
|
||||
`Extension "${args.name}" successfully enabled for scope "${args.scope}".`,
|
||||
|
||||
@@ -42,21 +42,6 @@ async function handleEnable(args: Args): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if server is from an extension
|
||||
const serverKey = Object.keys(servers).find(
|
||||
(key) => normalizeServerId(key) === name,
|
||||
);
|
||||
const server = serverKey ? servers[serverKey] : undefined;
|
||||
if (server?.extension) {
|
||||
debugLogger.log(
|
||||
`${RED}Error:${RESET} Server '${args.name}' is provided by extension '${server.extension.name}'.`,
|
||||
);
|
||||
debugLogger.log(
|
||||
`Use 'gemini extensions enable ${server.extension.name}' to manage this extension.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await canLoadServer(name, {
|
||||
adminMcpEnabled: settings.merged.admin?.mcp?.enabled ?? true,
|
||||
allowedList: settings.merged.mcp?.allowed,
|
||||
@@ -100,21 +85,6 @@ async function handleDisable(args: Args): Promise<void> {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if server is from an extension
|
||||
const serverKey = Object.keys(servers).find(
|
||||
(key) => normalizeServerId(key) === name,
|
||||
);
|
||||
const server = serverKey ? servers[serverKey] : undefined;
|
||||
if (server?.extension) {
|
||||
debugLogger.log(
|
||||
`${RED}Error:${RESET} Server '${args.name}' is provided by extension '${server.extension.name}'.`,
|
||||
);
|
||||
debugLogger.log(
|
||||
`Use 'gemini extensions disable ${server.extension.name}' to manage this extension.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (args.session) {
|
||||
manager.disableForSession(name);
|
||||
debugLogger.log(
|
||||
|
||||
@@ -323,6 +323,35 @@ export class McpServerEnablementManager {
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Auto-enable any disabled MCP servers by name.
|
||||
* Returns server names that were actually re-enabled.
|
||||
*/
|
||||
async autoEnableServers(serverNames: string[]): Promise<string[]> {
|
||||
const enabledServers: string[] = [];
|
||||
|
||||
for (const serverName of serverNames) {
|
||||
const normalizedName = normalizeServerId(serverName);
|
||||
const state = await this.getDisplayState(normalizedName);
|
||||
|
||||
let wasDisabled = false;
|
||||
if (state.isPersistentDisabled) {
|
||||
await this.enable(normalizedName);
|
||||
wasDisabled = true;
|
||||
}
|
||||
if (state.isSessionDisabled) {
|
||||
this.clearSessionDisable(normalizedName);
|
||||
wasDisabled = true;
|
||||
}
|
||||
|
||||
if (wasDisabled) {
|
||||
enabledServers.push(serverName);
|
||||
}
|
||||
}
|
||||
|
||||
return enabledServers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Read config from file asynchronously.
|
||||
*/
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
inferInstallMetadata,
|
||||
} from '../../config/extension-manager.js';
|
||||
import { SettingScope } from '../../config/settings.js';
|
||||
import { McpServerEnablementManager } from '../../config/mcp/mcpServerEnablement.js';
|
||||
import { theme } from '../semantic-colors.js';
|
||||
import { stat } from 'node:fs/promises';
|
||||
|
||||
@@ -381,6 +382,38 @@ async function enableAction(context: CommandContext, args: string) {
|
||||
type: MessageType.INFO,
|
||||
text: `Extension "${name}" enabled for the scope "${scope}"`,
|
||||
});
|
||||
|
||||
// Auto-enable any disabled MCP servers for this extension
|
||||
const extension = extensionManager
|
||||
.getExtensions()
|
||||
.find((e) => e.name === name);
|
||||
|
||||
if (extension?.mcpServers) {
|
||||
const mcpEnablementManager = McpServerEnablementManager.getInstance();
|
||||
const mcpClientManager = context.services.config?.getMcpClientManager();
|
||||
const enabledServers = await mcpEnablementManager.autoEnableServers(
|
||||
Object.keys(extension.mcpServers ?? {}),
|
||||
);
|
||||
|
||||
if (mcpClientManager && enabledServers.length > 0) {
|
||||
const restartPromises = enabledServers.map((serverName) =>
|
||||
mcpClientManager.restartServer(serverName).catch((error) => {
|
||||
context.ui.addItem({
|
||||
type: MessageType.WARNING,
|
||||
text: `Failed to restart MCP server '${serverName}': ${getErrorMessage(error)}`,
|
||||
});
|
||||
}),
|
||||
);
|
||||
await Promise.all(restartPromises);
|
||||
}
|
||||
|
||||
if (enabledServers.length > 0) {
|
||||
context.ui.addItem({
|
||||
type: MessageType.INFO,
|
||||
text: `Re-enabled MCP servers: ${enabledServers.join(', ')}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -397,19 +397,6 @@ async function handleEnableDisable(
|
||||
};
|
||||
}
|
||||
|
||||
// Check if server is from an extension
|
||||
const serverKey = Object.keys(servers).find(
|
||||
(key) => normalizeServerId(key) === name,
|
||||
);
|
||||
const server = serverKey ? servers[serverKey] : undefined;
|
||||
if (server?.extension) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `Server '${serverName}' is provided by extension '${server.extension.name}'.\nUse '/extensions ${action} ${server.extension.name}' to manage this extension.`,
|
||||
};
|
||||
}
|
||||
|
||||
const manager = McpServerEnablementManager.getInstance();
|
||||
|
||||
if (enable) {
|
||||
|
||||
Reference in New Issue
Block a user