Initial support for reloading extensions in the CLI - mcp servers only (#12239)

This commit is contained in:
Jacob MacDonald
2025-10-30 11:05:49 -07:00
committed by GitHub
parent d4cad0cdcc
commit cc081337b7
20 changed files with 437 additions and 107 deletions
+12 -1
View File
@@ -255,6 +255,7 @@ export interface ConfigParameters {
listExtensions?: boolean;
extensionLoader?: ExtensionLoader;
enabledExtensions?: string[];
enableExtensionReloading?: boolean;
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
noBrowser?: boolean;
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
@@ -312,7 +313,7 @@ export class Config {
private readonly toolDiscoveryCommand: string | undefined;
private readonly toolCallCommand: string | undefined;
private readonly mcpServerCommand: string | undefined;
private readonly mcpServers: Record<string, MCPServerConfig> | undefined;
private mcpServers: Record<string, MCPServerConfig> | undefined;
private userMemory: string;
private geminiMdFileCount: number;
private geminiMdFilePaths: string[];
@@ -346,6 +347,7 @@ export class Config {
private readonly listExtensions: boolean;
private readonly _extensionLoader: ExtensionLoader;
private readonly _enabledExtensions: string[];
private readonly enableExtensionReloading: boolean;
private readonly _blockedMcpServers: Array<{
name: string;
extensionName: string;
@@ -501,6 +503,7 @@ export class Config {
this.enableShellOutputEfficiency =
params.enableShellOutputEfficiency ?? true;
this.extensionManagement = params.extensionManagement ?? true;
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
this.storage = new Storage(this.targetDir);
this.fakeResponses = params.fakeResponses;
this.recordResponses = params.recordResponses;
@@ -749,6 +752,10 @@ export class Config {
return this.mcpServers;
}
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
this.mcpServers = mcpServers;
}
getUserMemory(): string {
return this.userMemory;
}
@@ -924,6 +931,10 @@ export class Config {
return this._enabledExtensions;
}
getEnableExtensionReloading(): boolean {
return this.enableExtensionReloading;
}
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
return this._blockedMcpServers;
}
@@ -9,6 +9,7 @@ import { McpClientManager } from './mcp-client-manager.js';
import { McpClient } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { Config } from '../config/config.js';
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
@@ -36,17 +37,22 @@ describe('McpClientManager', () => {
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager({} as ToolRegistry);
await manager.discoverAllMcpTools({
isTrustedFolder: () => true,
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
} as unknown as Config);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => true,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
@@ -61,17 +67,22 @@ describe('McpClientManager', () => {
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager({} as ToolRegistry);
await manager.discoverAllMcpTools({
isTrustedFolder: () => false,
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
} as unknown as Config);
const manager = new McpClientManager(
{} as ToolRegistry,
{
isTrustedFolder: () => false,
getExtensionLoader: () => new SimpleExtensionLoader([]),
getMcpServers: () => ({
'test-server': {},
}),
getMcpServerCommand: () => '',
getPromptRegistry: () => {},
getDebugMode: () => false,
getWorkspaceContext: () => {},
getEnableExtensionReloading: () => false,
} as unknown as Config,
);
await manager.discoverAllMcpTools();
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});
+166 -41
View File
@@ -4,7 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import type {
Config,
GeminiCLIExtension,
MCPServerConfig,
} from '../config/config.js';
import type { ToolRegistry } from './tool-registry.js';
import {
McpClient,
@@ -14,6 +18,7 @@ import {
import { getErrorMessage } from '../utils/errors.js';
import type { EventEmitter } from 'node:events';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
/**
* Manages the lifecycle of multiple MCP clients, including local child processes.
@@ -23,12 +28,162 @@ import { coreEvents } from '../utils/events.js';
export class McpClientManager {
private clients: Map<string, McpClient> = new Map();
private readonly toolRegistry: ToolRegistry;
private readonly cliConfig: Config;
// If we have ongoing MCP client discovery, this completes once that is done.
private discoveryPromise: Promise<void> | undefined;
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
private readonly eventEmitter?: EventEmitter;
constructor(toolRegistry: ToolRegistry, eventEmitter?: EventEmitter) {
constructor(
toolRegistry: ToolRegistry,
cliConfig: Config,
eventEmitter?: EventEmitter,
) {
this.toolRegistry = toolRegistry;
this.cliConfig = cliConfig;
this.eventEmitter = eventEmitter;
if (this.cliConfig.getEnableExtensionReloading()) {
this.cliConfig
.getExtensionLoader()
.extensionEvents()
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
.on('extensionDisabled', (event) =>
this.unloadExtension(event.extension),
)
.on('extensionUnloaded', (event) =>
this.unloadExtension(event.extension),
);
}
}
/**
* For all the MCP servers associated with this extension:
*
* - Removes all its MCP servers from the global configuration object.
* - Disconnects all MCP clients from their servers.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async unloadExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Unloading extension: ${extension.name}`);
await Promise.all(
Object.keys(extension.mcpServers ?? {}).map((name) => {
const newMcpServers = {
...this.cliConfig.getMcpServers(),
};
delete newMcpServers[name];
this.cliConfig.setMcpServers(newMcpServers);
return this.disconnectClient(name);
}),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
/**
* For all the MCP servers associated with this extension:
*
* - Adds all its MCP servers to the global configuration object.
* - Connects MCP clients to each server and discovers their tools.
* - Updates the Gemini chat configuration to load the new tools.
*/
private async loadExtension(extension: GeminiCLIExtension) {
debugLogger.log(`Loading extension: ${extension.name}`);
await Promise.all(
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
this.cliConfig.setMcpServers({
...this.cliConfig.getMcpServers(),
[name]: config,
});
return this.discoverMcpTools(name, config);
}),
);
// This is required to update the content generator configuration with the
// new tool configuration.
this.cliConfig.getGeminiClient().setTools();
}
private async disconnectClient(name: string) {
const existing = this.clients.get(name);
if (existing) {
try {
this.clients.delete(name);
this.eventEmitter?.emit('mcp-client-update', this.clients);
await existing.disconnect();
} catch (error) {
debugLogger.warn(
`Error stopping client '${name}': ${getErrorMessage(error)}`,
);
}
}
}
discoverMcpTools(
name: string,
config: MCPServerConfig,
): Promise<void> | void {
if (!this.cliConfig.isTrustedFolder()) {
return;
}
if (config.extension && !config.extension.isActive) {
return;
}
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
(async () => {
try {
await this.disconnectClient(name);
const client = new McpClient(
name,
config,
this.toolRegistry,
this.cliConfig.getPromptRegistry(),
this.cliConfig.getWorkspaceContext(),
this.cliConfig.getDebugMode(),
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
try {
await client.connect();
await client.discover(this.cliConfig);
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
// Log the error but don't let a single failed server stop the others
coreEvents.emitFeedback(
'error',
`Error during discovery for server '${name}': ${getErrorMessage(
error,
)}`,
error,
);
}
} finally {
resolve();
}
})();
});
if (this.discoveryPromise) {
this.discoveryPromise = this.discoveryPromise.then(
() => currentDiscoveryPromise,
);
} else {
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
this.discoveryPromise = currentDiscoveryPromise;
}
const currentPromise = this.discoveryPromise;
currentPromise.then((_) => {
// If we are the last recorded discoveryPromise, then we are done, reset
// the world.
if (currentPromise === this.discoveryPromise) {
this.discoveryPromise = undefined;
this.discoveryState = MCPDiscoveryState.COMPLETED;
}
});
return currentPromise;
}
/**
@@ -36,53 +191,23 @@ export class McpClientManager {
* It connects to each server, discovers its available tools, and registers
* them with the `ToolRegistry`.
*/
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
if (!cliConfig.isTrustedFolder()) {
async discoverAllMcpTools(): Promise<void> {
if (!this.cliConfig.isTrustedFolder()) {
return;
}
await this.stop();
const servers = populateMcpServerCommand(
cliConfig.getMcpServers() || {},
cliConfig.getMcpServerCommand(),
this.cliConfig.getMcpServers() || {},
this.cliConfig.getMcpServerCommand(),
);
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
this.eventEmitter?.emit('mcp-client-update', this.clients);
const discoveryPromises = Object.entries(servers)
.filter(([_, config]) => !config.extension || config.extension.isActive)
.map(async ([name, config]) => {
const client = new McpClient(
name,
config,
this.toolRegistry,
cliConfig.getPromptRegistry(),
cliConfig.getWorkspaceContext(),
cliConfig.getDebugMode(),
);
this.clients.set(name, client);
this.eventEmitter?.emit('mcp-client-update', this.clients);
try {
await client.connect();
await client.discover(cliConfig);
this.eventEmitter?.emit('mcp-client-update', this.clients);
} catch (error) {
this.eventEmitter?.emit('mcp-client-update', this.clients);
// Log the error but don't let a single failed server stop the others
coreEvents.emitFeedback(
'error',
`Error during discovery for server '${name}': ${getErrorMessage(
error,
)}`,
error,
);
}
});
await Promise.all(discoveryPromises);
this.discoveryState = MCPDiscoveryState.COMPLETED;
await Promise.all(
Object.entries(servers).map(async ([name, config]) =>
this.discoverMcpTools(name, config),
),
);
}
/**
+1
View File
@@ -160,6 +160,7 @@ export class McpClient {
if (this.status !== MCPServerStatus.CONNECTED) {
return;
}
this.toolRegistry.removeMcpToolsByServer(this.serverName);
this.updateStatus(MCPServerStatus.DISCONNECTING);
const client = this.client;
this.client = undefined;
+3 -3
View File
@@ -181,7 +181,7 @@ export class ToolRegistry {
constructor(config: Config, eventEmitter?: EventEmitter) {
this.config = config;
this.mcpClientManager = new McpClientManager(this, eventEmitter);
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
}
setMessageBus(messageBus: MessageBus): void {
@@ -244,7 +244,7 @@ export class ToolRegistry {
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools(this.config);
await this.mcpClientManager.discoverAllMcpTools();
}
/**
@@ -259,7 +259,7 @@ export class ToolRegistry {
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools(this.config);
await this.mcpClientManager.discoverAllMcpTools();
}
/**