feat(safety): Introduce safety checker framework (#12504)

This commit is contained in:
Allen Hutchison
2025-11-12 13:18:34 -08:00
committed by GitHub
parent aa9922bc98
commit 1ed163a666
21 changed files with 2636 additions and 328 deletions
+122 -6
View File
@@ -4,7 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { type PolicyRule, PolicyDecision, type ApprovalMode } from './types.js';
import {
type PolicyRule,
PolicyDecision,
type ApprovalMode,
type SafetyCheckerConfig,
type SafetyCheckerRule,
InProcessCheckerType,
} from './types.js';
import fs from 'node:fs/promises';
import path from 'node:path';
import toml from '@iarna/toml';
@@ -39,11 +46,39 @@ const PolicyRuleSchema = z.object({
modes: z.array(z.string()).optional(),
});
/**
* Schema for a single safety checker rule in the TOML file.
*/
const SafetyCheckerRuleSchema = z.object({
toolName: z.union([z.string(), z.array(z.string())]).optional(),
mcpName: z.string().optional(),
argsPattern: z.string().optional(),
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
commandRegex: z.string().optional(),
priority: z.number().int().default(0),
modes: z.array(z.string()).optional(),
checker: z.discriminatedUnion('type', [
z.object({
type: z.literal('in-process'),
name: z.nativeEnum(InProcessCheckerType),
required_context: z.array(z.string()).optional(),
config: z.record(z.unknown()).optional(),
}),
z.object({
type: z.literal('external'),
name: z.string(),
required_context: z.array(z.string()).optional(),
config: z.record(z.unknown()).optional(),
}),
]),
});
/**
* Schema for the entire policy TOML file.
*/
const PolicyFileSchema = z.object({
rule: z.array(PolicyRuleSchema),
rule: z.array(PolicyRuleSchema).optional(),
safety_checker: z.array(SafetyCheckerRuleSchema).optional(),
});
/**
@@ -80,6 +115,7 @@ export interface PolicyFileError {
*/
export interface PolicyLoadResult {
rules: PolicyRule[];
checkers: SafetyCheckerRule[];
errors: PolicyFileError[];
}
@@ -194,6 +230,7 @@ export async function loadPoliciesFromToml(
getPolicyTier: (dir: string) => number,
): Promise<PolicyLoadResult> {
const rules: PolicyRule[] = [];
const checkers: SafetyCheckerRule[] = [];
const errors: PolicyFileError[] = [];
for (const dir of policyDirs) {
@@ -267,8 +304,9 @@ export async function loadPoliciesFromToml(
}
// Validate shell command convenience syntax
for (let i = 0; i < validationResult.data.rule.length; i++) {
const rule = validationResult.data.rule[i];
const tomlRules = validationResult.data.rule ?? [];
for (let i = 0; i < tomlRules.length; i++) {
const rule = tomlRules[i];
const validationError = validateShellCommandSyntax(rule, i);
if (validationError) {
errors.push({
@@ -285,7 +323,7 @@ export async function loadPoliciesFromToml(
}
// Transform rules
const parsedRules: PolicyRule[] = validationResult.data.rule
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
.filter((rule) => {
// Filter by mode
if (!rule.modes || rule.modes.length === 0) {
@@ -369,6 +407,84 @@ export async function loadPoliciesFromToml(
.filter((rule): rule is PolicyRule => rule !== null);
rules.push(...parsedRules);
// Transform checkers
const parsedCheckers: SafetyCheckerRule[] = (
validationResult.data.safety_checker ?? []
)
.filter((checker) => {
if (!checker.modes || checker.modes.length === 0) {
return true;
}
return checker.modes.includes(approvalMode);
})
.flatMap((checker) => {
let effectiveArgsPattern = checker.argsPattern;
const commandPrefixes: string[] = [];
if (checker.commandPrefix) {
const prefixes = Array.isArray(checker.commandPrefix)
? checker.commandPrefix
: [checker.commandPrefix];
commandPrefixes.push(...prefixes);
} else if (checker.commandRegex) {
effectiveArgsPattern = `"command":"${checker.commandRegex}`;
}
const argsPatterns: Array<string | undefined> =
commandPrefixes.length > 0
? commandPrefixes.map(
(prefix) => `"command":"${escapeRegex(prefix)}`,
)
: [effectiveArgsPattern];
return argsPatterns.flatMap((argsPattern) => {
const toolNames: Array<string | undefined> = checker.toolName
? Array.isArray(checker.toolName)
? checker.toolName
: [checker.toolName]
: [undefined];
return toolNames.map((toolName) => {
let effectiveToolName: string | undefined;
if (checker.mcpName && toolName) {
effectiveToolName = `${checker.mcpName}__${toolName}`;
} else if (checker.mcpName) {
effectiveToolName = `${checker.mcpName}__*`;
} else {
effectiveToolName = toolName;
}
const safetyCheckerRule: SafetyCheckerRule = {
toolName: effectiveToolName,
priority: checker.priority,
checker: checker.checker as SafetyCheckerConfig,
};
if (argsPattern) {
try {
safetyCheckerRule.argsPattern = new RegExp(argsPattern);
} catch (e) {
const error = e as Error;
errors.push({
filePath,
fileName: file,
tier: tierName,
errorType: 'regex_compilation',
message: 'Invalid regex pattern in safety checker',
details: `Pattern: ${argsPattern}\nError: ${error.message}`,
});
return null;
}
}
return safetyCheckerRule;
});
});
})
.filter((checker): checker is SafetyCheckerRule => checker !== null);
checkers.push(...parsedCheckers);
} catch (e) {
const error = e as NodeJS.ErrnoException;
// Catch-all for unexpected errors
@@ -386,5 +502,5 @@ export async function loadPoliciesFromToml(
}
}
return { rules, errors };
return { rules, checkers, errors };
}