mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-07 11:51:14 -07:00
feat(plan): implement support for MCP servers in Plan mode (#18229)
This commit is contained in:
@@ -19,6 +19,7 @@ import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js';
|
||||
import { OAuthUtils } from '../mcp/oauth-utils.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import { ToolListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import {
|
||||
@@ -387,6 +388,157 @@ describe('mcp-client', () => {
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should register tool with readOnlyHint and add policy rule', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
setNotificationHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'readOnlyTool',
|
||||
description: 'A read-only tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
annotations: { readOnlyHint: true },
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockPolicyEngine = {
|
||||
addRule: vi.fn(),
|
||||
};
|
||||
const mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
} as unknown as Config;
|
||||
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
|
||||
// Verify tool registration
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
|
||||
// Verify policy rule addition
|
||||
expect(mockPolicyEngine.addRule).toHaveBeenCalledWith({
|
||||
toolName: 'test-server__readOnlyTool',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 50,
|
||||
modes: [ApprovalMode.PLAN],
|
||||
source: 'MCP Annotation (readOnlyHint) - test-server',
|
||||
});
|
||||
});
|
||||
|
||||
it('should not add policy rule for tool without readOnlyHint', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
setNotificationHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'writeTool',
|
||||
description: 'A write tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
// No annotations or readOnlyHint: false
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockPolicyEngine = {
|
||||
addRule: vi.fn(),
|
||||
};
|
||||
const mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue(mockPolicyEngine),
|
||||
} as unknown as Config;
|
||||
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should discover tools with $defs and $ref in schema', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
|
||||
@@ -32,6 +32,7 @@ import {
|
||||
PromptListChangedNotificationSchema,
|
||||
type Tool as McpTool,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type {
|
||||
Config,
|
||||
@@ -1028,6 +1029,9 @@ export async function discoverTools(
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
);
|
||||
|
||||
// Extract readOnlyHint from annotations
|
||||
const isReadOnly = toolDef.annotations?.readOnlyHint === true;
|
||||
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
mcpServerName,
|
||||
@@ -1036,12 +1040,24 @@ export async function discoverTools(
|
||||
toolDef.inputSchema ?? { type: 'object', properties: {} },
|
||||
messageBus,
|
||||
mcpServerConfig.trust,
|
||||
isReadOnly,
|
||||
undefined,
|
||||
cliConfig,
|
||||
mcpServerConfig.extension?.name,
|
||||
mcpServerConfig.extension?.id,
|
||||
);
|
||||
|
||||
// If the tool is read-only, allow it in Plan mode
|
||||
if (isReadOnly) {
|
||||
cliConfig.getPolicyEngine().addRule({
|
||||
toolName: tool.getFullyQualifiedName(),
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 50, // Match priority of built-in plan tools
|
||||
modes: [ApprovalMode.PLAN],
|
||||
source: `MCP Annotation (readOnlyHint) - ${mcpServerName}`,
|
||||
});
|
||||
}
|
||||
|
||||
discoveredTools.push(tool);
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
|
||||
@@ -203,6 +203,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorTrueCase' };
|
||||
const functionCall = {
|
||||
@@ -249,6 +250,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorTopLevelCase' };
|
||||
const functionCall = {
|
||||
@@ -298,6 +300,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
const params = { param: 'isErrorFalseCase' };
|
||||
const mockToolSuccessResultObject = {
|
||||
@@ -756,6 +759,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
createMockMessageBus(),
|
||||
true,
|
||||
undefined,
|
||||
undefined,
|
||||
{ isTrustedFolder: () => true } as any,
|
||||
undefined,
|
||||
undefined,
|
||||
@@ -901,6 +905,7 @@ describe('DiscoveredMCPTool', () => {
|
||||
bus,
|
||||
trust,
|
||||
undefined,
|
||||
undefined,
|
||||
mockConfig(isTrusted) as any,
|
||||
undefined,
|
||||
undefined,
|
||||
|
||||
@@ -247,6 +247,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
override readonly parameterSchema: unknown,
|
||||
messageBus: MessageBus,
|
||||
readonly trust?: boolean,
|
||||
readonly isReadOnly?: boolean,
|
||||
nameOverride?: string,
|
||||
private readonly cliConfig?: Config,
|
||||
override readonly extensionName?: string,
|
||||
@@ -283,6 +284,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.parameterSchema,
|
||||
this.messageBus,
|
||||
this.trust,
|
||||
this.isReadOnly,
|
||||
this.getFullyQualifiedName(),
|
||||
this.cliConfig,
|
||||
this.extensionName,
|
||||
|
||||
Reference in New Issue
Block a user