mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 14:04:41 -07:00
Co-authored-by: Allen Hutchison <adh@google.com>
This commit is contained in:
@@ -278,6 +278,18 @@ export interface SandboxConfig {
|
||||
image: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Callbacks for checking MCP server enablement status.
|
||||
* These callbacks are provided by the CLI package to bridge
|
||||
* the enablement state to the core package.
|
||||
*/
|
||||
export interface McpEnablementCallbacks {
|
||||
/** Check if a server is disabled for the current session only */
|
||||
isSessionDisabled: (serverId: string) => boolean;
|
||||
/** Check if a server is enabled in the file-based configuration */
|
||||
isFileEnabled: (serverId: string) => Promise<boolean>;
|
||||
}
|
||||
|
||||
export interface ConfigParameters {
|
||||
sessionId: string;
|
||||
clientVersion?: string;
|
||||
@@ -294,6 +306,7 @@ export interface ConfigParameters {
|
||||
toolCallCommand?: string;
|
||||
mcpServerCommand?: string;
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
mcpEnablementCallbacks?: McpEnablementCallbacks;
|
||||
userMemory?: string;
|
||||
geminiMdFileCount?: number;
|
||||
geminiMdFilePaths?: string[];
|
||||
@@ -426,6 +439,7 @@ export class Config {
|
||||
private readonly mcpEnabled: boolean;
|
||||
private readonly extensionsEnabled: boolean;
|
||||
private mcpServers: Record<string, MCPServerConfig> | undefined;
|
||||
private readonly mcpEnablementCallbacks?: McpEnablementCallbacks;
|
||||
private userMemory: string;
|
||||
private geminiMdFileCount: number;
|
||||
private geminiMdFilePaths: string[];
|
||||
@@ -564,6 +578,7 @@ export class Config {
|
||||
this.toolCallCommand = params.toolCallCommand;
|
||||
this.mcpServerCommand = params.mcpServerCommand;
|
||||
this.mcpServers = params.mcpServers;
|
||||
this.mcpEnablementCallbacks = params.mcpEnablementCallbacks;
|
||||
this.mcpEnabled = params.mcpEnabled ?? true;
|
||||
this.extensionsEnabled = params.extensionsEnabled ?? true;
|
||||
this.allowedMcpServers = params.allowedMcpServers ?? [];
|
||||
@@ -1235,6 +1250,10 @@ export class Config {
|
||||
return this.mcpEnabled;
|
||||
}
|
||||
|
||||
getMcpEnablementCallbacks(): McpEnablementCallbacks | undefined {
|
||||
return this.mcpEnablementCallbacks;
|
||||
}
|
||||
|
||||
getExtensionsEnabled(): boolean {
|
||||
return this.extensionsEnabled;
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ describe('McpClientManager', () => {
|
||||
getAllowedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(''),
|
||||
getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
isInitialized: vi.fn(),
|
||||
}),
|
||||
|
||||
@@ -27,6 +27,8 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
*/
|
||||
export class McpClientManager {
|
||||
private clients: Map<string, McpClient> = new Map();
|
||||
// Track all configured servers (including disabled ones) for UI display
|
||||
private allServerConfigs: Map<string, MCPServerConfig> = new Map();
|
||||
private readonly clientVersion: string;
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly cliConfig: Config;
|
||||
@@ -97,24 +99,44 @@ export class McpClientManager {
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
|
||||
private isAllowedMcpServer(name: string) {
|
||||
/**
|
||||
* Check if server is blocked by admin settings (allowlist/excludelist).
|
||||
* Returns true if blocked, false if allowed.
|
||||
*/
|
||||
private isBlockedBySettings(name: string): boolean {
|
||||
const allowedNames = this.cliConfig.getAllowedMcpServers();
|
||||
if (
|
||||
allowedNames &&
|
||||
allowedNames.length > 0 &&
|
||||
allowedNames.indexOf(name) === -1
|
||||
!allowedNames.includes(name)
|
||||
) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
const blockedNames = this.cliConfig.getBlockedMcpServers();
|
||||
if (
|
||||
blockedNames &&
|
||||
blockedNames.length > 0 &&
|
||||
blockedNames.indexOf(name) !== -1
|
||||
blockedNames.includes(name)
|
||||
) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if server is disabled by user (session or file-based).
|
||||
*/
|
||||
private async isDisabledByUser(name: string): Promise<boolean> {
|
||||
const callbacks = this.cliConfig.getMcpEnablementCallbacks();
|
||||
if (callbacks) {
|
||||
if (callbacks.isSessionDisabled(name)) {
|
||||
return true;
|
||||
}
|
||||
if (!(await callbacks.isFileEnabled(name))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string, skipRefresh = false) {
|
||||
@@ -138,11 +160,15 @@ export class McpClientManager {
|
||||
}
|
||||
}
|
||||
|
||||
maybeDiscoverMcpServer(
|
||||
async maybeDiscoverMcpServer(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.isAllowedMcpServer(name)) {
|
||||
): Promise<void> {
|
||||
// Always track server config for UI display
|
||||
this.allServerConfigs.set(name, config);
|
||||
|
||||
// Check if blocked by admin settings (allowlist/excludelist)
|
||||
if (this.isBlockedBySettings(name)) {
|
||||
if (!this.blockedMcpServers.find((s) => s.name === name)) {
|
||||
this.blockedMcpServers?.push({
|
||||
name,
|
||||
@@ -151,6 +177,14 @@ export class McpClientManager {
|
||||
}
|
||||
return;
|
||||
}
|
||||
// User-disabled servers: disconnect if running, don't start
|
||||
if (await this.isDisabledByUser(name)) {
|
||||
const existing = this.clients.get(name);
|
||||
if (existing) {
|
||||
await this.disconnectClient(name);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
@@ -273,6 +307,11 @@ export class McpClientManager {
|
||||
this.cliConfig.getMcpServerCommand(),
|
||||
);
|
||||
|
||||
// Set state synchronously before any await yields control
|
||||
if (!this.discoveryPromise) {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
}
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(([name, config]) =>
|
||||
@@ -283,23 +322,21 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all active MCP Clients.
|
||||
* Restarts all MCP servers (including newly enabled ones).
|
||||
*/
|
||||
async restart(): Promise<void> {
|
||||
await Promise.all(
|
||||
Array.from(this.clients.keys()).map(async (name) => {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
Array.from(this.allServerConfigs.entries()).map(
|
||||
async ([name, config]) => {
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
),
|
||||
);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
@@ -308,11 +345,11 @@ export class McpClientManager {
|
||||
* Restart a single MCP server by name.
|
||||
*/
|
||||
async restartServer(name: string) {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
const config = this.allServerConfigs.get(name);
|
||||
if (!config) {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
|
||||
@@ -344,12 +381,12 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* All of the MCP server configurations currently loaded.
|
||||
* All of the MCP server configurations (including disabled ones).
|
||||
*/
|
||||
getMcpServers(): Record<string, MCPServerConfig> {
|
||||
const mcpServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [name, client] of this.clients.entries()) {
|
||||
mcpServers[name] = client.getServerConfig();
|
||||
for (const [name, config] of this.allServerConfigs.entries()) {
|
||||
mcpServers[name] = config;
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user