fix(core): prevent race condition in policy persistence (#18506)

Co-authored-by: Allen Hutchison <adh@google.com>
This commit is contained in:
Brad Dux
2026-02-10 15:35:09 -08:00
committed by GitHub
parent be2ebd1772
commit 6d3fff2ea4
6 changed files with 256 additions and 86 deletions

View File

@@ -6,6 +6,7 @@
import * as fs from 'node:fs/promises'; import * as fs from 'node:fs/promises';
import * as path from 'node:path'; import * as path from 'node:path';
import * as crypto from 'node:crypto';
import { fileURLToPath } from 'node:url'; import { fileURLToPath } from 'node:url';
import { Storage } from '../config/storage.js'; import { Storage } from '../config/storage.js';
import { import {
@@ -17,7 +18,7 @@ import {
} from './types.js'; } from './types.js';
import type { PolicyEngine } from './policy-engine.js'; import type { PolicyEngine } from './policy-engine.js';
import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js'; import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js';
import { buildArgsPatterns } from './utils.js'; import { buildArgsPatterns, isSafeRegExp } from './utils.js';
import toml from '@iarna/toml'; import toml from '@iarna/toml';
import { import {
MessageBusType, MessageBusType,
@@ -331,6 +332,9 @@ export function createPolicyUpdater(
policyEngine: PolicyEngine, policyEngine: PolicyEngine,
messageBus: MessageBus, messageBus: MessageBus,
) { ) {
// Use a sequential queue for persistence to avoid lost updates from concurrent events.
let persistenceQueue = Promise.resolve();
messageBus.subscribe( messageBus.subscribe(
MessageBusType.UPDATE_POLICY, MessageBusType.UPDATE_POLICY,
async (message: UpdatePolicy) => { async (message: UpdatePolicy) => {
@@ -341,6 +345,8 @@ export function createPolicyUpdater(
const patterns = buildArgsPatterns(undefined, message.commandPrefix); const patterns = buildArgsPatterns(undefined, message.commandPrefix);
for (const pattern of patterns) { for (const pattern of patterns) {
if (pattern) { if (pattern) {
// Note: patterns from buildArgsPatterns are derived from escapeRegex,
// which is safe and won't contain ReDoS patterns.
policyEngine.addRule({ policyEngine.addRule({
toolName, toolName,
decision: PolicyDecision.ALLOW, decision: PolicyDecision.ALLOW,
@@ -354,6 +360,14 @@ export function createPolicyUpdater(
} }
} }
} else { } else {
if (message.argsPattern && !isSafeRegExp(message.argsPattern)) {
coreEvents.emitFeedback(
'error',
`Invalid or unsafe regular expression for tool ${toolName}: ${message.argsPattern}`,
);
return;
}
const argsPattern = message.argsPattern const argsPattern = message.argsPattern
? new RegExp(message.argsPattern) ? new RegExp(message.argsPattern)
: undefined; : undefined;
@@ -371,74 +385,88 @@ export function createPolicyUpdater(
} }
if (message.persist) { if (message.persist) {
try { persistenceQueue = persistenceQueue.then(async () => {
const userPoliciesDir = Storage.getUserPoliciesDir();
await fs.mkdir(userPoliciesDir, { recursive: true });
const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');
// Read existing file
let existingData: { rule?: TomlRule[] } = {};
try { try {
const fileContent = await fs.readFile(policyFile, 'utf-8'); const userPoliciesDir = Storage.getUserPoliciesDir();
existingData = toml.parse(fileContent) as { rule?: TomlRule[] }; await fs.mkdir(userPoliciesDir, { recursive: true });
} catch (error) { const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { // Read existing file
debugLogger.warn( let existingData: { rule?: TomlRule[] } = {};
`Failed to parse ${policyFile}, overwriting with new policy.`, try {
error, const fileContent = await fs.readFile(policyFile, 'utf-8');
); existingData = toml.parse(fileContent) as { rule?: TomlRule[] };
} catch (error) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
debugLogger.warn(
`Failed to parse ${policyFile}, overwriting with new policy.`,
error,
);
}
} }
// Initialize rule array if needed
if (!existingData.rule) {
existingData.rule = [];
}
// Create new rule object
const newRule: TomlRule = {};
if (message.mcpName) {
newRule.mcpName = message.mcpName;
// Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2)
: toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else {
newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
}
if (message.commandPrefix) {
newRule.commandPrefix = message.commandPrefix;
} else if (message.argsPattern) {
// message.argsPattern was already validated above
newRule.argsPattern = message.argsPattern;
}
// Add to rules
existingData.rule.push(newRule);
// Serialize back to TOML
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const newContent = toml.stringify(existingData as toml.JsonMap);
// Atomic write: write to a unique tmp file then rename to the target file.
// Using a unique suffix avoids race conditions where concurrent processes
// overwrite each other's temporary files, leading to ENOENT errors on rename.
const tmpSuffix = crypto.randomBytes(8).toString('hex');
const tmpFile = `${policyFile}.${tmpSuffix}.tmp`;
let handle: fs.FileHandle | undefined;
try {
// Use 'wx' to create the file exclusively (fails if exists) for security.
handle = await fs.open(tmpFile, 'wx');
await handle.writeFile(newContent, 'utf-8');
} finally {
await handle?.close();
}
await fs.rename(tmpFile, policyFile);
} catch (error) {
coreEvents.emitFeedback(
'error',
`Failed to persist policy for ${toolName}`,
error,
);
} }
});
// Initialize rule array if needed
if (!existingData.rule) {
existingData.rule = [];
}
// Create new rule object
const newRule: TomlRule = {};
if (message.mcpName) {
newRule.mcpName = message.mcpName;
// Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2)
: toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else {
newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
}
if (message.commandPrefix) {
newRule.commandPrefix = message.commandPrefix;
} else if (message.argsPattern) {
newRule.argsPattern = message.argsPattern;
}
// Add to rules
existingData.rule.push(newRule);
// Serialize back to TOML
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const newContent = toml.stringify(existingData as toml.JsonMap);
// Atomic write: write to tmp then rename
const tmpFile = `${policyFile}.tmp`;
await fs.writeFile(tmpFile, newContent, 'utf-8');
await fs.rename(tmpFile, policyFile);
} catch (error) {
coreEvents.emitFeedback(
'error',
`Failed to persist policy for ${toolName}`,
error,
);
}
} }
}, },
); );

View File

@@ -52,7 +52,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue( (fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'), new Error('File not found'),
); // Simulate new file ); // Simulate new file
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined); (fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'test_tool'; const toolName = 'test_tool';
@@ -70,10 +75,11 @@ describe('createPolicyUpdater', () => {
recursive: true, recursive: true,
}); });
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
// Check written content // Check written content
const expectedContent = expect.stringContaining(`toolName = "test_tool"`); const expectedContent = expect.stringContaining(`toolName = "test_tool"`);
expect(fs.writeFile).toHaveBeenCalledWith( expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
expectedContent, expectedContent,
'utf-8', 'utf-8',
); );
@@ -106,7 +112,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue( (fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'), new Error('File not found'),
); );
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined); (fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'run_shell_command'; const toolName = 'run_shell_command';
@@ -131,8 +142,8 @@ describe('createPolicyUpdater', () => {
); );
// Verify file written // Verify file written
expect(fs.writeFile).toHaveBeenCalledWith( expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
expect.stringMatching(/\.tmp$/), expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
expect.stringContaining(`commandPrefix = "git status"`), expect.stringContaining(`commandPrefix = "git status"`),
'utf-8', 'utf-8',
); );
@@ -147,7 +158,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue( (fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'), new Error('File not found'),
); );
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined); (fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my-jira-server'; const mcpName = 'my-jira-server';
@@ -164,8 +180,9 @@ describe('createPolicyUpdater', () => {
await new Promise((resolve) => setTimeout(resolve, 0)); await new Promise((resolve) => setTimeout(resolve, 0));
// Verify file written // Verify file written
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
const writtenContent = writeCall[1] as string; const writeCall = mockFileHandle.writeFile.mock.calls[0];
const writtenContent = writeCall[0] as string;
expect(writtenContent).toContain(`mcpName = "${mcpName}"`); expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`); expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
expect(writtenContent).toContain('priority = 200'); expect(writtenContent).toContain('priority = 200');
@@ -180,7 +197,12 @@ describe('createPolicyUpdater', () => {
(fs.readFile as unknown as Mock).mockRejectedValue( (fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'), new Error('File not found'),
); );
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined); (fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my"jira"server'; const mcpName = 'my"jira"server';
@@ -195,8 +217,9 @@ describe('createPolicyUpdater', () => {
await new Promise((resolve) => setTimeout(resolve, 0)); await new Promise((resolve) => setTimeout(resolve, 0));
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
const writtenContent = writeCall[1] as string; const writeCall = mockFileHandle.writeFile.mock.calls[0];
const writtenContent = writeCall[0] as string;
// Verify escaping - should be valid TOML // Verify escaping - should be valid TOML
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar' // Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'

View File

@@ -107,7 +107,14 @@ describe('createPolicyUpdater', () => {
createPolicyUpdater(policyEngine, messageBus); createPolicyUpdater(policyEngine, messageBus);
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' }); vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
vi.mocked(fs.mkdir).mockResolvedValue(undefined); vi.mocked(fs.mkdir).mockResolvedValue(undefined);
vi.mocked(fs.writeFile).mockResolvedValue(undefined);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
vi.mocked(fs.open).mockResolvedValue(
mockFileHandle as unknown as fs.FileHandle,
);
vi.mocked(fs.rename).mockResolvedValue(undefined); vi.mocked(fs.rename).mockResolvedValue(undefined);
await messageBus.publish({ await messageBus.publish({
@@ -120,8 +127,8 @@ describe('createPolicyUpdater', () => {
// Wait for the async listener to complete // Wait for the async listener to complete
await new Promise((resolve) => setTimeout(resolve, 0)); await new Promise((resolve) => setTimeout(resolve, 0));
expect(fs.writeFile).toHaveBeenCalled(); expect(fs.open).toHaveBeenCalled();
const [_path, content] = vi.mocked(fs.writeFile).mock.calls[0] as [ const [content] = mockFileHandle.writeFile.mock.calls[0] as [
string, string,
string, string,
]; ];
@@ -130,6 +137,19 @@ describe('createPolicyUpdater', () => {
expect(parsed.rule).toHaveLength(1); expect(parsed.rule).toHaveLength(1);
expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']); expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']);
}); });
it('should reject unsafe regex patterns', async () => {
createPolicyUpdater(policyEngine, messageBus);
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'test_tool',
argsPattern: '(a+)+',
persist: false,
});
expect(policyEngine.addRule).not.toHaveBeenCalled();
});
}); });
describe('ShellToolInvocation Policy Update', () => { describe('ShellToolInvocation Policy Update', () => {

View File

@@ -12,7 +12,7 @@ import {
type SafetyCheckerRule, type SafetyCheckerRule,
InProcessCheckerType, InProcessCheckerType,
} from './types.js'; } from './types.js';
import { buildArgsPatterns } from './utils.js'; import { buildArgsPatterns, isSafeRegExp } from './utils.js';
import fs from 'node:fs/promises'; import fs from 'node:fs/promises';
import path from 'node:path'; import path from 'node:path';
import toml from '@iarna/toml'; import toml from '@iarna/toml';
@@ -356,7 +356,7 @@ export async function loadPoliciesFromToml(
// Compile regex pattern // Compile regex pattern
if (argsPattern) { if (argsPattern) {
try { try {
policyRule.argsPattern = new RegExp(argsPattern); new RegExp(argsPattern);
} catch (e) { } catch (e) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const error = e as Error; const error = e as Error;
@@ -370,9 +370,24 @@ export async function loadPoliciesFromToml(
suggestion: suggestion:
'Check regex syntax for errors like unmatched brackets or invalid escape sequences', 'Check regex syntax for errors like unmatched brackets or invalid escape sequences',
}); });
// Skip this rule if regex compilation fails
return null; return null;
} }
if (!isSafeRegExp(argsPattern)) {
errors.push({
filePath,
fileName: file,
tier: tierName,
errorType: 'regex_compilation',
message: 'Unsafe regex pattern (potential ReDoS)',
details: `Pattern: ${argsPattern}`,
suggestion:
'Avoid nested quantifiers or extremely long patterns',
});
return null;
}
policyRule.argsPattern = new RegExp(argsPattern);
} }
return policyRule; return policyRule;
@@ -421,7 +436,7 @@ export async function loadPoliciesFromToml(
if (argsPattern) { if (argsPattern) {
try { try {
safetyCheckerRule.argsPattern = new RegExp(argsPattern); new RegExp(argsPattern);
} catch (e) { } catch (e) {
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const error = e as Error; const error = e as Error;
@@ -435,6 +450,21 @@ export async function loadPoliciesFromToml(
}); });
return null; return null;
} }
if (!isSafeRegExp(argsPattern)) {
errors.push({
filePath,
fileName: file,
tier: tierName,
errorType: 'regex_compilation',
message:
'Unsafe regex pattern in safety checker (potential ReDoS)',
details: `Pattern: ${argsPattern}`,
});
return null;
}
safetyCheckerRule.argsPattern = new RegExp(argsPattern);
} }
return safetyCheckerRule; return safetyCheckerRule;

View File

@@ -5,7 +5,7 @@
*/ */
import { describe, it, expect } from 'vitest'; import { describe, it, expect } from 'vitest';
import { escapeRegex, buildArgsPatterns } from './utils.js'; import { escapeRegex, buildArgsPatterns, isSafeRegExp } from './utils.js';
describe('policy/utils', () => { describe('policy/utils', () => {
describe('escapeRegex', () => { describe('escapeRegex', () => {
@@ -23,6 +23,44 @@ describe('policy/utils', () => {
}); });
}); });
describe('isSafeRegExp', () => {
it('should return true for simple regexes', () => {
expect(isSafeRegExp('abc')).toBe(true);
expect(isSafeRegExp('^abc$')).toBe(true);
expect(isSafeRegExp('a|b')).toBe(true);
});
it('should return true for safe quantifiers', () => {
expect(isSafeRegExp('a+')).toBe(true);
expect(isSafeRegExp('a*')).toBe(true);
expect(isSafeRegExp('a?')).toBe(true);
expect(isSafeRegExp('a{1,3}')).toBe(true);
});
it('should return true for safe groups', () => {
expect(isSafeRegExp('(abc)*')).toBe(true);
expect(isSafeRegExp('(a|b)+')).toBe(true);
});
it('should return false for invalid regexes', () => {
expect(isSafeRegExp('([a-z)')).toBe(false);
expect(isSafeRegExp('*')).toBe(false);
});
it('should return false for extremely long regexes', () => {
expect(isSafeRegExp('a'.repeat(2049))).toBe(false);
});
it('should return false for nested quantifiers (potential ReDoS)', () => {
expect(isSafeRegExp('(a+)+')).toBe(false);
expect(isSafeRegExp('(a+)*')).toBe(false);
expect(isSafeRegExp('(a*)+')).toBe(false);
expect(isSafeRegExp('(a*)*')).toBe(false);
expect(isSafeRegExp('(a|b+)+')).toBe(false);
expect(isSafeRegExp('(.*)+')).toBe(false);
});
});
describe('buildArgsPatterns', () => { describe('buildArgsPatterns', () => {
it('should return argsPattern if provided and no commandPrefix/regex', () => { it('should return argsPattern if provided and no commandPrefix/regex', () => {
const result = buildArgsPatterns('my-pattern', undefined, undefined); const result = buildArgsPatterns('my-pattern', undefined, undefined);

View File

@@ -11,6 +11,37 @@ export function escapeRegex(text: string): string {
return text.replace(/[-[\]{}()*+?.,\\^$|#\s"]/g, '\\$&'); return text.replace(/[-[\]{}()*+?.,\\^$|#\s"]/g, '\\$&');
} }
/**
* Basic validation for regular expressions to prevent common ReDoS patterns.
* This is a heuristic check and not a substitute for a full ReDoS scanner.
*/
export function isSafeRegExp(pattern: string): boolean {
try {
// 1. Ensure it's a valid regex
new RegExp(pattern);
} catch {
return false;
}
// 2. Limit length to prevent extremely long regexes
if (pattern.length > 2048) {
return false;
}
// 3. Heuristic: Check for nested quantifiers which are a primary source of ReDoS.
// Examples: (a+)+, (a|b)*, (.*)*, ([a-z]+)+
// We look for a group (...) followed by a quantifier (+, *, or {n,m})
// where the group itself contains a quantifier.
// This matches a '(' followed by some content including a quantifier, then ')',
// followed by another quantifier.
const nestedQuantifierPattern = /\([^)]*[*+?{].*\)[*+?{]/;
if (nestedQuantifierPattern.test(pattern)) {
return false;
}
return true;
}
/** /**
* Builds a list of args patterns for policy matching. * Builds a list of args patterns for policy matching.
* *