mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
Skip MCP server connections in untrusted folders (#7358)
This commit is contained in:
@@ -10,6 +10,7 @@ import { McpClient } from './mcp-client.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
vi.mock('./mcp-client.js', async () => {
|
||||
const originalModule = await vi.importActual('./mcp-client.js');
|
||||
@@ -47,8 +48,64 @@ describe('McpClientManager', () => {
|
||||
false,
|
||||
{} as WorkspaceContext,
|
||||
);
|
||||
await manager.discoverAllMcpTools();
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => true,
|
||||
} as unknown as Config);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should discover tools if isTrustedFolder is undefined', async () => {
|
||||
const mockedMcpClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{
|
||||
'test-server': {},
|
||||
},
|
||||
'',
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
false,
|
||||
{} as WorkspaceContext,
|
||||
);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => undefined,
|
||||
} as unknown as Config);
|
||||
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
|
||||
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should not discover tools if folder is not trusted', async () => {
|
||||
const mockedMcpClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
};
|
||||
vi.mocked(McpClient).mockReturnValue(
|
||||
mockedMcpClient as unknown as McpClient,
|
||||
);
|
||||
const manager = new McpClientManager(
|
||||
{
|
||||
'test-server': {},
|
||||
},
|
||||
'',
|
||||
{} as ToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
false,
|
||||
{} as WorkspaceContext,
|
||||
);
|
||||
await manager.discoverAllMcpTools({
|
||||
isTrustedFolder: () => false,
|
||||
} as unknown as Config);
|
||||
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
|
||||
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import {
|
||||
@@ -55,7 +55,10 @@ export class McpClientManager {
|
||||
* It connects to each server, discovers its available tools, and registers
|
||||
* them with the `ToolRegistry`.
|
||||
*/
|
||||
async discoverAllMcpTools(): Promise<void> {
|
||||
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
|
||||
if (cliConfig.isTrustedFolder() === false) {
|
||||
return;
|
||||
}
|
||||
await this.stop();
|
||||
|
||||
const servers = populateMcpServerCommand(
|
||||
@@ -91,7 +94,7 @@ export class McpClientManager {
|
||||
|
||||
try {
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
await client.discover(cliConfig);
|
||||
this.eventEmitter?.emit('mcp-server-connected', {
|
||||
name,
|
||||
current,
|
||||
|
||||
@@ -19,7 +19,7 @@ import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import * as GenAiLib from '@google/genai';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { AuthProviderType, type Config } from '../config/config.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
@@ -74,7 +74,7 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
@@ -136,7 +136,7 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledOnce();
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
@@ -180,7 +180,7 @@ describe('mcp-client', () => {
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await expect(client.discover()).rejects.toThrow(
|
||||
await expect(client.discover({} as Config)).rejects.toThrow(
|
||||
'No prompts or tools found on the server.',
|
||||
);
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
|
||||
@@ -21,7 +21,7 @@ import {
|
||||
ListRootsRequestSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type { MCPServerConfig } from '../config/config.js';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
@@ -146,13 +146,13 @@ export class McpClient {
|
||||
/**
|
||||
* Discovers tools and prompts from the MCP server.
|
||||
*/
|
||||
async discover(): Promise<void> {
|
||||
async discover(cliConfig: Config): Promise<void> {
|
||||
if (this.status !== MCPServerStatus.CONNECTED) {
|
||||
throw new Error('Client is not connected.');
|
||||
}
|
||||
|
||||
const prompts = await this.discoverPrompts();
|
||||
const tools = await this.discoverTools();
|
||||
const tools = await this.discoverTools(cliConfig);
|
||||
|
||||
if (prompts.length === 0 && tools.length === 0) {
|
||||
throw new Error('No prompts or tools found on the server.');
|
||||
@@ -191,8 +191,13 @@ export class McpClient {
|
||||
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
||||
}
|
||||
|
||||
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
|
||||
return discoverTools(this.serverName, this.serverConfig, this.client);
|
||||
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
||||
return discoverTools(
|
||||
this.serverName,
|
||||
this.serverConfig,
|
||||
this.client,
|
||||
cliConfig,
|
||||
);
|
||||
}
|
||||
|
||||
private async discoverPrompts(): Promise<Prompt[]> {
|
||||
@@ -445,6 +450,7 @@ export async function discoverMcpTools(
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
cliConfig: Config,
|
||||
): Promise<void> {
|
||||
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
|
||||
try {
|
||||
@@ -459,6 +465,7 @@ export async function discoverMcpTools(
|
||||
promptRegistry,
|
||||
debugMode,
|
||||
workspaceContext,
|
||||
cliConfig,
|
||||
),
|
||||
);
|
||||
await Promise.all(discoveryPromises);
|
||||
@@ -504,6 +511,7 @@ export async function connectAndDiscover(
|
||||
promptRegistry: PromptRegistry,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
cliConfig: Config,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
|
||||
@@ -531,6 +539,7 @@ export async function connectAndDiscover(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
mcpClient,
|
||||
cliConfig,
|
||||
);
|
||||
|
||||
// If we have neither prompts nor tools, it's a failed discovery
|
||||
@@ -632,6 +641,7 @@ export async function discoverTools(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
mcpClient: Client,
|
||||
cliConfig: Config,
|
||||
): Promise<DiscoveredMCPTool[]> {
|
||||
try {
|
||||
const mcpCallableTool = mcpToTool(mcpClient);
|
||||
@@ -667,6 +677,8 @@ export async function discoverTools(
|
||||
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
mcpServerConfig.trust,
|
||||
undefined,
|
||||
cliConfig,
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
|
||||
@@ -747,6 +747,89 @@ describe('DiscoveredMCPTool', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('shouldConfirmExecute with folder trust', () => {
|
||||
const mockConfig = (isTrusted: boolean | undefined) => ({
|
||||
isTrustedFolder: () => isTrusted,
|
||||
});
|
||||
|
||||
it('should return false if trust is true and folder is trusted', async () => {
|
||||
const trustedTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
true, // trust = true
|
||||
undefined,
|
||||
mockConfig(true) as any, // isTrustedFolder = true
|
||||
);
|
||||
const invocation = trustedTool.build({ param: 'mock' });
|
||||
expect(
|
||||
await invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it('should return confirmation details if trust is true but folder is not trusted', async () => {
|
||||
const trustedTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
true, // trust = true
|
||||
undefined,
|
||||
mockConfig(false) as any, // isTrustedFolder = false
|
||||
);
|
||||
const invocation = trustedTool.build({ param: 'mock' });
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
expect(confirmation).toHaveProperty('type', 'mcp');
|
||||
});
|
||||
|
||||
it('should return confirmation details if trust is false, even if folder is trusted', async () => {
|
||||
const untrustedTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
false, // trust = false
|
||||
undefined,
|
||||
mockConfig(true) as any, // isTrustedFolder = true
|
||||
);
|
||||
const invocation = untrustedTool.build({ param: 'mock' });
|
||||
const confirmation = await invocation.shouldConfirmExecute(
|
||||
new AbortController().signal,
|
||||
);
|
||||
expect(confirmation).not.toBe(false);
|
||||
expect(confirmation).toHaveProperty('type', 'mcp');
|
||||
});
|
||||
|
||||
it('should return false if trust is true and folder trust is undefined', async () => {
|
||||
// The check is `isTrustedFolder() !== false`, so `undefined` should pass
|
||||
const trustedTool = new DiscoveredMCPTool(
|
||||
mockCallableToolInstance,
|
||||
serverName,
|
||||
serverToolName,
|
||||
baseDescription,
|
||||
inputSchema,
|
||||
undefined,
|
||||
true, // trust = true
|
||||
undefined,
|
||||
mockConfig(undefined) as any, // isTrustedFolder = undefined
|
||||
);
|
||||
const invocation = trustedTool.build({ param: 'mock' });
|
||||
expect(
|
||||
await invocation.shouldConfirmExecute(new AbortController().signal),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('DiscoveredMCPToolInvocation', () => {
|
||||
it('should return the stringified params from getDescription', () => {
|
||||
const params = { param: 'testValue', param2: 'anotherOne' };
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
} from './tools.js';
|
||||
import type { CallableTool, FunctionCall, Part } from '@google/genai';
|
||||
import { ToolErrorType } from './tool-error.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
type ToolParams = Record<string, unknown>;
|
||||
|
||||
@@ -70,6 +71,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
readonly timeout?: number,
|
||||
readonly trust?: boolean,
|
||||
params: ToolParams = {},
|
||||
private readonly cliConfig?: Config,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
@@ -80,7 +82,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
const serverAllowListKey = this.serverName;
|
||||
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
|
||||
|
||||
if (this.trust) {
|
||||
const isTrustedFolder = this.cliConfig?.isTrustedFolder() !== false;
|
||||
|
||||
if (this.trust && isTrustedFolder) {
|
||||
return false; // server is trusted, no confirmation needed
|
||||
}
|
||||
|
||||
@@ -183,6 +187,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
readonly timeout?: number,
|
||||
readonly trust?: boolean,
|
||||
nameOverride?: string,
|
||||
private readonly cliConfig?: Config,
|
||||
) {
|
||||
super(
|
||||
nameOverride ?? generateValidName(serverToolName),
|
||||
@@ -205,6 +210,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.timeout,
|
||||
this.trust,
|
||||
`${this.serverName}__${this.serverToolName}`,
|
||||
this.cliConfig,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -219,6 +225,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.timeout,
|
||||
this.trust,
|
||||
params,
|
||||
this.cliConfig,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ export class ToolRegistry {
|
||||
await this.discoverAndRegisterToolsFromCommand();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -251,7 +251,7 @@ export class ToolRegistry {
|
||||
this.config.getPromptRegistry().clear();
|
||||
|
||||
// discover tools using MCP servers, if configured
|
||||
await this.mcpClientManager.discoverAllMcpTools();
|
||||
await this.mcpClientManager.discoverAllMcpTools(this.config);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -285,6 +285,7 @@ export class ToolRegistry {
|
||||
this.config.getPromptRegistry(),
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
this.config,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user