mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-27 13:34:15 -07:00
feat(policy): Propagate Tool Annotations for MCP Servers (#20083)
This commit is contained in:
@@ -23,7 +23,7 @@ import {
|
||||
ResourceListChangedNotificationSchema,
|
||||
ToolListChangedNotificationSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
import type { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
import {
|
||||
@@ -392,7 +392,7 @@ describe('mcp-client', () => {
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should register tool with readOnlyHint and add policy rule', async () => {
|
||||
it('should register tool with readOnlyHint and preserve annotations', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
@@ -462,17 +462,18 @@ describe('mcp-client', () => {
|
||||
// Verify tool registration
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
|
||||
// Verify policy rule addition
|
||||
expect(mockPolicyEngine.addRule).toHaveBeenCalledWith({
|
||||
toolName: 'test-server__readOnlyTool',
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 50,
|
||||
modes: [ApprovalMode.PLAN],
|
||||
source: 'MCP Annotation (readOnlyHint) - test-server',
|
||||
});
|
||||
// Verify addRule is NOT called (annotation-based rules are in plan.toml now)
|
||||
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
|
||||
|
||||
// Verify annotations are preserved on the registered tool
|
||||
const registeredTool = (
|
||||
mockedToolRegistry.registerTool as ReturnType<typeof vi.fn>
|
||||
).mock.calls[0][0] as DiscoveredMCPTool;
|
||||
expect(registeredTool.toolAnnotations).toEqual({ readOnlyHint: true });
|
||||
expect(registeredTool.isReadOnly).toBe(true);
|
||||
});
|
||||
|
||||
it('should not add policy rule for tool without readOnlyHint', async () => {
|
||||
it('should preserve undefined annotations for tool without readOnlyHint', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
@@ -541,6 +542,93 @@ describe('mcp-client', () => {
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
expect(mockPolicyEngine.addRule).not.toHaveBeenCalled();
|
||||
|
||||
// Verify annotations are undefined for tools without annotations
|
||||
const registeredTool = (
|
||||
mockedToolRegistry.registerTool as ReturnType<typeof vi.fn>
|
||||
).mock.calls[0][0] as DiscoveredMCPTool;
|
||||
expect(registeredTool.toolAnnotations).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should preserve full annotations object with multiple hints', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
setNotificationHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'multiAnnotationTool',
|
||||
description: 'A tool with multiple annotations',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
annotations: {
|
||||
readOnlyHint: true,
|
||||
destructiveHint: false,
|
||||
idempotentHint: true,
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
|
||||
const mockConfig = {
|
||||
getPolicyEngine: vi.fn().mockReturnValue({ addRule: vi.fn() }),
|
||||
} as unknown as Config;
|
||||
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
removeMcpToolsByServer: vi.fn(),
|
||||
} as unknown as ToolRegistry;
|
||||
const promptRegistry = {
|
||||
registerPrompt: vi.fn(),
|
||||
removePromptsByServer: vi.fn(),
|
||||
} as unknown as PromptRegistry;
|
||||
const resourceRegistry = {
|
||||
setResourcesForServer: vi.fn(),
|
||||
removeResourcesByServer: vi.fn(),
|
||||
} as unknown as ResourceRegistry;
|
||||
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{ command: 'test-command' },
|
||||
mockedToolRegistry,
|
||||
promptRegistry,
|
||||
resourceRegistry,
|
||||
workspaceContext,
|
||||
{ sanitizationConfig: EMPTY_CONFIG } as Config,
|
||||
false,
|
||||
'0.0.1',
|
||||
);
|
||||
|
||||
await client.connect();
|
||||
await client.discover(mockConfig);
|
||||
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
|
||||
const registeredTool = (
|
||||
mockedToolRegistry.registerTool as ReturnType<typeof vi.fn>
|
||||
).mock.calls[0][0] as DiscoveredMCPTool;
|
||||
expect(registeredTool.toolAnnotations).toEqual({
|
||||
readOnlyHint: true,
|
||||
destructiveHint: false,
|
||||
idempotentHint: true,
|
||||
});
|
||||
expect(registeredTool.isReadOnly).toBe(true);
|
||||
});
|
||||
|
||||
it('should discover tools with $defs and $ref in schema', async () => {
|
||||
|
||||
@@ -33,7 +33,6 @@ import {
|
||||
PromptListChangedNotificationSchema,
|
||||
ProgressNotificationSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { ApprovalMode, PolicyDecision } from '../policy/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
@@ -1078,8 +1077,9 @@ export async function discoverTools(
|
||||
options?.progressReporter,
|
||||
);
|
||||
|
||||
// Extract readOnlyHint from annotations
|
||||
const isReadOnly = toolDef.annotations?.readOnlyHint === true;
|
||||
// Extract annotations from the tool definition
|
||||
const annotations = toolDef.annotations;
|
||||
const isReadOnly = annotations?.readOnlyHint === true;
|
||||
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
@@ -1094,19 +1094,9 @@ export async function discoverTools(
|
||||
cliConfig,
|
||||
mcpServerConfig.extension?.name,
|
||||
mcpServerConfig.extension?.id,
|
||||
annotations as Record<string, unknown> | undefined,
|
||||
);
|
||||
|
||||
// If the tool is read-only, allow it in Plan mode
|
||||
if (isReadOnly) {
|
||||
cliConfig.getPolicyEngine().addRule({
|
||||
toolName: tool.getFullyQualifiedName(),
|
||||
decision: PolicyDecision.ASK_USER,
|
||||
priority: 50, // Match priority of built-in plan tools
|
||||
modes: [ApprovalMode.PLAN],
|
||||
source: `MCP Annotation (readOnlyHint) - ${mcpServerName}`,
|
||||
});
|
||||
}
|
||||
|
||||
discoveredTools.push(tool);
|
||||
} catch (error) {
|
||||
coreEvents.emitFeedback(
|
||||
|
||||
@@ -82,6 +82,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
private readonly cliConfig?: Config,
|
||||
private readonly toolDescription?: string,
|
||||
private readonly toolParameterSchema?: unknown,
|
||||
toolAnnotationsData?: Record<string, unknown>,
|
||||
) {
|
||||
// Use composite format for policy checks: serverName__toolName
|
||||
// This enables server wildcards (e.g., "google-workspace__*")
|
||||
@@ -93,6 +94,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
|
||||
`${serverName}${MCP_QUALIFIED_NAME_SEPARATOR}${serverToolName}`,
|
||||
displayName,
|
||||
serverName,
|
||||
toolAnnotationsData,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -257,6 +259,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
private readonly cliConfig?: Config,
|
||||
override readonly extensionName?: string,
|
||||
override readonly extensionId?: string,
|
||||
private readonly _toolAnnotations?: Record<string, unknown>,
|
||||
) {
|
||||
super(
|
||||
nameOverride ?? generateValidName(serverToolName),
|
||||
@@ -282,6 +285,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
return super.isReadOnly;
|
||||
}
|
||||
|
||||
override get toolAnnotations(): Record<string, unknown> | undefined {
|
||||
return this._toolAnnotations;
|
||||
}
|
||||
|
||||
getFullyQualifiedPrefix(): string {
|
||||
return `${this.serverName}${MCP_QUALIFIED_NAME_SEPARATOR}`;
|
||||
}
|
||||
@@ -304,6 +311,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.cliConfig,
|
||||
this.extensionName,
|
||||
this.extensionId,
|
||||
this._toolAnnotations,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -324,6 +332,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool<
|
||||
this.cliConfig,
|
||||
this.description,
|
||||
this.parameterSchema,
|
||||
this._toolAnnotations,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,13 +441,25 @@ export class ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
private buildToolMetadata(): Map<string, Record<string, unknown>> {
|
||||
const toolMetadata = new Map<string, Record<string, unknown>>();
|
||||
for (const [name, tool] of this.allKnownTools) {
|
||||
if (tool.toolAnnotations) {
|
||||
toolMetadata.set(name, tool.toolAnnotations);
|
||||
}
|
||||
}
|
||||
return toolMetadata;
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns All the tools that are not excluded.
|
||||
*/
|
||||
private getActiveTools(): AnyDeclarativeTool[] {
|
||||
const toolMetadata = this.buildToolMetadata();
|
||||
const excludedTools =
|
||||
this.expandExcludeToolsWithAliases(this.config.getExcludeTools()) ??
|
||||
new Set([]);
|
||||
this.expandExcludeToolsWithAliases(
|
||||
this.config.getExcludeTools(toolMetadata),
|
||||
) ?? new Set([]);
|
||||
const activeTools: AnyDeclarativeTool[] = [];
|
||||
for (const tool of this.allKnownTools.values()) {
|
||||
if (this.isActiveTool(tool, excludedTools)) {
|
||||
@@ -487,8 +499,9 @@ export class ToolRegistry {
|
||||
excludeTools?: Set<string>,
|
||||
): boolean {
|
||||
excludeTools ??=
|
||||
this.expandExcludeToolsWithAliases(this.config.getExcludeTools()) ??
|
||||
new Set([]);
|
||||
this.expandExcludeToolsWithAliases(
|
||||
this.config.getExcludeTools(this.buildToolMetadata()),
|
||||
) ?? new Set([]);
|
||||
|
||||
// Filter tools in Plan Mode to only allow approved read-only tools.
|
||||
const isPlanMode =
|
||||
|
||||
@@ -91,6 +91,7 @@ export abstract class BaseToolInvocation<
|
||||
readonly _toolName?: string,
|
||||
readonly _toolDisplayName?: string,
|
||||
readonly _serverName?: string,
|
||||
readonly _toolAnnotations?: Record<string, unknown>,
|
||||
) {}
|
||||
|
||||
abstract getDescription(): string;
|
||||
@@ -199,6 +200,7 @@ export abstract class BaseToolInvocation<
|
||||
args: this.params as Record<string, unknown>,
|
||||
},
|
||||
serverName: this._serverName,
|
||||
toolAnnotations: this._toolAnnotations,
|
||||
};
|
||||
|
||||
return new Promise<'ALLOW' | 'DENY' | 'ASK_USER'>((resolve) => {
|
||||
@@ -372,6 +374,10 @@ export abstract class DeclarativeTool<
|
||||
return READ_ONLY_KINDS.includes(this.kind);
|
||||
}
|
||||
|
||||
get toolAnnotations(): Record<string, unknown> | undefined {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getSchema(_modelId?: string): FunctionDeclaration {
|
||||
return {
|
||||
name: this.name,
|
||||
|
||||
Reference in New Issue
Block a user