mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
feat(safety): Introduce safety checker framework (#12504)
This commit is contained in:
@@ -4,7 +4,14 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type PolicyRule, PolicyDecision, type ApprovalMode } from './types.js';
|
||||
import {
|
||||
type PolicyRule,
|
||||
PolicyDecision,
|
||||
type ApprovalMode,
|
||||
type SafetyCheckerConfig,
|
||||
type SafetyCheckerRule,
|
||||
InProcessCheckerType,
|
||||
} from './types.js';
|
||||
import fs from 'node:fs/promises';
|
||||
import path from 'node:path';
|
||||
import toml from '@iarna/toml';
|
||||
@@ -39,11 +46,39 @@ const PolicyRuleSchema = z.object({
|
||||
modes: z.array(z.string()).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* Schema for a single safety checker rule in the TOML file.
|
||||
*/
|
||||
const SafetyCheckerRuleSchema = z.object({
|
||||
toolName: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
mcpName: z.string().optional(),
|
||||
argsPattern: z.string().optional(),
|
||||
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
|
||||
commandRegex: z.string().optional(),
|
||||
priority: z.number().int().default(0),
|
||||
modes: z.array(z.string()).optional(),
|
||||
checker: z.discriminatedUnion('type', [
|
||||
z.object({
|
||||
type: z.literal('in-process'),
|
||||
name: z.nativeEnum(InProcessCheckerType),
|
||||
required_context: z.array(z.string()).optional(),
|
||||
config: z.record(z.unknown()).optional(),
|
||||
}),
|
||||
z.object({
|
||||
type: z.literal('external'),
|
||||
name: z.string(),
|
||||
required_context: z.array(z.string()).optional(),
|
||||
config: z.record(z.unknown()).optional(),
|
||||
}),
|
||||
]),
|
||||
});
|
||||
|
||||
/**
|
||||
* Schema for the entire policy TOML file.
|
||||
*/
|
||||
const PolicyFileSchema = z.object({
|
||||
rule: z.array(PolicyRuleSchema),
|
||||
rule: z.array(PolicyRuleSchema).optional(),
|
||||
safety_checker: z.array(SafetyCheckerRuleSchema).optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
@@ -80,6 +115,7 @@ export interface PolicyFileError {
|
||||
*/
|
||||
export interface PolicyLoadResult {
|
||||
rules: PolicyRule[];
|
||||
checkers: SafetyCheckerRule[];
|
||||
errors: PolicyFileError[];
|
||||
}
|
||||
|
||||
@@ -194,6 +230,7 @@ export async function loadPoliciesFromToml(
|
||||
getPolicyTier: (dir: string) => number,
|
||||
): Promise<PolicyLoadResult> {
|
||||
const rules: PolicyRule[] = [];
|
||||
const checkers: SafetyCheckerRule[] = [];
|
||||
const errors: PolicyFileError[] = [];
|
||||
|
||||
for (const dir of policyDirs) {
|
||||
@@ -267,8 +304,9 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
|
||||
// Validate shell command convenience syntax
|
||||
for (let i = 0; i < validationResult.data.rule.length; i++) {
|
||||
const rule = validationResult.data.rule[i];
|
||||
const tomlRules = validationResult.data.rule ?? [];
|
||||
for (let i = 0; i < tomlRules.length; i++) {
|
||||
const rule = tomlRules[i];
|
||||
const validationError = validateShellCommandSyntax(rule, i);
|
||||
if (validationError) {
|
||||
errors.push({
|
||||
@@ -285,7 +323,7 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
|
||||
// Transform rules
|
||||
const parsedRules: PolicyRule[] = validationResult.data.rule
|
||||
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
|
||||
.filter((rule) => {
|
||||
// Filter by mode
|
||||
if (!rule.modes || rule.modes.length === 0) {
|
||||
@@ -369,6 +407,84 @@ export async function loadPoliciesFromToml(
|
||||
.filter((rule): rule is PolicyRule => rule !== null);
|
||||
|
||||
rules.push(...parsedRules);
|
||||
|
||||
// Transform checkers
|
||||
const parsedCheckers: SafetyCheckerRule[] = (
|
||||
validationResult.data.safety_checker ?? []
|
||||
)
|
||||
.filter((checker) => {
|
||||
if (!checker.modes || checker.modes.length === 0) {
|
||||
return true;
|
||||
}
|
||||
return checker.modes.includes(approvalMode);
|
||||
})
|
||||
.flatMap((checker) => {
|
||||
let effectiveArgsPattern = checker.argsPattern;
|
||||
const commandPrefixes: string[] = [];
|
||||
|
||||
if (checker.commandPrefix) {
|
||||
const prefixes = Array.isArray(checker.commandPrefix)
|
||||
? checker.commandPrefix
|
||||
: [checker.commandPrefix];
|
||||
commandPrefixes.push(...prefixes);
|
||||
} else if (checker.commandRegex) {
|
||||
effectiveArgsPattern = `"command":"${checker.commandRegex}`;
|
||||
}
|
||||
|
||||
const argsPatterns: Array<string | undefined> =
|
||||
commandPrefixes.length > 0
|
||||
? commandPrefixes.map(
|
||||
(prefix) => `"command":"${escapeRegex(prefix)}`,
|
||||
)
|
||||
: [effectiveArgsPattern];
|
||||
|
||||
return argsPatterns.flatMap((argsPattern) => {
|
||||
const toolNames: Array<string | undefined> = checker.toolName
|
||||
? Array.isArray(checker.toolName)
|
||||
? checker.toolName
|
||||
: [checker.toolName]
|
||||
: [undefined];
|
||||
|
||||
return toolNames.map((toolName) => {
|
||||
let effectiveToolName: string | undefined;
|
||||
if (checker.mcpName && toolName) {
|
||||
effectiveToolName = `${checker.mcpName}__${toolName}`;
|
||||
} else if (checker.mcpName) {
|
||||
effectiveToolName = `${checker.mcpName}__*`;
|
||||
} else {
|
||||
effectiveToolName = toolName;
|
||||
}
|
||||
|
||||
const safetyCheckerRule: SafetyCheckerRule = {
|
||||
toolName: effectiveToolName,
|
||||
priority: checker.priority,
|
||||
checker: checker.checker as SafetyCheckerConfig,
|
||||
};
|
||||
|
||||
if (argsPattern) {
|
||||
try {
|
||||
safetyCheckerRule.argsPattern = new RegExp(argsPattern);
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
errors.push({
|
||||
filePath,
|
||||
fileName: file,
|
||||
tier: tierName,
|
||||
errorType: 'regex_compilation',
|
||||
message: 'Invalid regex pattern in safety checker',
|
||||
details: `Pattern: ${argsPattern}\nError: ${error.message}`,
|
||||
});
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return safetyCheckerRule;
|
||||
});
|
||||
});
|
||||
})
|
||||
.filter((checker): checker is SafetyCheckerRule => checker !== null);
|
||||
|
||||
checkers.push(...parsedCheckers);
|
||||
} catch (e) {
|
||||
const error = e as NodeJS.ErrnoException;
|
||||
// Catch-all for unexpected errors
|
||||
@@ -386,5 +502,5 @@ export async function loadPoliciesFromToml(
|
||||
}
|
||||
}
|
||||
|
||||
return { rules, errors };
|
||||
return { rules, checkers, errors };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user