Extensions MCP refactor (#12413)

This commit is contained in:
Jacob MacDonald
2025-11-04 07:51:18 -08:00
committed by GitHub
parent 2b77c1ded4
commit da4fa5ad75
28 changed files with 877 additions and 478 deletions
+11 -28
View File
@@ -1146,9 +1146,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(baseSettings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1']);
});
it('should allow multiple specified MCP servers', async () => {
@@ -1162,10 +1160,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(baseSettings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
server3: { url: 'http://localhost:8082' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server3']);
});
it('should handle server names that do not exist', async () => {
@@ -1179,16 +1174,14 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(baseSettings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server4']);
});
it('should allow no MCP servers if the flag is provided but empty', async () => {
process.argv = ['node', 'script.js', '--allowed-mcp-server-names', ''];
const argv = await parseArguments({} as Settings);
const config = await loadCliConfig(baseSettings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({});
expect(config.getAllowedMcpServers()).toEqual(['']);
});
it('should read allowMCPServers from settings', async () => {
@@ -1199,10 +1192,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
mcp: { allowed: ['server1', 'server2'] },
};
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
server2: { url: 'http://localhost:8081' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']);
});
it('should read excludeMCPServers from settings', async () => {
@@ -1213,9 +1203,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
mcp: { excluded: ['server1', 'server2'] },
};
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server3: { url: 'http://localhost:8082' },
});
expect(config.getBlockedMcpServers()).toEqual(['server1', 'server2']);
});
it('should override allowMCPServers with excludeMCPServers if overlapping', async () => {
@@ -1229,9 +1217,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
},
};
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server2: { url: 'http://localhost:8081' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1', 'server2']);
expect(config.getBlockedMcpServers()).toEqual(['server1']);
});
it('should prioritize mcp server flag if set', async () => {
@@ -1250,9 +1237,7 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
},
};
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server1: { url: 'http://localhost:8080' },
});
expect(config.getAllowedMcpServers()).toEqual(['server1']);
});
it('should prioritize CLI flag over both allowed and excluded settings', async () => {
@@ -1273,10 +1258,8 @@ describe('loadCliConfig with allowed-mcp-server-names', () => {
},
};
const config = await loadCliConfig(settings, 'test-session', argv);
expect(config.getMcpServers()).toEqual({
server2: { url: 'http://localhost:8081' },
server3: { url: 'http://localhost:8082' },
});
expect(config.getAllowedMcpServers()).toEqual(['server2', 'server3']);
expect(config.getBlockedMcpServers()).toEqual([]);
});
});
+11 -97
View File
@@ -13,9 +13,7 @@ import process from 'node:process';
import { mcpCommand } from '../commands/mcp.js';
import type {
FileFilteringOptions,
MCPServerConfig,
OutputFormat,
GeminiCLIExtension,
} from '@google/gemini-cli-core';
import { extensionsCommand } from '../commands/extensions.js';
import {
@@ -49,9 +47,13 @@ import { appEvents } from '../utils/events.js';
import { isWorkspaceTrusted } from './trustedFolders.js';
import { createPolicyEngineConfig } from './policy.js';
import { ExtensionManager } from './extension-manager.js';
import type { ExtensionLoader } from '@google/gemini-cli-core/src/utils/extensionLoader.js';
import type {
ExtensionEvents,
ExtensionLoader,
} from '@google/gemini-cli-core/src/utils/extensionLoader.js';
import { requestConsentNonInteractive } from './extensions/consent.js';
import { promptForSetting } from './extensions/extensionSettings.js';
import type { EventEmitter } from 'node:stream';
export interface CliArgs {
query: string | undefined;
@@ -429,6 +431,7 @@ export async function loadCliConfig(
requestSetting: promptForSetting,
workspaceDir: cwd,
enabledExtensionOverrides: argv.extensions,
eventEmitter: appEvents as EventEmitter<ExtensionEvents>,
});
await extensionManager.loadExtensions();
@@ -448,7 +451,6 @@ export async function loadCliConfig(
memoryFileFiltering,
);
let mcpServers = mergeMcpServers(settings, extensionManager.getExtensions());
const question = argv.promptInteractive || argv.prompt || '';
// Determine approval mode with backward compatibility
@@ -565,37 +567,8 @@ export async function loadCliConfig(
const excludeTools = mergeExcludeTools(
settings,
extensionManager.getExtensions(),
extraExcludes.length > 0 ? extraExcludes : undefined,
);
const blockedMcpServers: Array<{ name: string; extensionName: string }> = [];
if (!argv.allowedMcpServerNames) {
if (settings.mcp?.allowed) {
mcpServers = allowedMcpServers(
mcpServers,
settings.mcp.allowed,
blockedMcpServers,
);
}
if (settings.mcp?.excluded) {
const excludedNames = new Set(settings.mcp.excluded.filter(Boolean));
if (excludedNames.size > 0) {
mcpServers = Object.fromEntries(
Object.entries(mcpServers).filter(([key]) => !excludedNames.has(key)),
);
}
}
}
if (argv.allowedMcpServerNames) {
mcpServers = allowedMcpServers(
mcpServers,
argv.allowedMcpServerNames,
blockedMcpServers,
);
}
const useModelRouter = settings.experimental?.useModelRouter ?? true;
const defaultModel = useModelRouter
@@ -633,7 +606,11 @@ export async function loadCliConfig(
toolDiscoveryCommand: settings.tools?.discoveryCommand,
toolCallCommand: settings.tools?.callCommand,
mcpServerCommand: settings.mcp?.serverCommand,
mcpServers,
mcpServers: settings.mcpServers,
allowedMcpServers: argv.allowedMcpServerNames ?? settings.mcp?.allowed,
blockedMcpServers: argv.allowedMcpServerNames
? [] // explicitly allowed servers overrides everything
: settings.mcp?.excluded,
userMemory: memoryContent,
geminiMdFileCount: fileCount,
geminiMdFilePaths: filePaths,
@@ -663,7 +640,6 @@ export async function loadCliConfig(
enabledExtensions: argv.extensions,
extensionLoader: extensionManager,
enableExtensionReloading: settings.experimental?.extensionReloading,
blockedMcpServers,
noBrowser: !!process.env['NO_BROWSER'],
summarizeToolOutput: settings.model?.summarizeToolOutput,
ideMode,
@@ -699,75 +675,13 @@ export async function loadCliConfig(
});
}
function allowedMcpServers(
mcpServers: { [x: string]: MCPServerConfig },
allowMCPServers: string[],
blockedMcpServers: Array<{ name: string; extensionName: string }>,
) {
const allowedNames = new Set(allowMCPServers.filter(Boolean));
if (allowedNames.size > 0) {
mcpServers = Object.fromEntries(
Object.entries(mcpServers).filter(([key, server]) => {
const isAllowed = allowedNames.has(key);
if (!isAllowed) {
blockedMcpServers.push({
name: key,
extensionName: server.extension?.name || '',
});
}
return isAllowed;
}),
);
} else {
blockedMcpServers.push(
...Object.entries(mcpServers).map(([key, server]) => ({
name: key,
extensionName: server.extension?.name || '',
})),
);
mcpServers = {};
}
return mcpServers;
}
function mergeMcpServers(settings: Settings, extensions: GeminiCLIExtension[]) {
const mcpServers = { ...(settings.mcpServers || {}) };
for (const extension of extensions) {
if (!extension.isActive) {
continue;
}
Object.entries(extension.mcpServers || {}).forEach(([key, server]) => {
if (mcpServers[key]) {
debugLogger.warn(
`Skipping extension MCP config for server with key "${key}" as it already exists.`,
);
return;
}
mcpServers[key] = {
...server,
extension,
};
});
}
return mcpServers;
}
function mergeExcludeTools(
settings: Settings,
extensions: GeminiCLIExtension[],
extraExcludes?: string[] | undefined,
): string[] {
const allExcludeTools = new Set([
...(settings.tools?.exclude || []),
...(extraExcludes || []),
]);
for (const extension of extensions) {
if (!extension.isActive) {
continue;
}
for (const tool of extension.excludeTools || []) {
allExcludeTools.add(tool);
}
}
return [...allExcludeTools];
}
+32 -25
View File
@@ -28,6 +28,7 @@ import {
ExtensionDisableEvent,
ExtensionEnableEvent,
ExtensionInstallEvent,
ExtensionLoader,
ExtensionUninstallEvent,
ExtensionUpdateEvent,
getErrorMessage,
@@ -36,6 +37,7 @@ import {
logExtensionInstallEvent,
logExtensionUninstall,
logExtensionUpdateEvent,
type ExtensionEvents,
type MCPServerConfig,
type ExtensionInstallMetadata,
type GeminiCLIExtension,
@@ -54,11 +56,7 @@ import {
maybePromptForSettings,
type ExtensionSetting,
} from './extensions/extensionSettings.js';
import type {
ExtensionEvents,
ExtensionLoader,
} from '@google/gemini-cli-core/src/utils/extensionLoader.js';
import { EventEmitter } from 'node:events';
import type { EventEmitter } from 'node:stream';
interface ExtensionManagerParams {
enabledExtensionOverrides?: string[];
@@ -66,6 +64,7 @@ interface ExtensionManagerParams {
requestConsent: (consent: string) => Promise<boolean>;
requestSetting: ((setting: ExtensionSetting) => Promise<string>) | null;
workspaceDir: string;
eventEmitter?: EventEmitter<ExtensionEvents>;
}
/**
@@ -73,7 +72,7 @@ interface ExtensionManagerParams {
*
* You must call `loadExtensions` prior to calling other methods on this class.
*/
export class ExtensionManager implements ExtensionLoader {
export class ExtensionManager extends ExtensionLoader {
private extensionEnablementManager: ExtensionEnablementManager;
private settings: Settings;
private requestConsent: (consent: string) => Promise<boolean>;
@@ -83,9 +82,9 @@ export class ExtensionManager implements ExtensionLoader {
private telemetryConfig: Config;
private workspaceDir: string;
private loadedExtensions: GeminiCLIExtension[] | undefined;
private eventEmitter: EventEmitter<ExtensionEvents>;
constructor(options: ExtensionManagerParams) {
super(options.eventEmitter);
this.workspaceDir = options.workspaceDir;
this.extensionEnablementManager = new ExtensionEnablementManager(
options.enabledExtensionOverrides,
@@ -102,7 +101,6 @@ export class ExtensionManager implements ExtensionLoader {
});
this.requestConsent = options.requestConsent;
this.requestSetting = options.requestSetting ?? undefined;
this.eventEmitter = new EventEmitter();
}
setRequestConsent(
@@ -126,10 +124,6 @@ export class ExtensionManager implements ExtensionLoader {
return this.loadedExtensions!;
}
extensionEvents(): EventEmitter<ExtensionEvents> {
return this.eventEmitter;
}
async installOrUpdateExtension(
installMetadata: ExtensionInstallMetadata,
previousExtensionConfig?: ExtensionConfig,
@@ -303,7 +297,7 @@ export class ExtensionManager implements ExtensionLoader {
await fs.promises.writeFile(metadataPath, metadataString);
// TODO: Gracefully handle this call failing, we should back up the old
// extension prior to overwriting it and then restore it.
// extension prior to overwriting it and then restore and restart it.
extension = await this.loadExtension(destinationPath)!;
if (!extension) {
throw new Error(`Extension not found`);
@@ -320,7 +314,6 @@ export class ExtensionManager implements ExtensionLoader {
'success',
),
);
this.eventEmitter.emit('extensionUpdated', { extension });
} else {
logExtensionInstallEvent(
this.telemetryConfig,
@@ -332,7 +325,6 @@ export class ExtensionManager implements ExtensionLoader {
'success',
),
);
this.eventEmitter.emit('extensionInstalled', { extension });
this.enableExtension(newExtensionConfig.name, SettingScope.User);
}
} finally {
@@ -397,7 +389,7 @@ export class ExtensionManager implements ExtensionLoader {
if (!extension) {
throw new Error(`Extension not found.`);
}
this.unloadExtension(extension);
await this.unloadExtension(extension);
const storage = new ExtensionStorage(extension.name);
await fs.promises.rm(storage.getExtensionDir(), {
@@ -419,9 +411,11 @@ export class ExtensionManager implements ExtensionLoader {
'success',
),
);
this.eventEmitter.emit('extensionUninstalled', { extension });
}
/**
* Loads all installed extensions, should only be called once.
*/
async loadExtensions(): Promise<GeminiCLIExtension[]> {
if (this.loadedExtensions) {
throw new Error('Extensions already loaded, only load extensions once.');
@@ -433,12 +427,14 @@ export class ExtensionManager implements ExtensionLoader {
}
for (const subdir of fs.readdirSync(extensionsDir)) {
const extensionDir = path.join(extensionsDir, subdir);
await this.loadExtension(extensionDir);
}
return this.loadedExtensions;
}
/**
* Adds `extension` to the list of extensions and starts it if appropriate.
*/
private async loadExtension(
extensionDir: string,
): Promise<GeminiCLIExtension | null> {
@@ -499,8 +495,9 @@ export class ExtensionManager implements ExtensionLoader {
),
id: getExtensionId(config, installMetadata),
};
this.eventEmitter.emit('extensionLoaded', { extension });
this.getExtensions().push(extension);
this.loadedExtensions = [...this.loadedExtensions, extension];
await this.maybeStartExtension(extension);
return extension;
} catch (e) {
debugLogger.error(
@@ -512,11 +509,17 @@ export class ExtensionManager implements ExtensionLoader {
}
}
private unloadExtension(extension: GeminiCLIExtension) {
/**
* Removes `extension` from the list of extensions and stops it if
* appropriate.
*/
private unloadExtension(
extension: GeminiCLIExtension,
): Promise<void> | undefined {
this.loadedExtensions = this.getExtensions().filter(
(entry) => extension !== entry,
);
this.eventEmitter.emit('extensionUnloaded', { extension });
return this.maybeStopExtension(extension);
}
loadExtensionConfig(extensionDir: string): ExtensionConfig {
@@ -616,14 +619,18 @@ export class ExtensionManager implements ExtensionLoader {
const scopePath =
scope === SettingScope.Workspace ? this.workspaceDir : os.homedir();
this.extensionEnablementManager.disable(name, true, scopePath);
extension.isActive = false;
await this.maybeStopExtension(extension);
logExtensionDisable(
this.telemetryConfig,
new ExtensionDisableEvent(hashValue(name), extension.id, scope),
);
extension.isActive = false;
this.eventEmitter.emit('extensionDisabled', { extension });
}
/**
* Enables an existing extension for a given scope, and starts it if
* appropriate.
*/
async enableExtension(name: string, scope: SettingScope) {
if (
scope === SettingScope.System ||
@@ -645,7 +652,7 @@ export class ExtensionManager implements ExtensionLoader {
new ExtensionEnableEvent(hashValue(name), extension.id, scope),
);
extension.isActive = true;
this.eventEmitter.emit('extensionEnabled', { extension });
await this.maybeStartExtension(extension);
}
}
+21 -9
View File
@@ -1666,7 +1666,10 @@ This extension will run the following MCP servers:
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User);
await extensionManager.disableExtension(
'my-extension',
SettingScope.User,
);
expect(
isEnabled({
name: 'my-extension',
@@ -1683,7 +1686,10 @@ This extension will run the following MCP servers:
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.Workspace);
await extensionManager.disableExtension(
'my-extension',
SettingScope.Workspace,
);
expect(
isEnabled({
name: 'my-extension',
@@ -1706,8 +1712,14 @@ This extension will run the following MCP servers:
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('my-extension', SettingScope.User);
extensionManager.disableExtension('my-extension', SettingScope.User);
await extensionManager.disableExtension(
'my-extension',
SettingScope.User,
);
await extensionManager.disableExtension(
'my-extension',
SettingScope.User,
);
expect(
isEnabled({
name: 'my-extension',
@@ -1738,7 +1750,7 @@ This extension will run the following MCP servers:
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
expect(mockLogExtensionDisable).toHaveBeenCalled();
expect(ExtensionDisableEvent).toHaveBeenCalledWith(
@@ -1766,7 +1778,7 @@ This extension will run the following MCP servers:
version: '1.0.0',
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.User);
await extensionManager.disableExtension('ext1', SettingScope.User);
let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0);
@@ -1783,7 +1795,7 @@ This extension will run the following MCP servers:
version: '1.0.0',
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
let activeExtensions = getActiveExtensions();
expect(activeExtensions).toHaveLength(0);
@@ -1804,8 +1816,8 @@ This extension will run the following MCP servers:
},
});
await extensionManager.loadExtensions();
extensionManager.disableExtension('ext1', SettingScope.Workspace);
extensionManager.enableExtension('ext1', SettingScope.Workspace);
await extensionManager.disableExtension('ext1', SettingScope.Workspace);
await extensionManager.enableExtension('ext1', SettingScope.Workspace);
expect(mockLogExtensionEnable).toHaveBeenCalled();
expect(ExtensionEnableEvent).toHaveBeenCalledWith(