mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-26 13:04:49 -07:00
Initial support for reloading extensions in the CLI - mcp servers only (#12239)
This commit is contained in:
@@ -255,6 +255,7 @@ export interface ConfigParameters {
|
||||
listExtensions?: boolean;
|
||||
extensionLoader?: ExtensionLoader;
|
||||
enabledExtensions?: string[];
|
||||
enableExtensionReloading?: boolean;
|
||||
blockedMcpServers?: Array<{ name: string; extensionName: string }>;
|
||||
noBrowser?: boolean;
|
||||
summarizeToolOutput?: Record<string, SummarizeToolOutputSettings>;
|
||||
@@ -312,7 +313,7 @@ export class Config {
|
||||
private readonly toolDiscoveryCommand: string | undefined;
|
||||
private readonly toolCallCommand: string | undefined;
|
||||
private readonly mcpServerCommand: string | undefined;
|
||||
private readonly mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||
private mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||
private userMemory: string;
|
||||
private geminiMdFileCount: number;
|
||||
private geminiMdFilePaths: string[];
|
||||
@@ -346,6 +347,7 @@ export class Config {
|
||||
private readonly listExtensions: boolean;
|
||||
private readonly _extensionLoader: ExtensionLoader;
|
||||
private readonly _enabledExtensions: string[];
|
||||
private readonly enableExtensionReloading: boolean;
|
||||
private readonly _blockedMcpServers: Array<{
|
||||
name: string;
|
||||
extensionName: string;
|
||||
@@ -501,6 +503,7 @@ export class Config {
|
||||
this.enableShellOutputEfficiency =
|
||||
params.enableShellOutputEfficiency ?? true;
|
||||
this.extensionManagement = params.extensionManagement ?? true;
|
||||
this.enableExtensionReloading = params.enableExtensionReloading ?? false;
|
||||
this.storage = new Storage(this.targetDir);
|
||||
this.fakeResponses = params.fakeResponses;
|
||||
this.recordResponses = params.recordResponses;
|
||||
@@ -749,6 +752,10 @@ export class Config {
|
||||
return this.mcpServers;
|
||||
}
|
||||
|
||||
setMcpServers(mcpServers: Record<string, MCPServerConfig>): void {
|
||||
this.mcpServers = mcpServers;
|
||||
}
|
||||
|
||||
getUserMemory(): string {
|
||||
return this.userMemory;
|
||||
}
|
||||
@@ -924,6 +931,10 @@ export class Config {
|
||||
return this._enabledExtensions;
|
||||
}
|
||||
|
||||
getEnableExtensionReloading(): boolean {
|
||||
return this.enableExtensionReloading;
|
||||
}
|
||||
|
||||
getBlockedMcpServers(): Array<{ name: string; extensionName: string }> {
|
||||
return this._blockedMcpServers;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { McpClient } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { SimpleExtensionLoader } from '../utils/extensionLoader.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
@@ -36,17 +37,22 @@ describe('McpClientManager', () => {
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager({} as ToolRegistry);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => true,
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
} as unknown as Config);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => true,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
@@ -61,17 +67,22 @@ describe('McpClientManager', () => {
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager({} as ToolRegistry);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => false,
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
} as unknown as Config);
|
||||
const manager = new McpClientManager(
|
||||
{} as ToolRegistry,
|
||||
{
|
||||
isTrustedFolder: () => false,
|
||||
getExtensionLoader: () => new SimpleExtensionLoader([]),
|
||||
getMcpServers: () => ({
|
||||
'test-server': {},
|
||||
}),
|
||||
getMcpServerCommand: () => '',
|
||||
getPromptRegistry: () => {},
|
||||
getDebugMode: () => false,
|
||||
getWorkspaceContext: () => {},
|
||||
getEnableExtensionReloading: () => false,
|
||||
} as unknown as Config,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
Config,
|
||||
GeminiCLIExtension,
|
||||
MCPServerConfig,
|
||||
} from '../config/config.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import {
|
||||
McpClient,
|
||||
@@ -14,6 +18,7 @@ import {
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Manages the lifecycle of multiple MCP clients, including local child processes.
|
||||
@@ -23,12 +28,162 @@ import { coreEvents } from '../utils/events.js';
|
||||
export class McpClientManager {
|
||||
private clients: Map<string, McpClient> = new Map();
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly cliConfig: Config;
|
||||
// If we have ongoing MCP client discovery, this completes once that is done.
|
||||
private discoveryPromise: Promise<void> | undefined;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
|
||||
constructor(toolRegistry: ToolRegistry, eventEmitter?: EventEmitter) {
|
||||
constructor(
|
||||
toolRegistry: ToolRegistry,
|
||||
cliConfig: Config,
|
||||
eventEmitter?: EventEmitter,
|
||||
) {
|
||||
this.toolRegistry = toolRegistry;
|
||||
this.cliConfig = cliConfig;
|
||||
this.eventEmitter = eventEmitter;
|
||||
if (this.cliConfig.getEnableExtensionReloading()) {
|
||||
this.cliConfig
|
||||
.getExtensionLoader()
|
||||
.extensionEvents()
|
||||
.on('extensionLoaded', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionEnabled', (event) => this.loadExtension(event.extension))
|
||||
.on('extensionDisabled', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
)
|
||||
.on('extensionUnloaded', (event) =>
|
||||
this.unloadExtension(event.extension),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* For all the MCP servers associated with this extension:
|
||||
*
|
||||
* - Removes all its MCP servers from the global configuration object.
|
||||
* - Disconnects all MCP clients from their servers.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async unloadExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Unloading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.keys(extension.mcpServers ?? {}).map((name) => {
|
||||
const newMcpServers = {
|
||||
...this.cliConfig.getMcpServers(),
|
||||
};
|
||||
delete newMcpServers[name];
|
||||
this.cliConfig.setMcpServers(newMcpServers);
|
||||
return this.disconnectClient(name);
|
||||
}),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
/**
|
||||
* For all the MCP servers associated with this extension:
|
||||
*
|
||||
* - Adds all its MCP servers to the global configuration object.
|
||||
* - Connects MCP clients to each server and discovers their tools.
|
||||
* - Updates the Gemini chat configuration to load the new tools.
|
||||
*/
|
||||
private async loadExtension(extension: GeminiCLIExtension) {
|
||||
debugLogger.log(`Loading extension: ${extension.name}`);
|
||||
await Promise.all(
|
||||
Object.entries(extension.mcpServers ?? {}).map(([name, config]) => {
|
||||
this.cliConfig.setMcpServers({
|
||||
...this.cliConfig.getMcpServers(),
|
||||
[name]: config,
|
||||
});
|
||||
return this.discoverMcpTools(name, config);
|
||||
}),
|
||||
);
|
||||
// This is required to update the content generator configuration with the
|
||||
// new tool configuration.
|
||||
this.cliConfig.getGeminiClient().setTools();
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string) {
|
||||
const existing = this.clients.get(name);
|
||||
if (existing) {
|
||||
try {
|
||||
this.clients.delete(name);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await existing.disconnect();
|
||||
} catch (error) {
|
||||
debugLogger.warn(
|
||||
`Error stopping client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
discoverMcpTools(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
if (config.extension && !config.extension.isActive) {
|
||||
return;
|
||||
}
|
||||
|
||||
const currentDiscoveryPromise = new Promise<void>((resolve, _reject) => {
|
||||
(async () => {
|
||||
try {
|
||||
await this.disconnectClient(name);
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
this.cliConfig.getPromptRegistry(),
|
||||
this.cliConfig.getWorkspaceContext(),
|
||||
this.cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(this.cliConfig);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
resolve();
|
||||
}
|
||||
})();
|
||||
});
|
||||
|
||||
if (this.discoveryPromise) {
|
||||
this.discoveryPromise = this.discoveryPromise.then(
|
||||
() => currentDiscoveryPromise,
|
||||
);
|
||||
} else {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
this.discoveryPromise = currentDiscoveryPromise;
|
||||
}
|
||||
const currentPromise = this.discoveryPromise;
|
||||
currentPromise.then((_) => {
|
||||
// If we are the last recorded discoveryPromise, then we are done, reset
|
||||
// the world.
|
||||
if (currentPromise === this.discoveryPromise) {
|
||||
this.discoveryPromise = undefined;
|
||||
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||
}
|
||||
});
|
||||
return currentPromise;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -36,53 +191,23 @@ export class McpClientManager {
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*/
|
||||
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
|
||||
if (!cliConfig.isTrustedFolder()) {
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
await this.stop();
|
||||
|
||||
const servers = populateMcpServerCommand(
|
||||
cliConfig.getMcpServers() || {},
|
||||
cliConfig.getMcpServerCommand(),
|
||||
this.cliConfig.getMcpServers() || {},
|
||||
this.cliConfig.getMcpServerCommand(),
|
||||
);
|
||||
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const discoveryPromises = Object.entries(servers)
|
||||
.filter(([_, config]) => !config.extension || config.extension.isActive)
|
||||
.map(async ([name, config]) => {
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
this.toolRegistry,
|
||||
cliConfig.getPromptRegistry(),
|
||||
cliConfig.getWorkspaceContext(),
|
||||
cliConfig.getDebugMode(),
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover(cliConfig);
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
} catch (error) {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
// Log the error but don't let a single failed server stop the others
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error during discovery for server '${name}': ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.all(discoveryPromises);
|
||||
this.discoveryState = MCPDiscoveryState.COMPLETED;
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(async ([name, config]) =>
|
||||
this.discoverMcpTools(name, config),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -160,6 +160,7 @@ export class McpClient {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
return;
|
||||
}
|
||||
this.toolRegistry.removeMcpToolsByServer(this.serverName);
|
||||
this.updateStatus(MCPServerStatus.DISCONNECTING);
|
||||
const client = this.client;
|
||||
this.client = undefined;
|
||||
|
||||
@@ -181,7 +181,7 @@ export class ToolRegistry {
|
||||
|
||||
constructor(config: Config, eventEmitter?: EventEmitter) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(this, eventEmitter);
|
||||
this.mcpClientManager = new McpClientManager(this, config, eventEmitter);
|
||||
}
|
||||
|
||||
setMessageBus(messageBus: MessageBus): void {
|
||||
@@ -244,7 +244,7 @@ export class ToolRegistry {
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -259,7 +259,7 @@ export class ToolRegistry {
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user