mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-01 07:24:38 -07:00
fix(policy): ensure MCP policies match unqualified names in non-interactive mode (#16490)
This commit is contained in:
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import { describe, it, expect, vi } from 'vitest';
|
import { describe, it, expect, vi } from 'vitest';
|
||||||
import type { Mock } from 'vitest';
|
import type { Mock } from 'vitest';
|
||||||
|
import type { CallableTool } from '@google/genai';
|
||||||
import { CoreToolScheduler } from './coreToolScheduler.js';
|
import { CoreToolScheduler } from './coreToolScheduler.js';
|
||||||
import type {
|
import type {
|
||||||
ToolCall,
|
ToolCall,
|
||||||
@@ -41,6 +42,7 @@ import {
|
|||||||
import * as modifiableToolModule from '../tools/modifiable-tool.js';
|
import * as modifiableToolModule from '../tools/modifiable-tool.js';
|
||||||
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_MODEL } from '../config/models.js';
|
||||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||||
|
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||||
|
|
||||||
vi.mock('fs/promises', () => ({
|
vi.mock('fs/promises', () => ({
|
||||||
writeFile: vi.fn(),
|
writeFile: vi.fn(),
|
||||||
@@ -283,7 +285,10 @@ function createMockConfig(overrides: Partial<Config> = {}): Config {
|
|||||||
if (!overrides.getPolicyEngine) {
|
if (!overrides.getPolicyEngine) {
|
||||||
finalConfig.getPolicyEngine = () =>
|
finalConfig.getPolicyEngine = () =>
|
||||||
({
|
({
|
||||||
check: async (toolCall: { name: string; args: object }) => {
|
check: async (
|
||||||
|
toolCall: { name: string; args: object },
|
||||||
|
_serverName?: string,
|
||||||
|
) => {
|
||||||
// Mock simple policy logic for tests
|
// Mock simple policy logic for tests
|
||||||
const mode = finalConfig.getApprovalMode();
|
const mode = finalConfig.getApprovalMode();
|
||||||
if (mode === ApprovalMode.YOLO) {
|
if (mode === ApprovalMode.YOLO) {
|
||||||
@@ -1834,4 +1839,69 @@ describe('CoreToolScheduler Sequential Execution', () => {
|
|||||||
|
|
||||||
modifyWithEditorSpy.mockRestore();
|
modifyWithEditorSpy.mockRestore();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should pass serverName to policy engine for DiscoveredMCPTool', async () => {
|
||||||
|
const mockMcpTool = {
|
||||||
|
tool: async () => ({ functionDeclarations: [] }),
|
||||||
|
callTool: async () => [],
|
||||||
|
};
|
||||||
|
const serverName = 'test-server';
|
||||||
|
const toolName = 'test-tool';
|
||||||
|
const mcpTool = new DiscoveredMCPTool(
|
||||||
|
mockMcpTool as unknown as CallableTool,
|
||||||
|
serverName,
|
||||||
|
toolName,
|
||||||
|
'description',
|
||||||
|
{ type: 'object', properties: {} },
|
||||||
|
createMockMessageBus() as unknown as MessageBus,
|
||||||
|
);
|
||||||
|
|
||||||
|
const mockToolRegistry = {
|
||||||
|
getTool: () => mcpTool,
|
||||||
|
getFunctionDeclarations: () => [],
|
||||||
|
tools: new Map(),
|
||||||
|
discovery: {},
|
||||||
|
registerTool: () => {},
|
||||||
|
getToolByName: () => mcpTool,
|
||||||
|
getToolByDisplayName: () => mcpTool,
|
||||||
|
getTools: () => [],
|
||||||
|
discoverTools: async () => {},
|
||||||
|
getAllTools: () => [],
|
||||||
|
getToolsByServer: () => [],
|
||||||
|
} as unknown as ToolRegistry;
|
||||||
|
|
||||||
|
const mockPolicyEngineCheck = vi.fn().mockResolvedValue({
|
||||||
|
decision: PolicyDecision.ALLOW,
|
||||||
|
});
|
||||||
|
|
||||||
|
const mockConfig = createMockConfig({
|
||||||
|
getToolRegistry: () => mockToolRegistry,
|
||||||
|
getPolicyEngine: () =>
|
||||||
|
({
|
||||||
|
check: mockPolicyEngineCheck,
|
||||||
|
}) as unknown as PolicyEngine,
|
||||||
|
isInteractive: () => false,
|
||||||
|
});
|
||||||
|
|
||||||
|
const scheduler = new CoreToolScheduler({
|
||||||
|
config: mockConfig,
|
||||||
|
getPreferredEditor: () => 'vscode',
|
||||||
|
});
|
||||||
|
|
||||||
|
const abortController = new AbortController();
|
||||||
|
const request = {
|
||||||
|
callId: '1',
|
||||||
|
name: toolName,
|
||||||
|
args: {},
|
||||||
|
isClientInitiated: false,
|
||||||
|
prompt_id: 'prompt-id-1',
|
||||||
|
};
|
||||||
|
|
||||||
|
await scheduler.schedule(request, abortController.signal);
|
||||||
|
|
||||||
|
expect(mockPolicyEngineCheck).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({ name: toolName }),
|
||||||
|
serverName,
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ import {
|
|||||||
type ToolCallResponseInfo,
|
type ToolCallResponseInfo,
|
||||||
} from '../scheduler/types.js';
|
} from '../scheduler/types.js';
|
||||||
import { ToolExecutor } from '../scheduler/tool-executor.js';
|
import { ToolExecutor } from '../scheduler/tool-executor.js';
|
||||||
|
import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
|
||||||
|
|
||||||
export type {
|
export type {
|
||||||
ToolCall,
|
ToolCall,
|
||||||
@@ -591,9 +592,15 @@ export class CoreToolScheduler {
|
|||||||
name: toolCall.request.name,
|
name: toolCall.request.name,
|
||||||
args: toolCall.request.args,
|
args: toolCall.request.args,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const serverName =
|
||||||
|
toolCall.tool instanceof DiscoveredMCPTool
|
||||||
|
? toolCall.tool.serverName
|
||||||
|
: undefined;
|
||||||
|
|
||||||
const { decision } = await this.config
|
const { decision } = await this.config
|
||||||
.getPolicyEngine()
|
.getPolicyEngine()
|
||||||
.check(toolCallForPolicy, undefined); // Server name undefined for local tools
|
.check(toolCallForPolicy, serverName);
|
||||||
|
|
||||||
if (decision === PolicyDecision.DENY) {
|
if (decision === PolicyDecision.DENY) {
|
||||||
const errorMessage = `Tool execution denied by policy.`;
|
const errorMessage = `Tool execution denied by policy.`;
|
||||||
|
|||||||
@@ -109,6 +109,37 @@ describe('PolicyEngine', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should match unqualified tool names with qualified rules when serverName is provided', async () => {
|
||||||
|
const rules: PolicyRule[] = [
|
||||||
|
{
|
||||||
|
toolName: 'my-server__tool',
|
||||||
|
decision: PolicyDecision.ALLOW,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
engine = new PolicyEngine({ rules });
|
||||||
|
|
||||||
|
// Match with qualified name (standard)
|
||||||
|
expect(
|
||||||
|
(await engine.check({ name: 'my-server__tool' }, 'my-server')).decision,
|
||||||
|
).toBe(PolicyDecision.ALLOW);
|
||||||
|
|
||||||
|
// Match with unqualified name + serverName (the fix)
|
||||||
|
expect((await engine.check({ name: 'tool' }, 'my-server')).decision).toBe(
|
||||||
|
PolicyDecision.ALLOW,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should NOT match with unqualified name but NO serverName
|
||||||
|
expect((await engine.check({ name: 'tool' }, undefined)).decision).toBe(
|
||||||
|
PolicyDecision.ASK_USER,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should NOT match with unqualified name but WRONG serverName
|
||||||
|
expect(
|
||||||
|
(await engine.check({ name: 'tool' }, 'wrong-server')).decision,
|
||||||
|
).toBe(PolicyDecision.ASK_USER);
|
||||||
|
});
|
||||||
|
|
||||||
it('should match by args pattern', async () => {
|
it('should match by args pattern', async () => {
|
||||||
const rules: PolicyRule[] = [
|
const rules: PolicyRule[] = [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -310,16 +310,22 @@ export class PolicyEngine {
|
|||||||
let matchedRule: PolicyRule | undefined;
|
let matchedRule: PolicyRule | undefined;
|
||||||
let decision: PolicyDecision | undefined;
|
let decision: PolicyDecision | undefined;
|
||||||
|
|
||||||
|
// For tools with a server name, we want to try matching both the
|
||||||
|
// original name and the fully qualified name (server__tool).
|
||||||
|
const toolCallsToTry: FunctionCall[] = [toolCall];
|
||||||
|
if (serverName && toolCall.name && !toolCall.name.includes('__')) {
|
||||||
|
toolCallsToTry.push({
|
||||||
|
...toolCall,
|
||||||
|
name: `${serverName}__${toolCall.name}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
for (const rule of this.rules) {
|
for (const rule of this.rules) {
|
||||||
if (
|
const match = toolCallsToTry.some((tc) =>
|
||||||
ruleMatches(
|
ruleMatches(rule, tc, stringifiedArgs, serverName, this.approvalMode),
|
||||||
rule,
|
);
|
||||||
toolCall,
|
|
||||||
stringifiedArgs,
|
if (match) {
|
||||||
serverName,
|
|
||||||
this.approvalMode,
|
|
||||||
)
|
|
||||||
) {
|
|
||||||
debugLogger.debug(
|
debugLogger.debug(
|
||||||
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
|
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
|
||||||
);
|
);
|
||||||
|
|||||||
Reference in New Issue
Block a user