feat(policy): implement dynamic mode-aware policy evaluation (#15307)

This commit is contained in:
Abhi
2025-12-22 15:25:07 -05:00
committed by GitHub
parent d18c96d6a1
commit b0d5c4c058
9 changed files with 190 additions and 46 deletions
+7 -5
View File
@@ -365,7 +365,6 @@ export class Config {
private userMemory: string; private userMemory: string;
private geminiMdFileCount: number; private geminiMdFileCount: number;
private geminiMdFilePaths: string[]; private geminiMdFilePaths: string[];
private approvalMode: ApprovalMode;
private readonly showMemoryUsage: boolean; private readonly showMemoryUsage: boolean;
private readonly accessibility: AccessibilitySettings; private readonly accessibility: AccessibilitySettings;
private readonly telemetrySettings: TelemetrySettings; private readonly telemetrySettings: TelemetrySettings;
@@ -483,7 +482,6 @@ export class Config {
this.userMemory = params.userMemory ?? ''; this.userMemory = params.userMemory ?? '';
this.geminiMdFileCount = params.geminiMdFileCount ?? 0; this.geminiMdFileCount = params.geminiMdFileCount ?? 0;
this.geminiMdFilePaths = params.geminiMdFilePaths ?? []; this.geminiMdFilePaths = params.geminiMdFilePaths ?? [];
this.approvalMode = params.approvalMode ?? ApprovalMode.DEFAULT;
this.showMemoryUsage = params.showMemoryUsage ?? false; this.showMemoryUsage = params.showMemoryUsage ?? false;
this.accessibility = params.accessibility ?? {}; this.accessibility = params.accessibility ?? {};
this.telemetrySettings = { this.telemetrySettings = {
@@ -600,7 +598,11 @@ export class Config {
this.enablePromptCompletion = params.enablePromptCompletion ?? false; this.enablePromptCompletion = params.enablePromptCompletion ?? false;
this.fileExclusions = new FileExclusions(this); this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter; this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig); this.policyEngine = new PolicyEngine({
...params.policyEngineConfig,
approvalMode:
params.approvalMode ?? params.policyEngineConfig?.approvalMode,
});
this.messageBus = new MessageBus(this.policyEngine, this.debugMode); this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.outputSettings = { this.outputSettings = {
format: params.output?.format ?? OutputFormat.TEXT, format: params.output?.format ?? OutputFormat.TEXT,
@@ -1126,7 +1128,7 @@ export class Config {
} }
getApprovalMode(): ApprovalMode { getApprovalMode(): ApprovalMode {
return this.approvalMode; return this.policyEngine.getApprovalMode();
} }
setApprovalMode(mode: ApprovalMode): void { setApprovalMode(mode: ApprovalMode): void {
@@ -1135,7 +1137,7 @@ export class Config {
'Cannot enable privileged approval modes in an untrusted folder.', 'Cannot enable privileged approval modes in an untrusted folder.',
); );
} }
this.approvalMode = mode; this.policyEngine.setApprovalMode(mode);
} }
isYoloModeDisabled(): boolean { isYoloModeDisabled(): boolean {
+2 -1
View File
@@ -124,7 +124,7 @@ export async function createPolicyEngineConfig(
rules: tomlRules, rules: tomlRules,
checkers: tomlCheckers, checkers: tomlCheckers,
errors, errors,
} = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) => } = await loadPoliciesFromToml(policyDirs, (dir) =>
getPolicyTier(dir, defaultPoliciesDir), getPolicyTier(dir, defaultPoliciesDir),
); );
@@ -236,6 +236,7 @@ export async function createPolicyEngineConfig(
rules, rules,
checkers, checkers,
defaultDecision: PolicyDecision.ASK_USER, defaultDecision: PolicyDecision.ASK_USER,
approvalMode,
}; };
} }
+6 -1
View File
@@ -20,6 +20,7 @@ import { PolicyEngine } from './policy-engine.js';
import { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js'; import { MessageBusType } from '../confirmation-bus/types.js';
import { Storage } from '../config/storage.js'; import { Storage } from '../config/storage.js';
import { ApprovalMode } from './types.js';
vi.mock('node:fs/promises'); vi.mock('node:fs/promises');
vi.mock('../config/storage.js'); vi.mock('../config/storage.js');
@@ -29,7 +30,11 @@ describe('createPolicyUpdater', () => {
let messageBus: MessageBus; let messageBus: MessageBus;
beforeEach(() => { beforeEach(() => {
policyEngine = new PolicyEngine({ rules: [], checkers: [] }); policyEngine = new PolicyEngine({
rules: [],
checkers: [],
approvalMode: ApprovalMode.DEFAULT,
});
messageBus = new MessageBus(policyEngine); messageBus = new MessageBus(policyEngine);
vi.clearAllMocks(); vi.clearAllMocks();
}); });
+40 -1
View File
@@ -12,6 +12,7 @@ import {
type PolicyEngineConfig, type PolicyEngineConfig,
type SafetyCheckerRule, type SafetyCheckerRule,
InProcessCheckerType, InProcessCheckerType,
ApprovalMode,
} from './types.js'; } from './types.js';
import type { FunctionCall } from '@google/genai'; import type { FunctionCall } from '@google/genai';
import { SafetyCheckDecision } from '../safety/protocol.js'; import { SafetyCheckDecision } from '../safety/protocol.js';
@@ -25,7 +26,10 @@ describe('PolicyEngine', () => {
mockCheckerRunner = { mockCheckerRunner = {
runChecker: vi.fn(), runChecker: vi.fn(),
} as unknown as CheckerRunner; } as unknown as CheckerRunner;
engine = new PolicyEngine({}, mockCheckerRunner); engine = new PolicyEngine(
{ approvalMode: ApprovalMode.DEFAULT },
mockCheckerRunner,
);
}); });
describe('constructor', () => { describe('constructor', () => {
@@ -163,6 +167,41 @@ describe('PolicyEngine', () => {
(await engine.check({ name: 'unknown-tool' }, undefined)).decision, (await engine.check({ name: 'unknown-tool' }, undefined)).decision,
).toBe(PolicyDecision.DENY); ).toBe(PolicyDecision.DENY);
}); });
it('should dynamically switch between modes and respect rule modes', async () => {
const rules: PolicyRule[] = [
{
toolName: 'edit',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'edit',
decision: PolicyDecision.ALLOW,
priority: 20,
modes: [ApprovalMode.AUTO_EDIT],
},
];
engine = new PolicyEngine({ rules });
// Default mode: priority 20 rule doesn't match, falls back to priority 10
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ASK_USER,
);
// Switch to autoEdit mode
engine.setApprovalMode(ApprovalMode.AUTO_EDIT);
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ALLOW,
);
// Switch back to default
engine.setApprovalMode(ApprovalMode.DEFAULT);
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ASK_USER,
);
});
}); });
describe('addRule', () => { describe('addRule', () => {
+43 -2
View File
@@ -13,6 +13,7 @@ import {
type HookCheckerRule, type HookCheckerRule,
type HookExecutionContext, type HookExecutionContext,
getHookSource, getHookSource,
ApprovalMode,
} from './types.js'; } from './types.js';
import { stableStringify } from './stable-stringify.js'; import { stableStringify } from './stable-stringify.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
@@ -30,7 +31,15 @@ function ruleMatches(
toolCall: FunctionCall, toolCall: FunctionCall,
stringifiedArgs: string | undefined, stringifiedArgs: string | undefined,
serverName: string | undefined, serverName: string | undefined,
currentApprovalMode: ApprovalMode,
): boolean { ): boolean {
// Check if rule applies to current approval mode
if (rule.modes && rule.modes.length > 0) {
if (!rule.modes.includes(currentApprovalMode)) {
return false;
}
}
// Check tool name if specified // Check tool name if specified
if (rule.toolName) { if (rule.toolName) {
// Support wildcard patterns: "serverName__*" matches "serverName__anyTool" // Support wildcard patterns: "serverName__*" matches "serverName__anyTool"
@@ -98,6 +107,7 @@ export class PolicyEngine {
private readonly nonInteractive: boolean; private readonly nonInteractive: boolean;
private readonly checkerRunner?: CheckerRunner; private readonly checkerRunner?: CheckerRunner;
private readonly allowHooks: boolean; private readonly allowHooks: boolean;
private approvalMode: ApprovalMode;
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
this.rules = (config.rules ?? []).sort( this.rules = (config.rules ?? []).sort(
@@ -113,6 +123,21 @@ export class PolicyEngine {
this.nonInteractive = config.nonInteractive ?? false; this.nonInteractive = config.nonInteractive ?? false;
this.checkerRunner = checkerRunner; this.checkerRunner = checkerRunner;
this.allowHooks = config.allowHooks ?? true; this.allowHooks = config.allowHooks ?? true;
this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT;
}
/**
* Update the current approval mode.
*/
setApprovalMode(mode: ApprovalMode): void {
this.approvalMode = mode;
}
/**
* Get the current approval mode.
*/
getApprovalMode(): ApprovalMode {
return this.approvalMode;
} }
/** /**
@@ -145,7 +170,15 @@ export class PolicyEngine {
let decision: PolicyDecision | undefined; let decision: PolicyDecision | undefined;
for (const rule of this.rules) { for (const rule of this.rules) {
if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) { if (
ruleMatches(
rule,
toolCall,
stringifiedArgs,
serverName,
this.approvalMode,
)
) {
debugLogger.debug( debugLogger.debug(
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
); );
@@ -225,7 +258,15 @@ export class PolicyEngine {
// If decision is not DENY, run safety checkers // If decision is not DENY, run safety checkers
if (decision !== PolicyDecision.DENY && this.checkerRunner) { if (decision !== PolicyDecision.DENY && this.checkerRunner) {
for (const checkerRule of this.checkers) { for (const checkerRule of this.checkers) {
if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) { if (
ruleMatches(
checkerRule,
toolCall,
stringifiedArgs,
serverName,
this.approvalMode,
)
) {
debugLogger.debug( debugLogger.debug(
`[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`, `[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`,
); );
@@ -6,7 +6,7 @@
import { describe, it, expect, beforeEach } from 'vitest'; import { describe, it, expect, beforeEach } from 'vitest';
import { PolicyEngine } from './policy-engine.js'; import { PolicyEngine } from './policy-engine.js';
import { PolicyDecision } from './types.js'; import { PolicyDecision, ApprovalMode } from './types.js';
import type { FunctionCall } from '@google/genai'; import type { FunctionCall } from '@google/genai';
describe('Shell Safety Policy', () => { describe('Shell Safety Policy', () => {
@@ -25,6 +25,7 @@ describe('Shell Safety Policy', () => {
}, },
], ],
defaultDecision: PolicyDecision.ASK_USER, defaultDecision: PolicyDecision.ASK_USER,
approvalMode: ApprovalMode.DEFAULT,
}); });
}); });
+34 -15
View File
@@ -5,7 +5,7 @@
*/ */
import { describe, it, expect, beforeEach, afterEach } from 'vitest'; import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { ApprovalMode, PolicyDecision } from './types.js'; import { PolicyDecision } from './types.js';
import * as fs from 'node:fs/promises'; import * as fs from 'node:fs/promises';
import * as path from 'node:path'; import * as path from 'node:path';
import * as os from 'node:os'; import * as os from 'node:os';
@@ -36,7 +36,7 @@ describe('policy-toml-loader', () => {
): Promise<PolicyLoadResult> { ): Promise<PolicyLoadResult> {
await fs.writeFile(path.join(tempDir, fileName), tomlContent); await fs.writeFile(path.join(tempDir, fileName), tomlContent);
const getPolicyTier = (_dir: string) => 1; const getPolicyTier = (_dir: string) => 1;
return loadPoliciesFromToml(ApprovalMode.DEFAULT, [tempDir], getPolicyTier); return loadPoliciesFromToml([tempDir], getPolicyTier);
} }
describe('loadPoliciesFromToml', () => { describe('loadPoliciesFromToml', () => {
@@ -133,7 +133,7 @@ priority = 100
expect(result.errors).toHaveLength(0); expect(result.errors).toHaveLength(0);
}); });
it('should filter rules by mode', async () => { it('should NOT filter rules by mode at load time but preserve modes property', async () => {
const result = await runLoadPoliciesFromToml(` const result = await runLoadPoliciesFromToml(`
[[rule]] [[rule]]
toolName = "glob" toolName = "glob"
@@ -148,12 +148,39 @@ priority = 100
modes = ["yolo"] modes = ["yolo"]
`); `);
// Only the first rule should be included (modes includes "default") // Both rules should be included
expect(result.rules).toHaveLength(1); expect(result.rules).toHaveLength(2);
expect(result.rules[0].toolName).toBe('glob'); expect(result.rules[0].toolName).toBe('glob');
expect(result.rules[0].modes).toEqual(['default', 'yolo']);
expect(result.rules[1].toolName).toBe('grep');
expect(result.rules[1].modes).toEqual(['yolo']);
expect(result.errors).toHaveLength(0); expect(result.errors).toHaveLength(0);
}); });
it('should return error if modes property is used for Tier 2 and Tier 3 policies', async () => {
await fs.writeFile(
path.join(tempDir, 'tier2.toml'),
`
[[rule]]
toolName = "tier2-tool"
decision = "allow"
priority = 100
modes = ["autoEdit"]
`,
);
const getPolicyTier = (_dir: string) => 2; // Tier 2
const result = await loadPoliciesFromToml([tempDir], getPolicyTier);
// It still transforms the rule, but it should also report an error
expect(result.rules).toHaveLength(1);
expect(result.rules[0].toolName).toBe('tier2-tool');
expect(result.rules[0].modes).toBeUndefined(); // Should be restricted
expect(result.errors).toHaveLength(1);
expect(result.errors[0].errorType).toBe('rule_validation');
expect(result.errors[0].message).toContain('Restricted property "modes"');
});
it('should handle TOML parse errors', async () => { it('should handle TOML parse errors', async () => {
const result = await runLoadPoliciesFromToml(` const result = await runLoadPoliciesFromToml(`
[[rule] [[rule]
@@ -267,11 +294,7 @@ priority = -1
); );
const getPolicyTier = (_dir: string) => 1; const getPolicyTier = (_dir: string) => 1;
const result = await loadPoliciesFromToml( const result = await loadPoliciesFromToml([tempDir], getPolicyTier);
ApprovalMode.DEFAULT,
[tempDir],
getPolicyTier,
);
expect(result.rules).toHaveLength(1); expect(result.rules).toHaveLength(1);
expect(result.rules[0].toolName).toBe('glob'); expect(result.rules[0].toolName).toBe('glob');
@@ -439,11 +462,7 @@ priority = 100
await fs.writeFile(filePath, 'content'); await fs.writeFile(filePath, 'content');
const getPolicyTier = (_dir: string) => 1; const getPolicyTier = (_dir: string) => 1;
const result = await loadPoliciesFromToml( const result = await loadPoliciesFromToml([filePath], getPolicyTier);
ApprovalMode.DEFAULT,
[filePath],
getPolicyTier,
);
expect(result.errors).toHaveLength(1); expect(result.errors).toHaveLength(1);
const error = result.errors[0]; const error = result.errors[0];
+38 -20
View File
@@ -7,7 +7,7 @@
import { import {
type PolicyRule, type PolicyRule,
PolicyDecision, PolicyDecision,
type ApprovalMode, ApprovalMode,
type SafetyCheckerConfig, type SafetyCheckerConfig,
type SafetyCheckerRule, type SafetyCheckerRule,
InProcessCheckerType, InProcessCheckerType,
@@ -43,7 +43,7 @@ const PolicyRuleSchema = z.object({
message: message:
'priority must be <= 999 to prevent tier overflow. Priorities >= 1000 would jump to the next tier.', 'priority must be <= 999 to prevent tier overflow. Priorities >= 1000 would jump to the next tier.',
}), }),
modes: z.array(z.string()).optional(), modes: z.array(z.nativeEnum(ApprovalMode)).optional(),
}); });
/** /**
@@ -56,7 +56,7 @@ const SafetyCheckerRuleSchema = z.object({
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(), commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
commandRegex: z.string().optional(), commandRegex: z.string().optional(),
priority: z.number().int().default(0), priority: z.number().int().default(0),
modes: z.array(z.string()).optional(), modes: z.array(z.nativeEnum(ApprovalMode)).optional(),
checker: z.discriminatedUnion('type', [ checker: z.discriminatedUnion('type', [
z.object({ z.object({
type: z.literal('in-process'), type: z.literal('in-process'),
@@ -216,16 +216,13 @@ function transformPriority(priority: number, tier: number): number {
* 1. Scans directories for .toml files * 1. Scans directories for .toml files
* 2. Parses and validates each file * 2. Parses and validates each file
* 3. Transforms rules (commandPrefix, arrays, mcpName, priorities) * 3. Transforms rules (commandPrefix, arrays, mcpName, priorities)
* 4. Filters rules by approval mode * 4. Collects detailed error information for any failures
* 5. Collects detailed error information for any failures
* *
* @param approvalMode The current approval mode (for filtering rules by mode)
* @param policyDirs Array of directory paths to scan for policy files * @param policyDirs Array of directory paths to scan for policy files
* @param getPolicyTier Function to determine tier (1-3) for a directory * @param getPolicyTier Function to determine tier (1-3) for a directory
* @returns Object containing successfully parsed rules and any errors encountered * @returns Object containing successfully parsed rules and any errors encountered
*/ */
export async function loadPoliciesFromToml( export async function loadPoliciesFromToml(
approvalMode: ApprovalMode,
policyDirs: string[], policyDirs: string[],
getPolicyTier: (dir: string) => number, getPolicyTier: (dir: string) => number,
): Promise<PolicyLoadResult> { ): Promise<PolicyLoadResult> {
@@ -305,6 +302,8 @@ export async function loadPoliciesFromToml(
// Validate shell command convenience syntax // Validate shell command convenience syntax
const tomlRules = validationResult.data.rule ?? []; const tomlRules = validationResult.data.rule ?? [];
const tomlCheckers = validationResult.data.safety_checker ?? [];
for (let i = 0; i < tomlRules.length; i++) { for (let i = 0; i < tomlRules.length; i++) {
const rule = tomlRules[i]; const rule = tomlRules[i];
const validationError = validateShellCommandSyntax(rule, i); const validationError = validateShellCommandSyntax(rule, i);
@@ -320,17 +319,40 @@ export async function loadPoliciesFromToml(
}); });
// Continue to next rule, don't skip the entire file // Continue to next rule, don't skip the entire file
} }
if (tier > 1 && rule.modes && rule.modes.length > 0) {
errors.push({
filePath,
fileName: file,
tier: tierName,
ruleIndex: i,
errorType: 'rule_validation',
message: 'Restricted property "modes"',
details: `Rule #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`,
suggestion: 'Remove the "modes" property from this rule.',
});
}
}
for (let i = 0; i < tomlCheckers.length; i++) {
const checker = tomlCheckers[i];
if (tier > 1 && checker.modes && checker.modes.length > 0) {
errors.push({
filePath,
fileName: file,
tier: tierName,
ruleIndex: i,
errorType: 'rule_validation',
message: 'Restricted property "modes" in safety checker',
details: `Safety Checker #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`,
suggestion:
'Remove the "modes" property from this safety checker.',
});
}
} }
// Transform rules // Transform rules
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? []) const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
.filter((rule) => {
// Filter by mode
if (!rule.modes || rule.modes.length === 0) {
return true;
}
return rule.modes.includes(approvalMode);
})
.flatMap((rule) => { .flatMap((rule) => {
// Transform commandPrefix/commandRegex to argsPattern // Transform commandPrefix/commandRegex to argsPattern
let effectiveArgsPattern = rule.argsPattern; let effectiveArgsPattern = rule.argsPattern;
@@ -377,6 +399,7 @@ export async function loadPoliciesFromToml(
toolName: effectiveToolName, toolName: effectiveToolName,
decision: rule.decision, decision: rule.decision,
priority: transformPriority(rule.priority, tier), priority: transformPriority(rule.priority, tier),
modes: tier === 1 ? rule.modes : undefined,
}; };
// Compile regex pattern // Compile regex pattern
@@ -412,12 +435,6 @@ export async function loadPoliciesFromToml(
const parsedCheckers: SafetyCheckerRule[] = ( const parsedCheckers: SafetyCheckerRule[] = (
validationResult.data.safety_checker ?? [] validationResult.data.safety_checker ?? []
) )
.filter((checker) => {
if (!checker.modes || checker.modes.length === 0) {
return true;
}
return checker.modes.includes(approvalMode);
})
.flatMap((checker) => { .flatMap((checker) => {
let effectiveArgsPattern = checker.argsPattern; let effectiveArgsPattern = checker.argsPattern;
const commandPrefixes: string[] = []; const commandPrefixes: string[] = [];
@@ -459,6 +476,7 @@ export async function loadPoliciesFromToml(
toolName: effectiveToolName, toolName: effectiveToolName,
priority: checker.priority, priority: checker.priority,
checker: checker.checker as SafetyCheckerConfig, checker: checker.checker as SafetyCheckerConfig,
modes: tier === 1 ? checker.modes : undefined,
}; };
if (argsPattern) { if (argsPattern) {
+18
View File
@@ -117,6 +117,12 @@ export interface PolicyRule {
* Default is 0. * Default is 0.
*/ */
priority?: number; priority?: number;
/**
* Approval modes this rule applies to.
* If undefined or empty, it applies to all modes.
*/
modes?: ApprovalMode[];
} }
export interface SafetyCheckerRule { export interface SafetyCheckerRule {
@@ -143,6 +149,12 @@ export interface SafetyCheckerRule {
* additional validation of a tool call. * additional validation of a tool call.
*/ */
checker: SafetyCheckerConfig; checker: SafetyCheckerConfig;
/**
* Approval modes this rule applies to.
* If undefined or empty, it applies to all modes.
*/
modes?: ApprovalMode[];
} }
export interface HookExecutionContext { export interface HookExecutionContext {
@@ -215,6 +227,12 @@ export interface PolicyEngineConfig {
* Defaults to true. * Defaults to true.
*/ */
allowHooks?: boolean; allowHooks?: boolean;
/**
* Current approval mode.
* Used to filter rules that have specific 'modes' defined.
*/
approvalMode?: ApprovalMode;
} }
export interface PolicySettings { export interface PolicySettings {