mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 14:04:41 -07:00
Co-authored-by: Allen Hutchison <adh@google.com>
This commit is contained in:
@@ -50,6 +50,7 @@ describe('McpClientManager', () => {
|
||||
getAllowedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getBlockedMcpServers: vi.fn().mockReturnValue([]),
|
||||
getMcpServerCommand: vi.fn().mockReturnValue(''),
|
||||
getMcpEnablementCallbacks: vi.fn().mockReturnValue(undefined),
|
||||
getGeminiClient: vi.fn().mockReturnValue({
|
||||
isInitialized: vi.fn(),
|
||||
}),
|
||||
|
||||
@@ -27,6 +27,8 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
*/
|
||||
export class McpClientManager {
|
||||
private clients: Map<string, McpClient> = new Map();
|
||||
// Track all configured servers (including disabled ones) for UI display
|
||||
private allServerConfigs: Map<string, MCPServerConfig> = new Map();
|
||||
private readonly clientVersion: string;
|
||||
private readonly toolRegistry: ToolRegistry;
|
||||
private readonly cliConfig: Config;
|
||||
@@ -97,24 +99,44 @@ export class McpClientManager {
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
|
||||
private isAllowedMcpServer(name: string) {
|
||||
/**
|
||||
* Check if server is blocked by admin settings (allowlist/excludelist).
|
||||
* Returns true if blocked, false if allowed.
|
||||
*/
|
||||
private isBlockedBySettings(name: string): boolean {
|
||||
const allowedNames = this.cliConfig.getAllowedMcpServers();
|
||||
if (
|
||||
allowedNames &&
|
||||
allowedNames.length > 0 &&
|
||||
allowedNames.indexOf(name) === -1
|
||||
!allowedNames.includes(name)
|
||||
) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
const blockedNames = this.cliConfig.getBlockedMcpServers();
|
||||
if (
|
||||
blockedNames &&
|
||||
blockedNames.length > 0 &&
|
||||
blockedNames.indexOf(name) !== -1
|
||||
blockedNames.includes(name)
|
||||
) {
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if server is disabled by user (session or file-based).
|
||||
*/
|
||||
private async isDisabledByUser(name: string): Promise<boolean> {
|
||||
const callbacks = this.cliConfig.getMcpEnablementCallbacks();
|
||||
if (callbacks) {
|
||||
if (callbacks.isSessionDisabled(name)) {
|
||||
return true;
|
||||
}
|
||||
if (!(await callbacks.isFileEnabled(name))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private async disconnectClient(name: string, skipRefresh = false) {
|
||||
@@ -138,11 +160,15 @@ export class McpClientManager {
|
||||
}
|
||||
}
|
||||
|
||||
maybeDiscoverMcpServer(
|
||||
async maybeDiscoverMcpServer(
|
||||
name: string,
|
||||
config: MCPServerConfig,
|
||||
): Promise<void> | void {
|
||||
if (!this.isAllowedMcpServer(name)) {
|
||||
): Promise<void> {
|
||||
// Always track server config for UI display
|
||||
this.allServerConfigs.set(name, config);
|
||||
|
||||
// Check if blocked by admin settings (allowlist/excludelist)
|
||||
if (this.isBlockedBySettings(name)) {
|
||||
if (!this.blockedMcpServers.find((s) => s.name === name)) {
|
||||
this.blockedMcpServers?.push({
|
||||
name,
|
||||
@@ -151,6 +177,14 @@ export class McpClientManager {
|
||||
}
|
||||
return;
|
||||
}
|
||||
// User-disabled servers: disconnect if running, don't start
|
||||
if (await this.isDisabledByUser(name)) {
|
||||
const existing = this.clients.get(name);
|
||||
if (existing) {
|
||||
await this.disconnectClient(name);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (!this.cliConfig.isTrustedFolder()) {
|
||||
return;
|
||||
}
|
||||
@@ -273,6 +307,11 @@ export class McpClientManager {
|
||||
this.cliConfig.getMcpServerCommand(),
|
||||
);
|
||||
|
||||
// Set state synchronously before any await yields control
|
||||
if (!this.discoveryPromise) {
|
||||
this.discoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
}
|
||||
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
await Promise.all(
|
||||
Object.entries(servers).map(([name, config]) =>
|
||||
@@ -283,23 +322,21 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* Restarts all active MCP Clients.
|
||||
* Restarts all MCP servers (including newly enabled ones).
|
||||
*/
|
||||
async restart(): Promise<void> {
|
||||
await Promise.all(
|
||||
Array.from(this.clients.keys()).map(async (name) => {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}),
|
||||
Array.from(this.allServerConfigs.entries()).map(
|
||||
async ([name, config]) => {
|
||||
try {
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
} catch (error) {
|
||||
debugLogger.error(
|
||||
`Error restarting client '${name}': ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
),
|
||||
);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
@@ -308,11 +345,11 @@ export class McpClientManager {
|
||||
* Restart a single MCP server by name.
|
||||
*/
|
||||
async restartServer(name: string) {
|
||||
const client = this.clients.get(name);
|
||||
if (!client) {
|
||||
const config = this.allServerConfigs.get(name);
|
||||
if (!config) {
|
||||
throw new Error(`No MCP server registered with the name "${name}"`);
|
||||
}
|
||||
await this.maybeDiscoverMcpServer(name, client.getServerConfig());
|
||||
await this.maybeDiscoverMcpServer(name, config);
|
||||
await this.cliConfig.refreshMcpContext();
|
||||
}
|
||||
|
||||
@@ -344,12 +381,12 @@ export class McpClientManager {
|
||||
}
|
||||
|
||||
/**
|
||||
* All of the MCP server configurations currently loaded.
|
||||
* All of the MCP server configurations (including disabled ones).
|
||||
*/
|
||||
getMcpServers(): Record<string, MCPServerConfig> {
|
||||
const mcpServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [name, client] of this.clients.entries()) {
|
||||
mcpServers[name] = client.getServerConfig();
|
||||
for (const [name, config] of this.allServerConfigs.entries()) {
|
||||
mcpServers[name] = config;
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user