Skip MCP server connections in untrusted folders (#7358)

This commit is contained in:
shrutip90
2025-08-28 15:46:27 -07:00
committed by GitHub
parent f00cf42f69
commit a0fbe000ee
7 changed files with 179 additions and 16 deletions

View File

@@ -10,6 +10,7 @@ import { McpClient } from './mcp-client.js';
import type { ToolRegistry } from './tool-registry.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { WorkspaceContext } from '../utils/workspaceContext.js';
import type { Config } from '../config/config.js';
vi.mock('./mcp-client.js', async () => {
const originalModule = await vi.importActual('./mcp-client.js');
@@ -47,8 +48,64 @@ describe('McpClientManager', () => {
false,
{} as WorkspaceContext,
);
await manager.discoverAllMcpTools();
await manager.discoverAllMcpTools({
isTrustedFolder: () => true,
} as unknown as Config);
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
it('should discover tools if isTrustedFolder is undefined', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager(
{
'test-server': {},
},
'',
{} as ToolRegistry,
{} as PromptRegistry,
false,
{} as WorkspaceContext,
);
await manager.discoverAllMcpTools({
isTrustedFolder: () => undefined,
} as unknown as Config);
expect(mockedMcpClient.connect).toHaveBeenCalledOnce();
expect(mockedMcpClient.discover).toHaveBeenCalledOnce();
});
it('should not discover tools if folder is not trusted', async () => {
const mockedMcpClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
};
vi.mocked(McpClient).mockReturnValue(
mockedMcpClient as unknown as McpClient,
);
const manager = new McpClientManager(
{
'test-server': {},
},
'',
{} as ToolRegistry,
{} as PromptRegistry,
false,
{} as WorkspaceContext,
);
await manager.discoverAllMcpTools({
isTrustedFolder: () => false,
} as unknown as Config);
expect(mockedMcpClient.connect).not.toHaveBeenCalled();
expect(mockedMcpClient.discover).not.toHaveBeenCalled();
});
});

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { MCPServerConfig } from '../config/config.js';
import type { Config, MCPServerConfig } from '../config/config.js';
import type { ToolRegistry } from './tool-registry.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import {
@@ -55,7 +55,10 @@ export class McpClientManager {
* It connects to each server, discovers its available tools, and registers
* them with the `ToolRegistry`.
*/
async discoverAllMcpTools(): Promise<void> {
async discoverAllMcpTools(cliConfig: Config): Promise<void> {
if (cliConfig.isTrustedFolder() === false) {
return;
}
await this.stop();
const servers = populateMcpServerCommand(
@@ -91,7 +94,7 @@ export class McpClientManager {
try {
await client.connect();
await client.discover();
await client.discover(cliConfig);
this.eventEmitter?.emit('mcp-server-connected', {
name,
current,

View File

@@ -19,7 +19,7 @@ import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import * as GenAiLib from '@google/genai';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { AuthProviderType } from '../config/config.js';
import { AuthProviderType, type Config } from '../config/config.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import type { ToolRegistry } from './tool-registry.js';
import type { WorkspaceContext } from '../utils/workspaceContext.js';
@@ -74,7 +74,7 @@ describe('mcp-client', () => {
false,
);
await client.connect();
await client.discover();
await client.discover({} as Config);
expect(mockedMcpToTool).toHaveBeenCalledOnce();
});
@@ -136,7 +136,7 @@ describe('mcp-client', () => {
false,
);
await client.connect();
await client.discover();
await client.discover({} as Config);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledOnce();
expect(consoleWarnSpy).toHaveBeenCalledWith(
@@ -180,7 +180,7 @@ describe('mcp-client', () => {
false,
);
await client.connect();
await expect(client.discover()).rejects.toThrow(
await expect(client.discover({} as Config)).rejects.toThrow(
'No prompts or tools found on the server.',
);
expect(consoleErrorSpy).toHaveBeenCalledWith(

View File

@@ -21,7 +21,7 @@ import {
ListRootsRequestSchema,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import type { MCPServerConfig } from '../config/config.js';
import type { Config, MCPServerConfig } from '../config/config.js';
import { AuthProviderType } from '../config/config.js';
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
@@ -146,13 +146,13 @@ export class McpClient {
/**
* Discovers tools and prompts from the MCP server.
*/
async discover(): Promise<void> {
async discover(cliConfig: Config): Promise<void> {
if (this.status !== MCPServerStatus.CONNECTED) {
throw new Error('Client is not connected.');
}
const prompts = await this.discoverPrompts();
const tools = await this.discoverTools();
const tools = await this.discoverTools(cliConfig);
if (prompts.length === 0 && tools.length === 0) {
throw new Error('No prompts or tools found on the server.');
@@ -191,8 +191,13 @@ export class McpClient {
return createTransport(this.serverName, this.serverConfig, this.debugMode);
}
private async discoverTools(): Promise<DiscoveredMCPTool[]> {
return discoverTools(this.serverName, this.serverConfig, this.client);
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
return discoverTools(
this.serverName,
this.serverConfig,
this.client,
cliConfig,
);
}
private async discoverPrompts(): Promise<Prompt[]> {
@@ -445,6 +450,7 @@ export async function discoverMcpTools(
promptRegistry: PromptRegistry,
debugMode: boolean,
workspaceContext: WorkspaceContext,
cliConfig: Config,
): Promise<void> {
mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS;
try {
@@ -459,6 +465,7 @@ export async function discoverMcpTools(
promptRegistry,
debugMode,
workspaceContext,
cliConfig,
),
);
await Promise.all(discoveryPromises);
@@ -504,6 +511,7 @@ export async function connectAndDiscover(
promptRegistry: PromptRegistry,
debugMode: boolean,
workspaceContext: WorkspaceContext,
cliConfig: Config,
): Promise<void> {
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
@@ -531,6 +539,7 @@ export async function connectAndDiscover(
mcpServerName,
mcpServerConfig,
mcpClient,
cliConfig,
);
// If we have neither prompts nor tools, it's a failed discovery
@@ -632,6 +641,7 @@ export async function discoverTools(
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
mcpClient: Client,
cliConfig: Config,
): Promise<DiscoveredMCPTool[]> {
try {
const mcpCallableTool = mcpToTool(mcpClient);
@@ -667,6 +677,8 @@ export async function discoverTools(
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
mcpServerConfig.trust,
undefined,
cliConfig,
),
);
} catch (error) {

View File

@@ -747,6 +747,89 @@ describe('DiscoveredMCPTool', () => {
});
});
describe('shouldConfirmExecute with folder trust', () => {
const mockConfig = (isTrusted: boolean | undefined) => ({
isTrustedFolder: () => isTrusted,
});
it('should return false if trust is true and folder is trusted', async () => {
const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
undefined,
true, // trust = true
undefined,
mockConfig(true) as any, // isTrustedFolder = true
);
const invocation = trustedTool.build({ param: 'mock' });
expect(
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
it('should return confirmation details if trust is true but folder is not trusted', async () => {
const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
undefined,
true, // trust = true
undefined,
mockConfig(false) as any, // isTrustedFolder = false
);
const invocation = trustedTool.build({ param: 'mock' });
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
expect(confirmation).toHaveProperty('type', 'mcp');
});
it('should return confirmation details if trust is false, even if folder is trusted', async () => {
const untrustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
undefined,
false, // trust = false
undefined,
mockConfig(true) as any, // isTrustedFolder = true
);
const invocation = untrustedTool.build({ param: 'mock' });
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
expect(confirmation).toHaveProperty('type', 'mcp');
});
it('should return false if trust is true and folder trust is undefined', async () => {
// The check is `isTrustedFolder() !== false`, so `undefined` should pass
const trustedTool = new DiscoveredMCPTool(
mockCallableToolInstance,
serverName,
serverToolName,
baseDescription,
inputSchema,
undefined,
true, // trust = true
undefined,
mockConfig(undefined) as any, // isTrustedFolder = undefined
);
const invocation = trustedTool.build({ param: 'mock' });
expect(
await invocation.shouldConfirmExecute(new AbortController().signal),
).toBe(false);
});
});
describe('DiscoveredMCPToolInvocation', () => {
it('should return the stringified params from getDescription', () => {
const params = { param: 'testValue', param2: 'anotherOne' };

View File

@@ -19,6 +19,7 @@ import {
} from './tools.js';
import type { CallableTool, FunctionCall, Part } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
import type { Config } from '../config/config.js';
type ToolParams = Record<string, unknown>;
@@ -70,6 +71,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
readonly timeout?: number,
readonly trust?: boolean,
params: ToolParams = {},
private readonly cliConfig?: Config,
) {
super(params);
}
@@ -80,7 +82,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
const serverAllowListKey = this.serverName;
const toolAllowListKey = `${this.serverName}.${this.serverToolName}`;
if (this.trust) {
const isTrustedFolder = this.cliConfig?.isTrustedFolder() !== false;
if (this.trust && isTrustedFolder) {
return false; // server is trusted, no confirmation needed
}
@@ -183,6 +187,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
readonly timeout?: number,
readonly trust?: boolean,
nameOverride?: string,
private readonly cliConfig?: Config,
) {
super(
nameOverride ?? generateValidName(serverToolName),
@@ -205,6 +210,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
this.timeout,
this.trust,
`${this.serverName}__${this.serverToolName}`,
this.cliConfig,
);
}
@@ -219,6 +225,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
this.timeout,
this.trust,
params,
this.cliConfig,
);
}
}

View File

@@ -236,7 +236,7 @@ export class ToolRegistry {
await this.discoverAndRegisterToolsFromCommand();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools();
await this.mcpClientManager.discoverAllMcpTools(this.config);
}
/**
@@ -251,7 +251,7 @@ export class ToolRegistry {
this.config.getPromptRegistry().clear();
// discover tools using MCP servers, if configured
await this.mcpClientManager.discoverAllMcpTools();
await this.mcpClientManager.discoverAllMcpTools(this.config);
}
/**
@@ -285,6 +285,7 @@ export class ToolRegistry {
this.config.getPromptRegistry(),
this.config.getDebugMode(),
this.config.getWorkspaceContext(),
this.config,
);
}
}