feat(plan): implement support for MCP servers in Plan mode (#18229)

This commit is contained in:
Adib234
2026-02-05 16:37:28 -05:00
committed by GitHub
parent 00a739e84c
commit fe975da91e
8 changed files with 256 additions and 9 deletions

View File

@@ -60,6 +60,7 @@ const createMockMCPTool = (
{ type: 'object', properties: {} },
mockMessageBus,
undefined, // trust
undefined, // isReadOnly
undefined, // nameOverride
undefined, // cliConfig
undefined, // extensionName

View File

@@ -22,6 +22,9 @@ import {
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
import { ApprovalMode } from '../policy/types.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import type { CallableTool } from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
// Mock tool names if they are dynamically generated or complex
vi.mock('../tools/ls', () => ({ LSTool: { Name: 'list_directory' } }));
@@ -33,7 +36,10 @@ vi.mock('../tools/read-many-files', () => ({
ReadManyFilesTool: { Name: 'read_many_files' },
}));
vi.mock('../tools/shell', () => ({
ShellTool: { Name: 'run_shell_command' },
ShellTool: class {
static readonly Name = 'run_shell_command';
name = 'run_shell_command';
},
}));
vi.mock('../tools/write-file', () => ({
WriteFileTool: { Name: 'write_file' },
@@ -76,6 +82,7 @@ describe('Core System Prompt (prompts.ts)', () => {
mockConfig = {
getToolRegistry: vi.fn().mockReturnValue({
getAllToolNames: vi.fn().mockReturnValue([]),
getAllTools: vi.fn().mockReturnValue([]),
}),
getEnableShellOutputEfficiency: vi.fn().mockReturnValue(true),
storage: {
@@ -90,6 +97,7 @@ describe('Core System Prompt (prompts.ts)', () => {
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL),
getPreviewFeatures: vi.fn().mockReturnValue(false),
getMessageBus: vi.fn(),
getAgentRegistry: vi.fn().mockReturnValue({
getDirectoryContext: vi.fn().mockReturnValue('Mock Agent Directory'),
}),
@@ -299,6 +307,48 @@ describe('Core System Prompt (prompts.ts)', () => {
expect(prompt).toMatchSnapshot();
});
it('should include read-only MCP tools in PLAN mode', () => {
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
const readOnlyMcpTool = new DiscoveredMCPTool(
{} as CallableTool,
'readonly-server',
'read_static_value',
'A read-only tool',
{},
{} as MessageBus,
false,
true, // isReadOnly
);
const nonReadOnlyMcpTool = new DiscoveredMCPTool(
{} as CallableTool,
'nonreadonly-server',
'non_read_static_value',
'A non-read-only tool',
{},
{} as MessageBus,
false,
false,
);
vi.mocked(mockConfig.getToolRegistry().getAllTools).mockReturnValue([
readOnlyMcpTool,
nonReadOnlyMcpTool,
]);
vi.mocked(mockConfig.getToolRegistry().getAllToolNames).mockReturnValue([
readOnlyMcpTool.name,
nonReadOnlyMcpTool.name,
]);
const prompt = getCoreSystemPrompt(mockConfig);
expect(prompt).toContain('`read_static_value` (readonly-server)');
expect(prompt).not.toContain(
'`non_read_static_value` (nonreadonly-server)',
);
});
it('should only list available tools in PLAN mode', () => {
vi.mocked(mockConfig.getApprovalMode).mockReturnValue(ApprovalMode.PLAN);
// Only enable a subset of tools, including ask_user

View File

@@ -26,6 +26,7 @@ import {
ENTER_PLAN_MODE_TOOL_NAME,
} from '../tools/tool-names.js';
import { resolveModel, isPreviewModel } from '../config/models.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
/**
* Orchestrates prompt generation by gathering context and building options.
@@ -48,6 +49,7 @@ export class PromptProvider {
const isPlanMode = approvalMode === ApprovalMode.PLAN;
const skills = config.getSkillManager().getSkills();
const toolNames = config.getToolRegistry().getAllToolNames();
const enabledToolNames = new Set(toolNames);
const approvedPlanPath = config.getApprovedPlanPath();
const desiredModel = resolveModel(
@@ -56,6 +58,28 @@ export class PromptProvider {
);
const isGemini3 = isPreviewModel(desiredModel);
// --- Context Gathering ---
let planModeToolsList = PLAN_MODE_TOOLS.filter((t) =>
enabledToolNames.has(t),
)
.map((t) => `- \`${t}\``)
.join('\n');
// Add read-only MCP tools to the list
if (isPlanMode) {
const allTools = config.getToolRegistry().getAllTools();
const readOnlyMcpTools = allTools.filter(
(t): t is DiscoveredMCPTool =>
t instanceof DiscoveredMCPTool && !!t.isReadOnly,
);
if (readOnlyMcpTools.length > 0) {
const mcpToolsList = readOnlyMcpTools
.map((t) => `- \`${t.name}\` (${t.serverName})`)
.join('\n');
planModeToolsList += `\n${mcpToolsList}`;
}
}
let basePrompt: string;
// --- Template File Override ---
@@ -105,11 +129,11 @@ export class PromptProvider {
'primaryWorkflows',
() => ({
interactive: interactiveMode,
enableCodebaseInvestigator: toolNames.includes(
enableCodebaseInvestigator: enabledToolNames.has(
CodebaseInvestigatorAgent.name,
),
enableWriteTodosTool: toolNames.includes(WRITE_TODOS_TOOL_NAME),
enableEnterPlanModeTool: toolNames.includes(
enableWriteTodosTool: enabledToolNames.has(WRITE_TODOS_TOOL_NAME),
enableEnterPlanModeTool: enabledToolNames.has(
ENTER_PLAN_MODE_TOOL_NAME,
),
approvedPlan: approvedPlanPath
@@ -121,11 +145,7 @@ export class PromptProvider {
planningWorkflow: this.withSection(
'planningWorkflow',
() => ({
planModeToolsList: PLAN_MODE_TOOLS.filter((t) =>
new Set(toolNames).has(t),
)
.map((t) => `- \`${t}\``)
.join('\n'),
planModeToolsList,
plansDir: config.storage.getProjectTempPlansDir(),
approvedPlanPath: config.getApprovedPlanPath(),
}),

View File

@@ -1494,6 +1494,7 @@ describe('loggers', () => {
false,
undefined,
undefined,
undefined,
'test-extension',
'test-extension-id',
);

View File

@@ -19,6 +19,7 @@ import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
import { OAuthUtils } from '../mcp/oauth-utils.js';
import type { PromptRegistry } from '../prompts/prompt-registry.js';
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import {
@@ -387,6 +388,157 @@ describe('mcp-client', () => {
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
});
it('should register tool with readOnlyHint and add policy rule', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
setNotificationHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'readOnlyTool',
description: 'A read-only tool',
inputSchema: { type: 'object', properties: {} },
annotations: { readOnlyHint: true },
},
],
}),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockPolicyEngine = {
addRule: vi.fn(),
};
const mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
} as unknown as Config;
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
removeMcpToolsByServer: vi.fn(),
} as unknown as ToolRegistry;
const promptRegistry = {
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry;
const resourceRegistry = {
setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry;
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config,
false,
'0.0.1',
);
await client.connect();
await client.discover(mockConfig);
// Verify tool registration
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
// Verify policy rule addition
expect(mockPolicyEngine.addRule).toHaveBeenCalledWith({
toolName: 'test-server__readOnlyTool',
decision: PolicyDecision.ASK_USER,
priority: 50,
modes: [ApprovalMode.PLAN],
source: 'MCP Annotation (readOnlyHint) - test-server',
});
});
it('should not add policy rule for tool without readOnlyHint', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
setNotificationHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'writeTool',
description: 'A write tool',
inputSchema: { type: 'object', properties: {} },
// No annotations or readOnlyHint: false
},
],
}),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockPolicyEngine = {
addRule: vi.fn(),
};
const mockConfig = {
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
} as unknown as Config;
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
removeMcpToolsByServer: vi.fn(),
} as unknown as ToolRegistry;
const promptRegistry = {
registerPrompt: vi.fn(),
removePromptsByServer: vi.fn(),
} as unknown as PromptRegistry;
const resourceRegistry = {
setResourcesForServer: vi.fn(),
removeResourcesByServer: vi.fn(),
} as unknown as ResourceRegistry;
const client = new McpClient(
'test-server',
{ command: 'test-command' },
mockedToolRegistry,
promptRegistry,
resourceRegistry,
workspaceContext,
{ sanitizationConfig: EMPTY_CONFIG } as Config,
false,
'0.0.1',
);
await client.connect();
await client.discover(mockConfig);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
});
it('should discover tools with $defs and $ref in schema', async () => {
const mockedClient = {
connect: vi.fn(),

View File

@@ -32,6 +32,7 @@ import {
PromptListChangedNotificationSchema,
type Tool as McpTool,
} from '@modelcontextprotocol/sdk/types.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
import { parse } from 'shell-quote';
import type {
Config,
@@ -1028,6 +1029,9 @@ export async function discoverTools(
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
);
// Extract readOnlyHint from annotations
const isReadOnly = toolDef.annotations?.readOnlyHint === true;
const tool = new DiscoveredMCPTool(
mcpCallableTool,
mcpServerName,
@@ -1036,12 +1040,24 @@ export async function discoverTools(
toolDef.inputSchema ?? { type: 'object', properties: {} },
messageBus,
mcpServerConfig.trust,
isReadOnly,
undefined,
cliConfig,
mcpServerConfig.extension?.name,
mcpServerConfig.extension?.id,
);
// If the tool is read-only, allow it in Plan mode
if (isReadOnly) {
cliConfig.getPolicyEngine().addRule({
toolName: tool.getFullyQualifiedName(),
decision: PolicyDecision.ASK_USER,
priority: 50, // Match priority of built-in plan tools
modes: [ApprovalMode.PLAN],
source: `MCP Annotation (readOnlyHint) - ${mcpServerName}`,
});
}
discoveredTools.push(tool);
} catch (error) {
coreEvents.emitFeedback(

View File

@@ -203,6 +203,7 @@ describe('DiscoveredMCPTool', () => {
undefined,
undefined,
undefined,
undefined,
);
const params = { param: 'isErrorTrueCase' };
const functionCall = {
@@ -249,6 +250,7 @@ describe('DiscoveredMCPTool', () => {
undefined,
undefined,
undefined,
undefined,
);
const params = { param: 'isErrorTopLevelCase' };
const functionCall = {
@@ -298,6 +300,7 @@ describe('DiscoveredMCPTool', () => {
undefined,
undefined,
undefined,
undefined,
);
const params = { param: 'isErrorFalseCase' };
const mockToolSuccessResultObject = {
@@ -756,6 +759,7 @@ describe('DiscoveredMCPTool', () => {
createMockMessageBus(),
true,
undefined,
undefined,
{ isTrustedFolder: () => true } as any,
undefined,
undefined,
@@ -901,6 +905,7 @@ describe('DiscoveredMCPTool', () => {
bus,
trust,
undefined,
undefined,
mockConfig(isTrusted) as any,
undefined,
undefined,

View File

@@ -247,6 +247,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
override readonly parameterSchema: unknown,
messageBus: MessageBus,
readonly trust?: boolean,
readonly isReadOnly?: boolean,
nameOverride?: string,
private readonly cliConfig?: Config,
override readonly extensionName?: string,
@@ -283,6 +284,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
this.parameterSchema,
this.messageBus,
this.trust,
this.isReadOnly,
this.getFullyQualifiedName(),
this.cliConfig,
this.extensionName,