diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index 78cf1e85ac..ca641d09ea 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -6,6 +6,7 @@ import * as fs from 'node:fs/promises'; import * as path from 'node:path'; +import * as crypto from 'node:crypto'; import { fileURLToPath } from 'node:url'; import { Storage } from '../config/storage.js'; import { @@ -17,7 +18,7 @@ import { } from './types.js'; import type { PolicyEngine } from './policy-engine.js'; import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js'; -import { buildArgsPatterns } from './utils.js'; +import { buildArgsPatterns, isSafeRegExp } from './utils.js'; import toml from '@iarna/toml'; import { MessageBusType, @@ -331,6 +332,9 @@ export function createPolicyUpdater( policyEngine: PolicyEngine, messageBus: MessageBus, ) { + // Use a sequential queue for persistence to avoid lost updates from concurrent events. + let persistenceQueue = Promise.resolve(); + messageBus.subscribe( MessageBusType.UPDATE_POLICY, async (message: UpdatePolicy) => { @@ -341,6 +345,8 @@ export function createPolicyUpdater( const patterns = buildArgsPatterns(undefined, message.commandPrefix); for (const pattern of patterns) { if (pattern) { + // Note: patterns from buildArgsPatterns are derived from escapeRegex, + // which is safe and won't contain ReDoS patterns. policyEngine.addRule({ toolName, decision: PolicyDecision.ALLOW, @@ -354,6 +360,14 @@ export function createPolicyUpdater( } } } else { + if (message.argsPattern && !isSafeRegExp(message.argsPattern)) { + coreEvents.emitFeedback( + 'error', + `Invalid or unsafe regular expression for tool ${toolName}: ${message.argsPattern}`, + ); + return; + } + const argsPattern = message.argsPattern ? new RegExp(message.argsPattern) : undefined; @@ -371,74 +385,88 @@ export function createPolicyUpdater( } if (message.persist) { - try { - const userPoliciesDir = Storage.getUserPoliciesDir(); - await fs.mkdir(userPoliciesDir, { recursive: true }); - const policyFile = path.join(userPoliciesDir, 'auto-saved.toml'); - - // Read existing file - let existingData: { rule?: TomlRule[] } = {}; + persistenceQueue = persistenceQueue.then(async () => { try { - const fileContent = await fs.readFile(policyFile, 'utf-8'); - existingData = toml.parse(fileContent) as { rule?: TomlRule[] }; - } catch (error) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { - debugLogger.warn( - `Failed to parse ${policyFile}, overwriting with new policy.`, - error, - ); + const userPoliciesDir = Storage.getUserPoliciesDir(); + await fs.mkdir(userPoliciesDir, { recursive: true }); + const policyFile = path.join(userPoliciesDir, 'auto-saved.toml'); + + // Read existing file + let existingData: { rule?: TomlRule[] } = {}; + try { + const fileContent = await fs.readFile(policyFile, 'utf-8'); + existingData = toml.parse(fileContent) as { rule?: TomlRule[] }; + } catch (error) { + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { + debugLogger.warn( + `Failed to parse ${policyFile}, overwriting with new policy.`, + error, + ); + } } + + // Initialize rule array if needed + if (!existingData.rule) { + existingData.rule = []; + } + + // Create new rule object + const newRule: TomlRule = {}; + + if (message.mcpName) { + newRule.mcpName = message.mcpName; + // Extract simple tool name + const simpleToolName = toolName.startsWith(`${message.mcpName}__`) + ? toolName.slice(message.mcpName.length + 2) + : toolName; + newRule.toolName = simpleToolName; + newRule.decision = 'allow'; + newRule.priority = 200; + } else { + newRule.toolName = toolName; + newRule.decision = 'allow'; + newRule.priority = 100; + } + + if (message.commandPrefix) { + newRule.commandPrefix = message.commandPrefix; + } else if (message.argsPattern) { + // message.argsPattern was already validated above + newRule.argsPattern = message.argsPattern; + } + + // Add to rules + existingData.rule.push(newRule); + + // Serialize back to TOML + // @iarna/toml stringify might not produce beautiful output but it handles escaping correctly + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const newContent = toml.stringify(existingData as toml.JsonMap); + + // Atomic write: write to a unique tmp file then rename to the target file. + // Using a unique suffix avoids race conditions where concurrent processes + // overwrite each other's temporary files, leading to ENOENT errors on rename. + const tmpSuffix = crypto.randomBytes(8).toString('hex'); + const tmpFile = `${policyFile}.${tmpSuffix}.tmp`; + + let handle: fs.FileHandle | undefined; + try { + // Use 'wx' to create the file exclusively (fails if exists) for security. + handle = await fs.open(tmpFile, 'wx'); + await handle.writeFile(newContent, 'utf-8'); + } finally { + await handle?.close(); + } + await fs.rename(tmpFile, policyFile); + } catch (error) { + coreEvents.emitFeedback( + 'error', + `Failed to persist policy for ${toolName}`, + error, + ); } - - // Initialize rule array if needed - if (!existingData.rule) { - existingData.rule = []; - } - - // Create new rule object - const newRule: TomlRule = {}; - - if (message.mcpName) { - newRule.mcpName = message.mcpName; - // Extract simple tool name - const simpleToolName = toolName.startsWith(`${message.mcpName}__`) - ? toolName.slice(message.mcpName.length + 2) - : toolName; - newRule.toolName = simpleToolName; - newRule.decision = 'allow'; - newRule.priority = 200; - } else { - newRule.toolName = toolName; - newRule.decision = 'allow'; - newRule.priority = 100; - } - - if (message.commandPrefix) { - newRule.commandPrefix = message.commandPrefix; - } else if (message.argsPattern) { - newRule.argsPattern = message.argsPattern; - } - - // Add to rules - existingData.rule.push(newRule); - - // Serialize back to TOML - // @iarna/toml stringify might not produce beautiful output but it handles escaping correctly - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const newContent = toml.stringify(existingData as toml.JsonMap); - - // Atomic write: write to tmp then rename - const tmpFile = `${policyFile}.tmp`; - await fs.writeFile(tmpFile, newContent, 'utf-8'); - await fs.rename(tmpFile, policyFile); - } catch (error) { - coreEvents.emitFeedback( - 'error', - `Failed to persist policy for ${toolName}`, - error, - ); - } + }); } }, ); diff --git a/packages/core/src/policy/persistence.test.ts b/packages/core/src/policy/persistence.test.ts index 22f00ac9a8..7d80b41893 100644 --- a/packages/core/src/policy/persistence.test.ts +++ b/packages/core/src/policy/persistence.test.ts @@ -52,7 +52,12 @@ describe('createPolicyUpdater', () => { (fs.readFile as unknown as Mock).mockRejectedValue( new Error('File not found'), ); // Simulate new file - (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); (fs.rename as unknown as Mock).mockResolvedValue(undefined); const toolName = 'test_tool'; @@ -70,10 +75,11 @@ describe('createPolicyUpdater', () => { recursive: true, }); + expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); + // Check written content const expectedContent = expect.stringContaining(`toolName = "test_tool"`); - expect(fs.writeFile).toHaveBeenCalledWith( - expect.stringMatching(/\.tmp$/), + expect(mockFileHandle.writeFile).toHaveBeenCalledWith( expectedContent, 'utf-8', ); @@ -106,7 +112,12 @@ describe('createPolicyUpdater', () => { (fs.readFile as unknown as Mock).mockRejectedValue( new Error('File not found'), ); - (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); (fs.rename as unknown as Mock).mockResolvedValue(undefined); const toolName = 'run_shell_command'; @@ -131,8 +142,8 @@ describe('createPolicyUpdater', () => { ); // Verify file written - expect(fs.writeFile).toHaveBeenCalledWith( - expect.stringMatching(/\.tmp$/), + expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); + expect(mockFileHandle.writeFile).toHaveBeenCalledWith( expect.stringContaining(`commandPrefix = "git status"`), 'utf-8', ); @@ -147,7 +158,12 @@ describe('createPolicyUpdater', () => { (fs.readFile as unknown as Mock).mockRejectedValue( new Error('File not found'), ); - (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); (fs.rename as unknown as Mock).mockResolvedValue(undefined); const mcpName = 'my-jira-server'; @@ -164,8 +180,9 @@ describe('createPolicyUpdater', () => { await new Promise((resolve) => setTimeout(resolve, 0)); // Verify file written - const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; - const writtenContent = writeCall[1] as string; + expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); + const writeCall = mockFileHandle.writeFile.mock.calls[0]; + const writtenContent = writeCall[0] as string; expect(writtenContent).toContain(`mcpName = "${mcpName}"`); expect(writtenContent).toContain(`toolName = "${simpleToolName}"`); expect(writtenContent).toContain('priority = 200'); @@ -180,7 +197,12 @@ describe('createPolicyUpdater', () => { (fs.readFile as unknown as Mock).mockRejectedValue( new Error('File not found'), ); - (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); (fs.rename as unknown as Mock).mockResolvedValue(undefined); const mcpName = 'my"jira"server'; @@ -195,8 +217,9 @@ describe('createPolicyUpdater', () => { await new Promise((resolve) => setTimeout(resolve, 0)); - const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; - const writtenContent = writeCall[1] as string; + expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); + const writeCall = mockFileHandle.writeFile.mock.calls[0]; + const writtenContent = writeCall[0] as string; // Verify escaping - should be valid TOML // Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar' diff --git a/packages/core/src/policy/policy-updater.test.ts b/packages/core/src/policy/policy-updater.test.ts index aa6b7ac887..928d84408b 100644 --- a/packages/core/src/policy/policy-updater.test.ts +++ b/packages/core/src/policy/policy-updater.test.ts @@ -107,7 +107,14 @@ describe('createPolicyUpdater', () => { createPolicyUpdater(policyEngine, messageBus); vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.mkdir).mockResolvedValue(undefined); - vi.mocked(fs.writeFile).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + vi.mocked(fs.open).mockResolvedValue( + mockFileHandle as unknown as fs.FileHandle, + ); vi.mocked(fs.rename).mockResolvedValue(undefined); await messageBus.publish({ @@ -120,8 +127,8 @@ describe('createPolicyUpdater', () => { // Wait for the async listener to complete await new Promise((resolve) => setTimeout(resolve, 0)); - expect(fs.writeFile).toHaveBeenCalled(); - const [_path, content] = vi.mocked(fs.writeFile).mock.calls[0] as [ + expect(fs.open).toHaveBeenCalled(); + const [content] = mockFileHandle.writeFile.mock.calls[0] as [ string, string, ]; @@ -130,6 +137,19 @@ describe('createPolicyUpdater', () => { expect(parsed.rule).toHaveLength(1); expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']); }); + + it('should reject unsafe regex patterns', async () => { + createPolicyUpdater(policyEngine, messageBus); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'test_tool', + argsPattern: '(a+)+', + persist: false, + }); + + expect(policyEngine.addRule).not.toHaveBeenCalled(); + }); }); describe('ShellToolInvocation Policy Update', () => { diff --git a/packages/core/src/policy/toml-loader.ts b/packages/core/src/policy/toml-loader.ts index df3bc4e9ba..67fcacce75 100644 --- a/packages/core/src/policy/toml-loader.ts +++ b/packages/core/src/policy/toml-loader.ts @@ -12,7 +12,7 @@ import { type SafetyCheckerRule, InProcessCheckerType, } from './types.js'; -import { buildArgsPatterns } from './utils.js'; +import { buildArgsPatterns, isSafeRegExp } from './utils.js'; import fs from 'node:fs/promises'; import path from 'node:path'; import toml from '@iarna/toml'; @@ -356,7 +356,7 @@ export async function loadPoliciesFromToml( // Compile regex pattern if (argsPattern) { try { - policyRule.argsPattern = new RegExp(argsPattern); + new RegExp(argsPattern); } catch (e) { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const error = e as Error; @@ -370,9 +370,24 @@ export async function loadPoliciesFromToml( suggestion: 'Check regex syntax for errors like unmatched brackets or invalid escape sequences', }); - // Skip this rule if regex compilation fails return null; } + + if (!isSafeRegExp(argsPattern)) { + errors.push({ + filePath, + fileName: file, + tier: tierName, + errorType: 'regex_compilation', + message: 'Unsafe regex pattern (potential ReDoS)', + details: `Pattern: ${argsPattern}`, + suggestion: + 'Avoid nested quantifiers or extremely long patterns', + }); + return null; + } + + policyRule.argsPattern = new RegExp(argsPattern); } return policyRule; @@ -421,7 +436,7 @@ export async function loadPoliciesFromToml( if (argsPattern) { try { - safetyCheckerRule.argsPattern = new RegExp(argsPattern); + new RegExp(argsPattern); } catch (e) { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const error = e as Error; @@ -435,6 +450,21 @@ export async function loadPoliciesFromToml( }); return null; } + + if (!isSafeRegExp(argsPattern)) { + errors.push({ + filePath, + fileName: file, + tier: tierName, + errorType: 'regex_compilation', + message: + 'Unsafe regex pattern in safety checker (potential ReDoS)', + details: `Pattern: ${argsPattern}`, + }); + return null; + } + + safetyCheckerRule.argsPattern = new RegExp(argsPattern); } return safetyCheckerRule; diff --git a/packages/core/src/policy/utils.test.ts b/packages/core/src/policy/utils.test.ts index dfbb8b298c..90f3c632c7 100644 --- a/packages/core/src/policy/utils.test.ts +++ b/packages/core/src/policy/utils.test.ts @@ -5,7 +5,7 @@ */ import { describe, it, expect } from 'vitest'; -import { escapeRegex, buildArgsPatterns } from './utils.js'; +import { escapeRegex, buildArgsPatterns, isSafeRegExp } from './utils.js'; describe('policy/utils', () => { describe('escapeRegex', () => { @@ -23,6 +23,44 @@ describe('policy/utils', () => { }); }); + describe('isSafeRegExp', () => { + it('should return true for simple regexes', () => { + expect(isSafeRegExp('abc')).toBe(true); + expect(isSafeRegExp('^abc$')).toBe(true); + expect(isSafeRegExp('a|b')).toBe(true); + }); + + it('should return true for safe quantifiers', () => { + expect(isSafeRegExp('a+')).toBe(true); + expect(isSafeRegExp('a*')).toBe(true); + expect(isSafeRegExp('a?')).toBe(true); + expect(isSafeRegExp('a{1,3}')).toBe(true); + }); + + it('should return true for safe groups', () => { + expect(isSafeRegExp('(abc)*')).toBe(true); + expect(isSafeRegExp('(a|b)+')).toBe(true); + }); + + it('should return false for invalid regexes', () => { + expect(isSafeRegExp('([a-z)')).toBe(false); + expect(isSafeRegExp('*')).toBe(false); + }); + + it('should return false for extremely long regexes', () => { + expect(isSafeRegExp('a'.repeat(2049))).toBe(false); + }); + + it('should return false for nested quantifiers (potential ReDoS)', () => { + expect(isSafeRegExp('(a+)+')).toBe(false); + expect(isSafeRegExp('(a+)*')).toBe(false); + expect(isSafeRegExp('(a*)+')).toBe(false); + expect(isSafeRegExp('(a*)*')).toBe(false); + expect(isSafeRegExp('(a|b+)+')).toBe(false); + expect(isSafeRegExp('(.*)+')).toBe(false); + }); + }); + describe('buildArgsPatterns', () => { it('should return argsPattern if provided and no commandPrefix/regex', () => { const result = buildArgsPatterns('my-pattern', undefined, undefined); diff --git a/packages/core/src/policy/utils.ts b/packages/core/src/policy/utils.ts index b891a8fda1..3742ba3ed6 100644 --- a/packages/core/src/policy/utils.ts +++ b/packages/core/src/policy/utils.ts @@ -11,6 +11,37 @@ export function escapeRegex(text: string): string { return text.replace(/[-[\]{}()*+?.,\\^$|#\s"]/g, '\\$&'); } +/** + * Basic validation for regular expressions to prevent common ReDoS patterns. + * This is a heuristic check and not a substitute for a full ReDoS scanner. + */ +export function isSafeRegExp(pattern: string): boolean { + try { + // 1. Ensure it's a valid regex + new RegExp(pattern); + } catch { + return false; + } + + // 2. Limit length to prevent extremely long regexes + if (pattern.length > 2048) { + return false; + } + + // 3. Heuristic: Check for nested quantifiers which are a primary source of ReDoS. + // Examples: (a+)+, (a|b)*, (.*)*, ([a-z]+)+ + // We look for a group (...) followed by a quantifier (+, *, or {n,m}) + // where the group itself contains a quantifier. + // This matches a '(' followed by some content including a quantifier, then ')', + // followed by another quantifier. + const nestedQuantifierPattern = /\([^)]*[*+?{].*\)[*+?{]/; + if (nestedQuantifierPattern.test(pattern)) { + return false; + } + + return true; +} + /** * Builds a list of args patterns for policy matching. *