mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-02 16:04:38 -07:00
feat(policy): support auto-add to policy by default and scoped persistence (#20361)
This commit is contained in:
@@ -29,7 +29,7 @@ import { type MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { SHELL_TOOL_NAMES } from '../utils/shell-utils.js';
|
||||
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
|
||||
import { SHELL_TOOL_NAME, SENSITIVE_TOOLS } from '../tools/tool-names.js';
|
||||
import { isNodeError } from '../utils/errors.js';
|
||||
import { MCP_TOOL_PREFIX } from '../tools/mcp-tool.js';
|
||||
|
||||
@@ -46,13 +46,20 @@ export const WORKSPACE_POLICY_TIER = 3;
|
||||
export const USER_POLICY_TIER = 4;
|
||||
export const ADMIN_POLICY_TIER = 5;
|
||||
|
||||
// Specific priority offsets and derived priorities for dynamic/settings rules.
|
||||
// These are added to the tier base (e.g., USER_POLICY_TIER).
|
||||
/**
|
||||
* The fractional priority of "Always allow" rules (e.g., 950/1000).
|
||||
* Higher fraction within a tier wins.
|
||||
*/
|
||||
export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950;
|
||||
|
||||
// Workspace tier (3) + high priority (950/1000) = ALWAYS_ALLOW_PRIORITY
|
||||
// This ensures user "always allow" selections are high priority
|
||||
// within the workspace tier but still lose to user/admin policies.
|
||||
export const ALWAYS_ALLOW_PRIORITY = WORKSPACE_POLICY_TIER + 0.95;
|
||||
/**
|
||||
* The fractional priority offset for "Always allow" rules (e.g., 0.95).
|
||||
* This ensures consistency between in-memory rules and persisted rules.
|
||||
*/
|
||||
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
|
||||
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
|
||||
|
||||
// Specific priority offsets and derived priorities for dynamic/settings rules.
|
||||
|
||||
export const MCP_EXCLUDED_PRIORITY = USER_POLICY_TIER + 0.9;
|
||||
export const EXCLUDE_TOOLS_FLAG_PRIORITY = USER_POLICY_TIER + 0.4;
|
||||
@@ -60,6 +67,18 @@ export const ALLOWED_TOOLS_FLAG_PRIORITY = USER_POLICY_TIER + 0.3;
|
||||
export const TRUSTED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.2;
|
||||
export const ALLOWED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.1;
|
||||
|
||||
// These are added to the tier base (e.g., USER_POLICY_TIER).
|
||||
// Workspace tier (3) + high priority (950/1000) = ALWAYS_ALLOW_PRIORITY
|
||||
export const ALWAYS_ALLOW_PRIORITY =
|
||||
WORKSPACE_POLICY_TIER + ALWAYS_ALLOW_PRIORITY_OFFSET;
|
||||
|
||||
/**
|
||||
* Returns the fractional priority of ALWAYS_ALLOW_PRIORITY scaled to 1000.
|
||||
*/
|
||||
export function getAlwaysAllowPriorityFraction(): number {
|
||||
return Math.round((ALWAYS_ALLOW_PRIORITY % 1) * 1000);
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the list of directories to search for policy files, in order of increasing priority
|
||||
* (Default -> Extension -> Workspace -> User -> Admin).
|
||||
@@ -492,6 +511,19 @@ export function createPolicyUpdater(
|
||||
if (message.commandPrefix) {
|
||||
// Convert commandPrefix(es) to argsPatterns for in-memory rules
|
||||
const patterns = buildArgsPatterns(undefined, message.commandPrefix);
|
||||
const tier =
|
||||
message.persistScope === 'user'
|
||||
? USER_POLICY_TIER
|
||||
: WORKSPACE_POLICY_TIER;
|
||||
const priority = tier + getAlwaysAllowPriorityFraction() / 1000;
|
||||
|
||||
if (SENSITIVE_TOOLS.has(toolName) && !message.commandPrefix) {
|
||||
debugLogger.warn(
|
||||
`Attempted to update policy for sensitive tool '${toolName}' without a commandPrefix. Skipping.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const pattern of patterns) {
|
||||
if (pattern) {
|
||||
// Note: patterns from buildArgsPatterns are derived from escapeRegex,
|
||||
@@ -499,7 +531,7 @@ export function createPolicyUpdater(
|
||||
policyEngine.addRule({
|
||||
toolName,
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: ALWAYS_ALLOW_PRIORITY,
|
||||
priority,
|
||||
argsPattern: new RegExp(pattern),
|
||||
source: 'Dynamic (Confirmed)',
|
||||
});
|
||||
@@ -518,10 +550,23 @@ export function createPolicyUpdater(
|
||||
? new RegExp(message.argsPattern)
|
||||
: undefined;
|
||||
|
||||
const tier =
|
||||
message.persistScope === 'user'
|
||||
? USER_POLICY_TIER
|
||||
: WORKSPACE_POLICY_TIER;
|
||||
const priority = tier + getAlwaysAllowPriorityFraction() / 1000;
|
||||
|
||||
if (SENSITIVE_TOOLS.has(toolName) && !message.argsPattern) {
|
||||
debugLogger.warn(
|
||||
`Attempted to update policy for sensitive tool '${toolName}' without an argsPattern. Skipping.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
policyEngine.addRule({
|
||||
toolName,
|
||||
decision: PolicyDecision.ALLOW,
|
||||
priority: ALWAYS_ALLOW_PRIORITY,
|
||||
priority,
|
||||
argsPattern,
|
||||
source: 'Dynamic (Confirmed)',
|
||||
});
|
||||
@@ -530,7 +575,10 @@ export function createPolicyUpdater(
|
||||
if (message.persist) {
|
||||
persistenceQueue = persistenceQueue.then(async () => {
|
||||
try {
|
||||
const policyFile = storage.getAutoSavedPolicyPath();
|
||||
const policyFile =
|
||||
message.persistScope === 'workspace'
|
||||
? storage.getWorkspaceAutoSavedPolicyPath()
|
||||
: storage.getAutoSavedPolicyPath();
|
||||
await fs.mkdir(path.dirname(policyFile), { recursive: true });
|
||||
|
||||
// Read existing file
|
||||
@@ -560,21 +608,19 @@ export function createPolicyUpdater(
|
||||
}
|
||||
|
||||
// Create new rule object
|
||||
const newRule: TomlRule = {};
|
||||
const newRule: TomlRule = {
|
||||
decision: 'allow',
|
||||
priority: getAlwaysAllowPriorityFraction(),
|
||||
};
|
||||
|
||||
if (message.mcpName) {
|
||||
newRule.mcpName = message.mcpName;
|
||||
// Extract simple tool name
|
||||
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
|
||||
newRule.toolName = toolName.startsWith(`${message.mcpName}__`)
|
||||
? toolName.slice(message.mcpName.length + 2)
|
||||
: toolName;
|
||||
newRule.toolName = simpleToolName;
|
||||
newRule.decision = 'allow';
|
||||
newRule.priority = 200;
|
||||
} else {
|
||||
newRule.toolName = toolName;
|
||||
newRule.decision = 'allow';
|
||||
newRule.priority = 100;
|
||||
}
|
||||
|
||||
if (message.commandPrefix) {
|
||||
|
||||
@@ -4,25 +4,22 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import * as path from 'node:path';
|
||||
import { createPolicyUpdater, ALWAYS_ALLOW_PRIORITY } from './config.js';
|
||||
import {
|
||||
createPolicyUpdater,
|
||||
getAlwaysAllowPriorityFraction,
|
||||
} from './config.js';
|
||||
import { PolicyEngine } from './policy-engine.js';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { MessageBusType } from '../confirmation-bus/types.js';
|
||||
import { Storage, AUTO_SAVED_POLICY_FILENAME } from '../config/storage.js';
|
||||
import { ApprovalMode } from './types.js';
|
||||
import { vol, fs as memfs } from 'memfs';
|
||||
|
||||
// Use memfs for all fs operations in this test
|
||||
vi.mock('node:fs/promises', () => import('memfs').then((m) => m.fs.promises));
|
||||
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('../config/storage.js');
|
||||
|
||||
describe('createPolicyUpdater', () => {
|
||||
@@ -31,6 +28,8 @@ describe('createPolicyUpdater', () => {
|
||||
let mockStorage: Storage;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
vol.reset();
|
||||
policyEngine = new PolicyEngine({
|
||||
rules: [],
|
||||
checkers: [],
|
||||
@@ -43,202 +42,184 @@ describe('createPolicyUpdater', () => {
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should persist policy when persist flag is true', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const userPoliciesDir = '/mock/user/.gemini/policies';
|
||||
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
); // Simulate new file
|
||||
|
||||
const mockFileHandle = {
|
||||
writeFile: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const toolName = 'test_tool';
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
toolName: 'test_tool',
|
||||
persist: true,
|
||||
});
|
||||
|
||||
// Wait for async operations (microtasks)
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
// Policy updater handles persistence asynchronously in a promise queue.
|
||||
// We use advanceTimersByTimeAsync to yield to the microtask queue.
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
expect(fs.mkdir).toHaveBeenCalledWith(userPoliciesDir, {
|
||||
recursive: true,
|
||||
});
|
||||
const fileExists = memfs.existsSync(policyFile);
|
||||
expect(fileExists).toBe(true);
|
||||
|
||||
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
|
||||
|
||||
// Check written content
|
||||
const expectedContent = expect.stringContaining(`toolName = "test_tool"`);
|
||||
expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
|
||||
expectedContent,
|
||||
'utf-8',
|
||||
);
|
||||
expect(fs.rename).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\.tmp$/),
|
||||
policyFile,
|
||||
);
|
||||
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
expect(content).toContain('toolName = "test_tool"');
|
||||
expect(content).toContain('decision = "allow"');
|
||||
const expectedPriority = getAlwaysAllowPriorityFraction();
|
||||
expect(content).toContain(`priority = ${expectedPriority}`);
|
||||
});
|
||||
|
||||
it('should not persist policy when persist flag is false or undefined', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: 'test_tool',
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
expect(fs.writeFile).not.toHaveBeenCalled();
|
||||
expect(fs.rename).not.toHaveBeenCalled();
|
||||
expect(memfs.existsSync(policyFile)).toBe(false);
|
||||
});
|
||||
|
||||
it('should persist policy with commandPrefix when provided', async () => {
|
||||
it('should append to existing policy file', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const userPoliciesDir = '/mock/user/.gemini/policies';
|
||||
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
|
||||
const mockFileHandle = {
|
||||
writeFile: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const toolName = 'run_shell_command';
|
||||
const commandPrefix = 'git status';
|
||||
const existingContent =
|
||||
'[[rule]]\ntoolName = "existing_tool"\ndecision = "allow"\n';
|
||||
const dir = path.dirname(policyFile);
|
||||
memfs.mkdirSync(dir, { recursive: true });
|
||||
memfs.writeFileSync(policyFile, existingContent);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
toolName: 'new_tool',
|
||||
persist: true,
|
||||
commandPrefix,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
// In-memory rule check (unchanged)
|
||||
const rules = policyEngine.getRules();
|
||||
const addedRule = rules.find((r) => r.toolName === toolName);
|
||||
expect(addedRule).toBeDefined();
|
||||
expect(addedRule?.priority).toBe(ALWAYS_ALLOW_PRIORITY);
|
||||
expect(addedRule?.argsPattern).toEqual(
|
||||
new RegExp(`"command":"git\\ status(?:[\\s"]|\\\\")`),
|
||||
);
|
||||
|
||||
// Verify file written
|
||||
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
|
||||
expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
|
||||
expect.stringContaining(`commandPrefix = "git status"`),
|
||||
'utf-8',
|
||||
);
|
||||
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
expect(content).toContain('toolName = "existing_tool"');
|
||||
expect(content).toContain('toolName = "new_tool"');
|
||||
});
|
||||
|
||||
it('should persist policy with mcpName and toolName when provided', async () => {
|
||||
it('should handle toml with multiple rules correctly', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const userPoliciesDir = '/mock/user/.gemini/policies';
|
||||
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
|
||||
const mockFileHandle = {
|
||||
writeFile: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
const existingContent = `
|
||||
[[rule]]
|
||||
toolName = "tool1"
|
||||
decision = "allow"
|
||||
|
||||
const mcpName = 'my-jira-server';
|
||||
const simpleToolName = 'search';
|
||||
const toolName = `${mcpName}__${simpleToolName}`;
|
||||
[[rule]]
|
||||
toolName = "tool2"
|
||||
decision = "deny"
|
||||
`;
|
||||
const dir = path.dirname(policyFile);
|
||||
memfs.mkdirSync(dir, { recursive: true });
|
||||
memfs.writeFileSync(policyFile, existingContent);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
toolName: 'tool3',
|
||||
persist: true,
|
||||
mcpName,
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
// Verify file written
|
||||
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
|
||||
const writeCall = mockFileHandle.writeFile.mock.calls[0];
|
||||
const writtenContent = writeCall[0] as string;
|
||||
expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
|
||||
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
|
||||
expect(writtenContent).toContain('priority = 200');
|
||||
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
expect(content).toContain('toolName = "tool1"');
|
||||
expect(content).toContain('toolName = "tool2"');
|
||||
expect(content).toContain('toolName = "tool3"');
|
||||
});
|
||||
|
||||
it('should escape special characters in toolName and mcpName', async () => {
|
||||
it('should include argsPattern if provided', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const userPoliciesDir = '/mock/user/.gemini/policies';
|
||||
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
|
||||
(fs.readFile as unknown as Mock).mockRejectedValue(
|
||||
new Error('File not found'),
|
||||
);
|
||||
|
||||
const mockFileHandle = {
|
||||
writeFile: vi.fn().mockResolvedValue(undefined),
|
||||
close: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
|
||||
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
|
||||
|
||||
const mcpName = 'my"jira"server';
|
||||
const toolName = `my"jira"server__search"tool"`;
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName,
|
||||
toolName: 'test_tool',
|
||||
persist: true,
|
||||
mcpName,
|
||||
argsPattern: '^foo.*$',
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
|
||||
const writeCall = mockFileHandle.writeFile.mock.calls[0];
|
||||
const writtenContent = writeCall[0] as string;
|
||||
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
expect(content).toContain('argsPattern = "^foo.*$"');
|
||||
});
|
||||
|
||||
// Verify escaping - should be valid TOML
|
||||
it('should include mcpName if provided', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
|
||||
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: 'search"tool"',
|
||||
persist: true,
|
||||
mcpName: 'my"jira"server',
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
const writtenContent = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
|
||||
// Verify escaping - should be valid TOML and contain the values
|
||||
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'
|
||||
// instead of "foo\"bar\"" if there are no single quotes in the string.
|
||||
try {
|
||||
expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`);
|
||||
expect(writtenContent).toContain('mcpName = "my\\"jira\\"server"');
|
||||
} catch {
|
||||
expect(writtenContent).toContain(`mcpName = 'my"jira"server'`);
|
||||
expect(writtenContent).toContain('mcpName = \'my"jira"server\'');
|
||||
}
|
||||
|
||||
try {
|
||||
expect(writtenContent).toContain(`toolName = "search\\"tool\\""`);
|
||||
expect(writtenContent).toContain('toolName = "search\\"tool\\""');
|
||||
} catch {
|
||||
expect(writtenContent).toContain(`toolName = 'search"tool"'`);
|
||||
expect(writtenContent).toContain('toolName = \'search"tool"\'');
|
||||
}
|
||||
});
|
||||
|
||||
it('should persist to workspace when persistScope is workspace', async () => {
|
||||
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
||||
|
||||
const workspacePoliciesDir = '/mock/project/.gemini/policies';
|
||||
const policyFile = path.join(
|
||||
workspacePoliciesDir,
|
||||
AUTO_SAVED_POLICY_FILENAME,
|
||||
);
|
||||
vi.spyOn(mockStorage, 'getWorkspaceAutoSavedPolicyPath').mockReturnValue(
|
||||
policyFile,
|
||||
);
|
||||
|
||||
await messageBus.publish({
|
||||
type: MessageBusType.UPDATE_POLICY,
|
||||
toolName: 'test_tool',
|
||||
persist: true,
|
||||
persistScope: 'workspace',
|
||||
});
|
||||
|
||||
await vi.advanceTimersByTimeAsync(100);
|
||||
|
||||
expect(memfs.existsSync(policyFile)).toBe(true);
|
||||
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
|
||||
expect(content).toContain('toolName = "test_tool"');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -19,6 +19,7 @@ import {
|
||||
type PolicyUpdateOptions,
|
||||
} from '../tools/tools.js';
|
||||
import * as shellUtils from '../utils/shell-utils.js';
|
||||
import { escapeRegex } from './utils.js';
|
||||
|
||||
vi.mock('node:fs/promises');
|
||||
vi.mock('../config/storage.js');
|
||||
@@ -75,7 +76,9 @@ describe('createPolicyUpdater', () => {
|
||||
expect.objectContaining({
|
||||
toolName: 'run_shell_command',
|
||||
priority: ALWAYS_ALLOW_PRIORITY,
|
||||
argsPattern: new RegExp('"command":"echo(?:[\\s"]|\\\\")'),
|
||||
argsPattern: new RegExp(
|
||||
escapeRegex('"command":"echo') + '(?:[\\s"]|\\\\")',
|
||||
),
|
||||
}),
|
||||
);
|
||||
expect(policyEngine.addRule).toHaveBeenNthCalledWith(
|
||||
@@ -83,7 +86,9 @@ describe('createPolicyUpdater', () => {
|
||||
expect.objectContaining({
|
||||
toolName: 'run_shell_command',
|
||||
priority: ALWAYS_ALLOW_PRIORITY,
|
||||
argsPattern: new RegExp('"command":"ls(?:[\\s"]|\\\\")'),
|
||||
argsPattern: new RegExp(
|
||||
escapeRegex('"command":"ls') + '(?:[\\s"]|\\\\")',
|
||||
),
|
||||
}),
|
||||
);
|
||||
});
|
||||
@@ -103,7 +108,9 @@ describe('createPolicyUpdater', () => {
|
||||
expect.objectContaining({
|
||||
toolName: 'run_shell_command',
|
||||
priority: ALWAYS_ALLOW_PRIORITY,
|
||||
argsPattern: new RegExp('"command":"git(?:[\\s"]|\\\\")'),
|
||||
argsPattern: new RegExp(
|
||||
escapeRegex('"command":"git') + '(?:[\\s"]|\\\\")',
|
||||
),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { expect, describe, it } from 'vitest';
|
||||
import { escapeRegex, buildArgsPatterns, isSafeRegExp } from './utils.js';
|
||||
|
||||
describe('policy/utils', () => {
|
||||
@@ -43,20 +43,20 @@ describe('policy/utils', () => {
|
||||
});
|
||||
|
||||
it('should return false for invalid regexes', () => {
|
||||
expect(isSafeRegExp('[')).toBe(false);
|
||||
expect(isSafeRegExp('([a-z)')).toBe(false);
|
||||
expect(isSafeRegExp('*')).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for extremely long regexes', () => {
|
||||
expect(isSafeRegExp('a'.repeat(2049))).toBe(false);
|
||||
it('should return false for long regexes', () => {
|
||||
expect(isSafeRegExp('a'.repeat(3000))).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false for nested quantifiers (potential ReDoS)', () => {
|
||||
it('should return false for nested quantifiers (ReDoS heuristic)', () => {
|
||||
expect(isSafeRegExp('(a+)+')).toBe(false);
|
||||
expect(isSafeRegExp('(a+)*')).toBe(false);
|
||||
expect(isSafeRegExp('(a*)+')).toBe(false);
|
||||
expect(isSafeRegExp('(a*)*')).toBe(false);
|
||||
expect(isSafeRegExp('(a|b+)+')).toBe(false);
|
||||
expect(isSafeRegExp('(a|b)*')).toBe(true);
|
||||
expect(isSafeRegExp('(.*)*')).toBe(false);
|
||||
expect(isSafeRegExp('([a-z]+)+')).toBe(false);
|
||||
expect(isSafeRegExp('(.*)+')).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -69,14 +69,14 @@ describe('policy/utils', () => {
|
||||
|
||||
it('should build pattern from a single commandPrefix', () => {
|
||||
const result = buildArgsPatterns(undefined, 'ls', undefined);
|
||||
expect(result).toEqual(['"command":"ls(?:[\\s"]|\\\\")']);
|
||||
expect(result).toEqual(['\\"command\\":\\"ls(?:[\\s"]|\\\\")']);
|
||||
});
|
||||
|
||||
it('should build patterns from an array of commandPrefixes', () => {
|
||||
const result = buildArgsPatterns(undefined, ['ls', 'cd'], undefined);
|
||||
const result = buildArgsPatterns(undefined, ['echo', 'ls'], undefined);
|
||||
expect(result).toEqual([
|
||||
'"command":"ls(?:[\\s"]|\\\\")',
|
||||
'"command":"cd(?:[\\s"]|\\\\")',
|
||||
'\\"command\\":\\"echo(?:[\\s"]|\\\\")',
|
||||
'\\"command\\":\\"ls(?:[\\s"]|\\\\")',
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -87,7 +87,7 @@ describe('policy/utils', () => {
|
||||
|
||||
it('should prioritize commandPrefix over commandRegex and argsPattern', () => {
|
||||
const result = buildArgsPatterns('raw', 'prefix', 'regex');
|
||||
expect(result).toEqual(['"command":"prefix(?:[\\s"]|\\\\")']);
|
||||
expect(result).toEqual(['\\"command\\":\\"prefix(?:[\\s"]|\\\\")']);
|
||||
});
|
||||
|
||||
it('should prioritize commandRegex over argsPattern if no commandPrefix', () => {
|
||||
@@ -98,14 +98,15 @@ describe('policy/utils', () => {
|
||||
it('should escape characters in commandPrefix', () => {
|
||||
const result = buildArgsPatterns(undefined, 'git checkout -b', undefined);
|
||||
expect(result).toEqual([
|
||||
'"command":"git\\ checkout\\ \\-b(?:[\\s"]|\\\\")',
|
||||
'\\"command\\":\\"git\\ checkout\\ \\-b(?:[\\s"]|\\\\")',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should correctly escape quotes in commandPrefix', () => {
|
||||
const result = buildArgsPatterns(undefined, 'git "fix"', undefined);
|
||||
expect(result).toEqual([
|
||||
'"command":"git\\ \\\\\\"fix\\\\\\"(?:[\\s"]|\\\\")',
|
||||
// eslint-disable-next-line no-useless-escape
|
||||
'\\\"command\\\":\\\"git\\ \\\\\\\"fix\\\\\\\"(?:[\\s\"]|\\\\\")',
|
||||
]);
|
||||
});
|
||||
|
||||
@@ -142,7 +143,7 @@ describe('policy/utils', () => {
|
||||
const gitRegex = new RegExp(gitPatterns[0]!);
|
||||
// git\status -> {"command":"git\\status"}
|
||||
const gitAttack = '{"command":"git\\\\status"}';
|
||||
expect(gitRegex.test(gitAttack)).toBe(false);
|
||||
expect(gitAttack).not.toMatch(gitRegex);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -63,16 +63,22 @@ export function buildArgsPatterns(
|
||||
? commandPrefix
|
||||
: [commandPrefix];
|
||||
|
||||
// Expand command prefixes to multiple patterns.
|
||||
// We append [\\s"] to ensure we match whole words only (e.g., "git" but not
|
||||
// "github"). Since we match against JSON stringified args, the value is
|
||||
// always followed by a space or a closing quote.
|
||||
return prefixes.map((prefix) => {
|
||||
const jsonPrefix = JSON.stringify(prefix).slice(1, -1);
|
||||
// JSON.stringify safely encodes the prefix in quotes.
|
||||
// We remove ONLY the trailing quote to match it as an open prefix string.
|
||||
const encodedPrefix = JSON.stringify(prefix);
|
||||
const openQuotePrefix = encodedPrefix.substring(
|
||||
0,
|
||||
encodedPrefix.length - 1,
|
||||
);
|
||||
|
||||
// Escape the exact JSON literal segment we expect to see
|
||||
const matchSegment = escapeRegex(`"command":${openQuotePrefix}`);
|
||||
|
||||
// We allow [\s], ["], or the specific sequence [\"] (for escaped quotes
|
||||
// in JSON). We do NOT allow generic [\\], which would match "git\status"
|
||||
// -> "gitstatus".
|
||||
return `"command":"${escapeRegex(jsonPrefix)}(?:[\\s"]|\\\\")`;
|
||||
return `${matchSegment}(?:[\\s"]|\\\\")`;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -82,3 +88,30 @@ export function buildArgsPatterns(
|
||||
|
||||
return [argsPattern];
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a regex pattern to match a specific file path in tool arguments.
|
||||
* This is used to narrow tool approvals for edit tools to specific files.
|
||||
*
|
||||
* @param filePath The relative path to the file.
|
||||
* @returns A regex string that matches "file_path":"<path>" in a JSON string.
|
||||
*/
|
||||
export function buildFilePathArgsPattern(filePath: string): string {
|
||||
// JSON.stringify safely encodes the path (handling quotes, backslashes, etc)
|
||||
// and wraps it in double quotes. We simply prepend the key name and escape
|
||||
// the entire sequence for Regex matching without any slicing.
|
||||
const encodedPath = JSON.stringify(filePath);
|
||||
return escapeRegex(`"file_path":${encodedPath}`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds a regex pattern to match a specific "pattern" in tool arguments.
|
||||
* This is used to narrow tool approvals for search tools like glob/grep to specific patterns.
|
||||
*
|
||||
* @param pattern The pattern to match.
|
||||
* @returns A regex string that matches "pattern":"<pattern>" in a JSON string.
|
||||
*/
|
||||
export function buildPatternArgsPattern(pattern: string): string {
|
||||
const encodedPattern = JSON.stringify(pattern);
|
||||
return escapeRegex(`"pattern":${encodedPattern}`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user