mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
fix(policy): ensure MCP policies match unqualified names in non-interactive mode (#16490)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import type { Mock } from 'vitest';
|
||||
import type { CallableTool } from '@google/genai';
|
||||
import { CoreToolScheduler } from './coreToolScheduler.js';
|
||||
import type {
|
||||
ToolCall,
|
||||
@@ -41,6 +42,7 @@ import {
|
||||
import * as modifiableToolModule from '../tools/modifiable-tool.js';
|
||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
|
||||
vi.mock('fs/promises', () => ({
|
||||
writeFile: vi.fn(),
|
||||
@@ -283,7 +285,10 @@ function createMockConfig(overrides: Partial<Config> = {}): Config {
|
||||
if (!overrides.getPolicyEngine) {
|
||||
finalConfig.getPolicyEngine = () =>
|
||||
({
|
||||
check: async (toolCall: { name: string; args: object }) => {
|
||||
check: async (
|
||||
toolCall: { name: string; args: object },
|
||||
_serverName?: string,
|
||||
) => {
|
||||
// Mock simple policy logic for tests
|
||||
const mode = finalConfig.getApprovalMode();
|
||||
if (mode === ApprovalMode.YOLO) {
|
||||
@@ -1834,4 +1839,69 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
||||
|
||||
modifyWithEditorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should pass serverName to policy engine for DiscoveredMCPTool', async () => {
|
||||
const mockMcpTool = {
|
||||
tool: async () => ({ functionDeclarations: [] }),
|
||||
callTool: async () => [],
|
||||
};
|
||||
const serverName = 'test-server';
|
||||
const toolName = 'test-tool';
|
||||
const mcpTool = new DiscoveredMCPTool(
|
||||
mockMcpTool as unknown as CallableTool,
|
||||
serverName,
|
||||
toolName,
|
||||
'description',
|
||||
{ type: 'object', properties: {} },
|
||||
createMockMessageBus() as unknown as MessageBus,
|
||||
);
|
||||
|
||||
const mockToolRegistry = {
|
||||
getTool: () => mcpTool,
|
||||
getFunctionDeclarations: () => [],
|
||||
tools: new Map(),
|
||||
discovery: {},
|
||||
registerTool: () => {},
|
||||
getToolByName: () => mcpTool,
|
||||
getToolByDisplayName: () => mcpTool,
|
||||
getTools: () => [],
|
||||
discoverTools: async () => {},
|
||||
getAllTools: () => [],
|
||||
getToolsByServer: () => [],
|
||||
} as unknown as ToolRegistry;
|
||||
|
||||
const mockPolicyEngineCheck = vi.fn().mockResolvedValue({
|
||||
decision: PolicyDecision.ALLOW,
|
||||
});
|
||||
|
||||
const mockConfig = createMockConfig({
|
||||
getToolRegistry: () => mockToolRegistry,
|
||||
getPolicyEngine: () =>
|
||||
({
|
||||
check: mockPolicyEngineCheck,
|
||||
}) as unknown as PolicyEngine,
|
||||
isInteractive: () => false,
|
||||
});
|
||||
|
||||
const scheduler = new CoreToolScheduler({
|
||||
config: mockConfig,
|
||||
getPreferredEditor: () => 'vscode',
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const request = {
|
||||
callId: '1',
|
||||
name: toolName,
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
};
|
||||
|
||||
await scheduler.schedule(request, abortController.signal);
|
||||
|
||||
expect(mockPolicyEngineCheck).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: toolName }),
|
||||
serverName,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -44,6 +44,7 @@ import {
|
||||
type ToolCallResponseInfo,
|
||||
} from '../scheduler/types.js';
|
||||
import { ToolExecutor } from '../scheduler/tool-executor.js';
|
||||
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||
|
||||
export type {
|
||||
ToolCall,
|
||||
@@ -591,9 +592,15 @@ export class CoreToolScheduler {
|
||||
name: toolCall.request.name,
|
||||
args: toolCall.request.args,
|
||||
};
|
||||
|
||||
const serverName =
|
||||
toolCall.tool instanceof DiscoveredMCPTool
|
||||
? toolCall.tool.serverName
|
||||
: undefined;
|
||||
|
||||
const { decision } = await this.config
|
||||
.getPolicyEngine()
|
||||
.check(toolCallForPolicy, undefined); // Server name undefined for local tools
|
||||
.check(toolCallForPolicy, serverName);
|
||||
|
||||
if (decision === PolicyDecision.DENY) {
|
||||
const errorMessage = `Tool execution denied by policy.`;
|
||||
|
||||
@@ -109,6 +109,37 @@ describe('PolicyEngine', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should match unqualified tool names with qualified rules when serverName is provided', async () => {
|
||||
const rules: PolicyRule[] = [
|
||||
{
|
||||
toolName: 'my-server__tool',
|
||||
decision: PolicyDecision.ALLOW,
|
||||
},
|
||||
];
|
||||
|
||||
engine = new PolicyEngine({ rules });
|
||||
|
||||
// Match with qualified name (standard)
|
||||
expect(
|
||||
(await engine.check({ name: 'my-server__tool' }, 'my-server')).decision,
|
||||
).toBe(PolicyDecision.ALLOW);
|
||||
|
||||
// Match with unqualified name + serverName (the fix)
|
||||
expect((await engine.check({ name: 'tool' }, 'my-server')).decision).toBe(
|
||||
PolicyDecision.ALLOW,
|
||||
);
|
||||
|
||||
// Should NOT match with unqualified name but NO serverName
|
||||
expect((await engine.check({ name: 'tool' }, undefined)).decision).toBe(
|
||||
PolicyDecision.ASK_USER,
|
||||
);
|
||||
|
||||
// Should NOT match with unqualified name but WRONG serverName
|
||||
expect(
|
||||
(await engine.check({ name: 'tool' }, 'wrong-server')).decision,
|
||||
).toBe(PolicyDecision.ASK_USER);
|
||||
});
|
||||
|
||||
it('should match by args pattern', async () => {
|
||||
const rules: PolicyRule[] = [
|
||||
{
|
||||
|
||||
@@ -310,16 +310,22 @@ export class PolicyEngine {
|
||||
let matchedRule: PolicyRule | undefined;
|
||||
let decision: PolicyDecision | undefined;
|
||||
|
||||
// For tools with a server name, we want to try matching both the
|
||||
// original name and the fully qualified name (server__tool).
|
||||
const toolCallsToTry: FunctionCall[] = [toolCall];
|
||||
if (serverName && toolCall.name && !toolCall.name.includes('__')) {
|
||||
toolCallsToTry.push({
|
||||
...toolCall,
|
||||
name: `${serverName}__${toolCall.name}`,
|
||||
});
|
||||
}
|
||||
|
||||
for (const rule of this.rules) {
|
||||
if (
|
||||
ruleMatches(
|
||||
rule,
|
||||
toolCall,
|
||||
stringifiedArgs,
|
||||
serverName,
|
||||
this.approvalMode,
|
||||
)
|
||||
) {
|
||||
const match = toolCallsToTry.some((tc) =>
|
||||
ruleMatches(rule, tc, stringifiedArgs, serverName, this.approvalMode),
|
||||
);
|
||||
|
||||
if (match) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user