diff --git a/packages/core/src/tools/mcp-client-manager.test.ts b/packages/core/src/tools/mcp-client-manager.test.ts index b30425c8f4..6095381841 100644 --- a/packages/core/src/tools/mcp-client-manager.test.ts +++ b/packages/core/src/tools/mcp-client-manager.test.ts @@ -10,6 +10,7 @@ import { McpClient } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import type { WorkspaceContext } from '../utils/workspaceContext.js'; +import type { Config } from '../config/config.js'; vi.mock('./mcp-client.js', async () => { const originalModule = await vi.importActual('./mcp-client.js'); @@ -47,8 +48,64 @@ describe('McpClientManager', () => { false, {} as WorkspaceContext, ); - await manager.discoverAllMcpTools(); + await manager.discoverAllMcpTools({ + isTrustedFolder: () => true, + } as unknown as Config); expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); }); + + it('should discover tools if isTrustedFolder is undefined', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + const manager = new McpClientManager( + { + 'test-server': {}, + }, + '', + {} as ToolRegistry, + {} as PromptRegistry, + false, + {} as WorkspaceContext, + ); + await manager.discoverAllMcpTools({ + isTrustedFolder: () => undefined, + } as unknown as Config); + expect(mockedMcpClient.connect).toHaveBeenCalledOnce(); + expect(mockedMcpClient.discover).toHaveBeenCalledOnce(); + }); + + it('should not discover tools if folder is not trusted', async () => { + const mockedMcpClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + }; + vi.mocked(McpClient).mockReturnValue( + mockedMcpClient as unknown as McpClient, + ); + const manager = new McpClientManager( + { + 'test-server': {}, + }, + '', + {} as ToolRegistry, + {} as PromptRegistry, + false, + {} as WorkspaceContext, + ); + await manager.discoverAllMcpTools({ + isTrustedFolder: () => false, + } as unknown as Config); + expect(mockedMcpClient.connect).not.toHaveBeenCalled(); + expect(mockedMcpClient.discover).not.toHaveBeenCalled(); + }); }); diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 182977efe6..3f25102af6 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { MCPServerConfig } from '../config/config.js'; +import type { Config, MCPServerConfig } from '../config/config.js'; import type { ToolRegistry } from './tool-registry.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { @@ -55,7 +55,10 @@ export class McpClientManager { * It connects to each server, discovers its available tools, and registers * them with the `ToolRegistry`. */ - async discoverAllMcpTools(): Promise { + async discoverAllMcpTools(cliConfig: Config): Promise { + if (cliConfig.isTrustedFolder() === false) { + return; + } await this.stop(); const servers = populateMcpServerCommand( @@ -91,7 +94,7 @@ export class McpClientManager { try { await client.connect(); - await client.discover(); + await client.discover(cliConfig); this.eventEmitter?.emit('mcp-server-connected', { name, current, diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index d97686de2f..d2c87bb8e4 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -19,7 +19,7 @@ import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import * as GenAiLib from '@google/genai'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; -import { AuthProviderType } from '../config/config.js'; +import { AuthProviderType, type Config } from '../config/config.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import type { ToolRegistry } from './tool-registry.js'; import type { WorkspaceContext } from '../utils/workspaceContext.js'; @@ -74,7 +74,7 @@ describe('mcp-client', () => { false, ); await client.connect(); - await client.discover(); + await client.discover({} as Config); expect(mockedMcpToTool).toHaveBeenCalledOnce(); }); @@ -136,7 +136,7 @@ describe('mcp-client', () => { false, ); await client.connect(); - await client.discover(); + await client.discover({} as Config); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(consoleWarnSpy).toHaveBeenCalledOnce(); expect(consoleWarnSpy).toHaveBeenCalledWith( @@ -180,7 +180,7 @@ describe('mcp-client', () => { false, ); await client.connect(); - await expect(client.discover()).rejects.toThrow( + await expect(client.discover({} as Config)).rejects.toThrow( 'No prompts or tools found on the server.', ); expect(consoleErrorSpy).toHaveBeenCalledWith( diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 67010e9269..d34c413065 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -21,7 +21,7 @@ import { ListRootsRequestSchema, } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; -import type { MCPServerConfig } from '../config/config.js'; +import type { Config, MCPServerConfig } from '../config/config.js'; import { AuthProviderType } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; @@ -146,13 +146,13 @@ export class McpClient { /** * Discovers tools and prompts from the MCP server. */ - async discover(): Promise { + async discover(cliConfig: Config): Promise { if (this.status !== MCPServerStatus.CONNECTED) { throw new Error('Client is not connected.'); } const prompts = await this.discoverPrompts(); - const tools = await this.discoverTools(); + const tools = await this.discoverTools(cliConfig); if (prompts.length === 0 && tools.length === 0) { throw new Error('No prompts or tools found on the server.'); @@ -191,8 +191,13 @@ export class McpClient { return createTransport(this.serverName, this.serverConfig, this.debugMode); } - private async discoverTools(): Promise { - return discoverTools(this.serverName, this.serverConfig, this.client); + private async discoverTools(cliConfig: Config): Promise { + return discoverTools( + this.serverName, + this.serverConfig, + this.client, + cliConfig, + ); } private async discoverPrompts(): Promise { @@ -445,6 +450,7 @@ export async function discoverMcpTools( promptRegistry: PromptRegistry, debugMode: boolean, workspaceContext: WorkspaceContext, + cliConfig: Config, ): Promise { mcpDiscoveryState = MCPDiscoveryState.IN_PROGRESS; try { @@ -459,6 +465,7 @@ export async function discoverMcpTools( promptRegistry, debugMode, workspaceContext, + cliConfig, ), ); await Promise.all(discoveryPromises); @@ -504,6 +511,7 @@ export async function connectAndDiscover( promptRegistry: PromptRegistry, debugMode: boolean, workspaceContext: WorkspaceContext, + cliConfig: Config, ): Promise { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -531,6 +539,7 @@ export async function connectAndDiscover( mcpServerName, mcpServerConfig, mcpClient, + cliConfig, ); // If we have neither prompts nor tools, it's a failed discovery @@ -632,6 +641,7 @@ export async function discoverTools( mcpServerName: string, mcpServerConfig: MCPServerConfig, mcpClient: Client, + cliConfig: Config, ): Promise { try { const mcpCallableTool = mcpToTool(mcpClient); @@ -667,6 +677,8 @@ export async function discoverTools( funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} }, mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, mcpServerConfig.trust, + undefined, + cliConfig, ), ); } catch (error) { diff --git a/packages/core/src/tools/mcp-tool.test.ts b/packages/core/src/tools/mcp-tool.test.ts index b9d7ddc3ac..5fc64cbaa6 100644 --- a/packages/core/src/tools/mcp-tool.test.ts +++ b/packages/core/src/tools/mcp-tool.test.ts @@ -747,6 +747,89 @@ describe('DiscoveredMCPTool', () => { }); }); + describe('shouldConfirmExecute with folder trust', () => { + const mockConfig = (isTrusted: boolean | undefined) => ({ + isTrustedFolder: () => isTrusted, + }); + + it('should return false if trust is true and folder is trusted', async () => { + const trustedTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + true, // trust = true + undefined, + mockConfig(true) as any, // isTrustedFolder = true + ); + const invocation = trustedTool.build({ param: 'mock' }); + expect( + await invocation.shouldConfirmExecute(new AbortController().signal), + ).toBe(false); + }); + + it('should return confirmation details if trust is true but folder is not trusted', async () => { + const trustedTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + true, // trust = true + undefined, + mockConfig(false) as any, // isTrustedFolder = false + ); + const invocation = trustedTool.build({ param: 'mock' }); + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + expect(confirmation).toHaveProperty('type', 'mcp'); + }); + + it('should return confirmation details if trust is false, even if folder is trusted', async () => { + const untrustedTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + false, // trust = false + undefined, + mockConfig(true) as any, // isTrustedFolder = true + ); + const invocation = untrustedTool.build({ param: 'mock' }); + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(confirmation).not.toBe(false); + expect(confirmation).toHaveProperty('type', 'mcp'); + }); + + it('should return false if trust is true and folder trust is undefined', async () => { + // The check is `isTrustedFolder() !== false`, so `undefined` should pass + const trustedTool = new DiscoveredMCPTool( + mockCallableToolInstance, + serverName, + serverToolName, + baseDescription, + inputSchema, + undefined, + true, // trust = true + undefined, + mockConfig(undefined) as any, // isTrustedFolder = undefined + ); + const invocation = trustedTool.build({ param: 'mock' }); + expect( + await invocation.shouldConfirmExecute(new AbortController().signal), + ).toBe(false); + }); + }); + describe('DiscoveredMCPToolInvocation', () => { it('should return the stringified params from getDescription', () => { const params = { param: 'testValue', param2: 'anotherOne' }; diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index ede885d88e..5bc48a1e7d 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -19,6 +19,7 @@ import { } from './tools.js'; import type { CallableTool, FunctionCall, Part } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; +import type { Config } from '../config/config.js'; type ToolParams = Record; @@ -70,6 +71,7 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< readonly timeout?: number, readonly trust?: boolean, params: ToolParams = {}, + private readonly cliConfig?: Config, ) { super(params); } @@ -80,7 +82,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< const serverAllowListKey = this.serverName; const toolAllowListKey = `${this.serverName}.${this.serverToolName}`; - if (this.trust) { + const isTrustedFolder = this.cliConfig?.isTrustedFolder() !== false; + + if (this.trust && isTrustedFolder) { return false; // server is trusted, no confirmation needed } @@ -183,6 +187,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< readonly timeout?: number, readonly trust?: boolean, nameOverride?: string, + private readonly cliConfig?: Config, ) { super( nameOverride ?? generateValidName(serverToolName), @@ -205,6 +210,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< this.timeout, this.trust, `${this.serverName}__${this.serverToolName}`, + this.cliConfig, ); } @@ -219,6 +225,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< this.timeout, this.trust, params, + this.cliConfig, ); } } diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index ec054d1821..c4cb46e9b2 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -236,7 +236,7 @@ export class ToolRegistry { await this.discoverAndRegisterToolsFromCommand(); // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(); + await this.mcpClientManager.discoverAllMcpTools(this.config); } /** @@ -251,7 +251,7 @@ export class ToolRegistry { this.config.getPromptRegistry().clear(); // discover tools using MCP servers, if configured - await this.mcpClientManager.discoverAllMcpTools(); + await this.mcpClientManager.discoverAllMcpTools(this.config); } /** @@ -285,6 +285,7 @@ export class ToolRegistry { this.config.getPromptRegistry(), this.config.getDebugMode(), this.config.getWorkspaceContext(), + this.config, ); } }