feat(policy): implement dynamic mode-aware policy evaluation (#15307)

This commit is contained in:
Abhi
2025-12-22 15:25:07 -05:00
committed by Abhi
parent 81b171c1b4
commit 999f9b4bf1
9 changed files with 190 additions and 46 deletions

View File

@@ -361,7 +361,6 @@ export class Config {
private userMemory: string;
private geminiMdFileCount: number;
private geminiMdFilePaths: string[];
private approvalMode: ApprovalMode;
private readonly showMemoryUsage: boolean;
private readonly accessibility: AccessibilitySettings;
private readonly telemetrySettings: TelemetrySettings;
@@ -481,7 +480,6 @@ export class Config {
this.userMemory = params.userMemory ?? '';
this.geminiMdFileCount = params.geminiMdFileCount ?? 0;
this.geminiMdFilePaths = params.geminiMdFilePaths ?? [];
this.approvalMode = params.approvalMode ?? ApprovalMode.DEFAULT;
this.showMemoryUsage = params.showMemoryUsage ?? false;
this.accessibility = params.accessibility ?? {};
this.telemetrySettings = {
@@ -596,7 +594,11 @@ export class Config {
this.enablePromptCompletion = params.enablePromptCompletion ?? false;
this.fileExclusions = new FileExclusions(this);
this.eventEmitter = params.eventEmitter;
this.policyEngine = new PolicyEngine(params.policyEngineConfig);
this.policyEngine = new PolicyEngine({
...params.policyEngineConfig,
approvalMode:
params.approvalMode ?? params.policyEngineConfig?.approvalMode,
});
this.messageBus = new MessageBus(this.policyEngine, this.debugMode);
this.outputSettings = {
format: params.output?.format ?? OutputFormat.TEXT,
@@ -1127,7 +1129,7 @@ export class Config {
}
getApprovalMode(): ApprovalMode {
return this.approvalMode;
return this.policyEngine.getApprovalMode();
}
setApprovalMode(mode: ApprovalMode): void {
@@ -1136,7 +1138,7 @@ export class Config {
'Cannot enable privileged approval modes in an untrusted folder.',
);
}
this.approvalMode = mode;
this.policyEngine.setApprovalMode(mode);
}
isYoloModeDisabled(): boolean {

View File

@@ -124,7 +124,7 @@ export async function createPolicyEngineConfig(
rules: tomlRules,
checkers: tomlCheckers,
errors,
} = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) =>
} = await loadPoliciesFromToml(policyDirs, (dir) =>
getPolicyTier(dir, defaultPoliciesDir),
);
@@ -236,6 +236,7 @@ export async function createPolicyEngineConfig(
rules,
checkers,
defaultDecision: PolicyDecision.ASK_USER,
approvalMode,
};
}

View File

@@ -20,6 +20,7 @@ import { PolicyEngine } from './policy-engine.js';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import { Storage } from '../config/storage.js';
import { ApprovalMode } from './types.js';
vi.mock('node:fs/promises');
vi.mock('../config/storage.js');
@@ -29,7 +30,11 @@ describe('createPolicyUpdater', () => {
let messageBus: MessageBus;
beforeEach(() => {
policyEngine = new PolicyEngine({ rules: [], checkers: [] });
policyEngine = new PolicyEngine({
rules: [],
checkers: [],
approvalMode: ApprovalMode.DEFAULT,
});
messageBus = new MessageBus(policyEngine);
vi.clearAllMocks();
});

View File

@@ -12,6 +12,7 @@ import {
type PolicyEngineConfig,
type SafetyCheckerRule,
InProcessCheckerType,
ApprovalMode,
} from './types.js';
import type { FunctionCall } from '@google/genai';
import { SafetyCheckDecision } from '../safety/protocol.js';
@@ -25,7 +26,10 @@ describe('PolicyEngine', () => {
mockCheckerRunner = {
runChecker: vi.fn(),
} as unknown as CheckerRunner;
engine = new PolicyEngine({}, mockCheckerRunner);
engine = new PolicyEngine(
{ approvalMode: ApprovalMode.DEFAULT },
mockCheckerRunner,
);
});
describe('constructor', () => {
@@ -163,6 +167,41 @@ describe('PolicyEngine', () => {
(await engine.check({ name: 'unknown-tool' }, undefined)).decision,
).toBe(PolicyDecision.DENY);
});
it('should dynamically switch between modes and respect rule modes', async () => {
const rules: PolicyRule[] = [
{
toolName: 'edit',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
{
toolName: 'edit',
decision: PolicyDecision.ALLOW,
priority: 20,
modes: [ApprovalMode.AUTO_EDIT],
},
];
engine = new PolicyEngine({ rules });
// Default mode: priority 20 rule doesn't match, falls back to priority 10
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ASK_USER,
);
// Switch to autoEdit mode
engine.setApprovalMode(ApprovalMode.AUTO_EDIT);
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ALLOW,
);
// Switch back to default
engine.setApprovalMode(ApprovalMode.DEFAULT);
expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe(
PolicyDecision.ASK_USER,
);
});
});
describe('addRule', () => {

View File

@@ -13,6 +13,7 @@ import {
type HookCheckerRule,
type HookExecutionContext,
getHookSource,
ApprovalMode,
} from './types.js';
import { stableStringify } from './stable-stringify.js';
import { debugLogger } from '../utils/debugLogger.js';
@@ -30,7 +31,15 @@ function ruleMatches(
toolCall: FunctionCall,
stringifiedArgs: string | undefined,
serverName: string | undefined,
currentApprovalMode: ApprovalMode,
): boolean {
// Check if rule applies to current approval mode
if (rule.modes && rule.modes.length > 0) {
if (!rule.modes.includes(currentApprovalMode)) {
return false;
}
}
// Check tool name if specified
if (rule.toolName) {
// Support wildcard patterns: "serverName__*" matches "serverName__anyTool"
@@ -98,6 +107,7 @@ export class PolicyEngine {
private readonly nonInteractive: boolean;
private readonly checkerRunner?: CheckerRunner;
private readonly allowHooks: boolean;
private approvalMode: ApprovalMode;
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
this.rules = (config.rules ?? []).sort(
@@ -113,6 +123,21 @@ export class PolicyEngine {
this.nonInteractive = config.nonInteractive ?? false;
this.checkerRunner = checkerRunner;
this.allowHooks = config.allowHooks ?? true;
this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT;
}
/**
* Update the current approval mode.
*/
setApprovalMode(mode: ApprovalMode): void {
this.approvalMode = mode;
}
/**
* Get the current approval mode.
*/
getApprovalMode(): ApprovalMode {
return this.approvalMode;
}
/**
@@ -145,7 +170,15 @@ export class PolicyEngine {
let decision: PolicyDecision | undefined;
for (const rule of this.rules) {
if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) {
if (
ruleMatches(
rule,
toolCall,
stringifiedArgs,
serverName,
this.approvalMode,
)
) {
debugLogger.debug(
`[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`,
);
@@ -225,7 +258,15 @@ export class PolicyEngine {
// If decision is not DENY, run safety checkers
if (decision !== PolicyDecision.DENY && this.checkerRunner) {
for (const checkerRule of this.checkers) {
if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) {
if (
ruleMatches(
checkerRule,
toolCall,
stringifiedArgs,
serverName,
this.approvalMode,
)
) {
debugLogger.debug(
`[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`,
);

View File

@@ -6,7 +6,7 @@
import { describe, it, expect, beforeEach } from 'vitest';
import { PolicyEngine } from './policy-engine.js';
import { PolicyDecision } from './types.js';
import { PolicyDecision, ApprovalMode } from './types.js';
import type { FunctionCall } from '@google/genai';
describe('Shell Safety Policy', () => {
@@ -25,6 +25,7 @@ describe('Shell Safety Policy', () => {
},
],
defaultDecision: PolicyDecision.ASK_USER,
approvalMode: ApprovalMode.DEFAULT,
});
});

View File

@@ -5,7 +5,7 @@
*/
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
import { ApprovalMode, PolicyDecision } from './types.js';
import { PolicyDecision } from './types.js';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import * as os from 'node:os';
@@ -36,7 +36,7 @@ describe('policy-toml-loader', () => {
): Promise<PolicyLoadResult> {
await fs.writeFile(path.join(tempDir, fileName), tomlContent);
const getPolicyTier = (_dir: string) => 1;
return loadPoliciesFromToml(ApprovalMode.DEFAULT, [tempDir], getPolicyTier);
return loadPoliciesFromToml([tempDir], getPolicyTier);
}
describe('loadPoliciesFromToml', () => {
@@ -133,7 +133,7 @@ priority = 100
expect(result.errors).toHaveLength(0);
});
it('should filter rules by mode', async () => {
it('should NOT filter rules by mode at load time but preserve modes property', async () => {
const result = await runLoadPoliciesFromToml(`
[[rule]]
toolName = "glob"
@@ -148,12 +148,39 @@ priority = 100
modes = ["yolo"]
`);
// Only the first rule should be included (modes includes "default")
expect(result.rules).toHaveLength(1);
// Both rules should be included
expect(result.rules).toHaveLength(2);
expect(result.rules[0].toolName).toBe('glob');
expect(result.rules[0].modes).toEqual(['default', 'yolo']);
expect(result.rules[1].toolName).toBe('grep');
expect(result.rules[1].modes).toEqual(['yolo']);
expect(result.errors).toHaveLength(0);
});
it('should return error if modes property is used for Tier 2 and Tier 3 policies', async () => {
await fs.writeFile(
path.join(tempDir, 'tier2.toml'),
`
[[rule]]
toolName = "tier2-tool"
decision = "allow"
priority = 100
modes = ["autoEdit"]
`,
);
const getPolicyTier = (_dir: string) => 2; // Tier 2
const result = await loadPoliciesFromToml([tempDir], getPolicyTier);
// It still transforms the rule, but it should also report an error
expect(result.rules).toHaveLength(1);
expect(result.rules[0].toolName).toBe('tier2-tool');
expect(result.rules[0].modes).toBeUndefined(); // Should be restricted
expect(result.errors).toHaveLength(1);
expect(result.errors[0].errorType).toBe('rule_validation');
expect(result.errors[0].message).toContain('Restricted property "modes"');
});
it('should handle TOML parse errors', async () => {
const result = await runLoadPoliciesFromToml(`
[[rule]
@@ -267,11 +294,7 @@ priority = -1
);
const getPolicyTier = (_dir: string) => 1;
const result = await loadPoliciesFromToml(
ApprovalMode.DEFAULT,
[tempDir],
getPolicyTier,
);
const result = await loadPoliciesFromToml([tempDir], getPolicyTier);
expect(result.rules).toHaveLength(1);
expect(result.rules[0].toolName).toBe('glob');
@@ -439,11 +462,7 @@ priority = 100
await fs.writeFile(filePath, 'content');
const getPolicyTier = (_dir: string) => 1;
const result = await loadPoliciesFromToml(
ApprovalMode.DEFAULT,
[filePath],
getPolicyTier,
);
const result = await loadPoliciesFromToml([filePath], getPolicyTier);
expect(result.errors).toHaveLength(1);
const error = result.errors[0];

View File

@@ -7,7 +7,7 @@
import {
type PolicyRule,
PolicyDecision,
type ApprovalMode,
ApprovalMode,
type SafetyCheckerConfig,
type SafetyCheckerRule,
InProcessCheckerType,
@@ -43,7 +43,7 @@ const PolicyRuleSchema = z.object({
message:
'priority must be <= 999 to prevent tier overflow. Priorities >= 1000 would jump to the next tier.',
}),
modes: z.array(z.string()).optional(),
modes: z.array(z.nativeEnum(ApprovalMode)).optional(),
});
/**
@@ -56,7 +56,7 @@ const SafetyCheckerRuleSchema = z.object({
commandPrefix: z.union([z.string(), z.array(z.string())]).optional(),
commandRegex: z.string().optional(),
priority: z.number().int().default(0),
modes: z.array(z.string()).optional(),
modes: z.array(z.nativeEnum(ApprovalMode)).optional(),
checker: z.discriminatedUnion('type', [
z.object({
type: z.literal('in-process'),
@@ -216,16 +216,13 @@ function transformPriority(priority: number, tier: number): number {
* 1. Scans directories for .toml files
* 2. Parses and validates each file
* 3. Transforms rules (commandPrefix, arrays, mcpName, priorities)
* 4. Filters rules by approval mode
* 5. Collects detailed error information for any failures
* 4. Collects detailed error information for any failures
*
* @param approvalMode The current approval mode (for filtering rules by mode)
* @param policyDirs Array of directory paths to scan for policy files
* @param getPolicyTier Function to determine tier (1-3) for a directory
* @returns Object containing successfully parsed rules and any errors encountered
*/
export async function loadPoliciesFromToml(
approvalMode: ApprovalMode,
policyDirs: string[],
getPolicyTier: (dir: string) => number,
): Promise<PolicyLoadResult> {
@@ -305,6 +302,8 @@ export async function loadPoliciesFromToml(
// Validate shell command convenience syntax
const tomlRules = validationResult.data.rule ?? [];
const tomlCheckers = validationResult.data.safety_checker ?? [];
for (let i = 0; i < tomlRules.length; i++) {
const rule = tomlRules[i];
const validationError = validateShellCommandSyntax(rule, i);
@@ -320,17 +319,40 @@ export async function loadPoliciesFromToml(
});
// Continue to next rule, don't skip the entire file
}
if (tier > 1 && rule.modes && rule.modes.length > 0) {
errors.push({
filePath,
fileName: file,
tier: tierName,
ruleIndex: i,
errorType: 'rule_validation',
message: 'Restricted property "modes"',
details: `Rule #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`,
suggestion: 'Remove the "modes" property from this rule.',
});
}
}
for (let i = 0; i < tomlCheckers.length; i++) {
const checker = tomlCheckers[i];
if (tier > 1 && checker.modes && checker.modes.length > 0) {
errors.push({
filePath,
fileName: file,
tier: tierName,
ruleIndex: i,
errorType: 'rule_validation',
message: 'Restricted property "modes" in safety checker',
details: `Safety Checker #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`,
suggestion:
'Remove the "modes" property from this safety checker.',
});
}
}
// Transform rules
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
.filter((rule) => {
// Filter by mode
if (!rule.modes || rule.modes.length === 0) {
return true;
}
return rule.modes.includes(approvalMode);
})
.flatMap((rule) => {
// Transform commandPrefix/commandRegex to argsPattern
let effectiveArgsPattern = rule.argsPattern;
@@ -377,6 +399,7 @@ export async function loadPoliciesFromToml(
toolName: effectiveToolName,
decision: rule.decision,
priority: transformPriority(rule.priority, tier),
modes: tier === 1 ? rule.modes : undefined,
};
// Compile regex pattern
@@ -412,12 +435,6 @@ export async function loadPoliciesFromToml(
const parsedCheckers: SafetyCheckerRule[] = (
validationResult.data.safety_checker ?? []
)
.filter((checker) => {
if (!checker.modes || checker.modes.length === 0) {
return true;
}
return checker.modes.includes(approvalMode);
})
.flatMap((checker) => {
let effectiveArgsPattern = checker.argsPattern;
const commandPrefixes: string[] = [];
@@ -459,6 +476,7 @@ export async function loadPoliciesFromToml(
toolName: effectiveToolName,
priority: checker.priority,
checker: checker.checker as SafetyCheckerConfig,
modes: tier === 1 ? checker.modes : undefined,
};
if (argsPattern) {

View File

@@ -117,6 +117,12 @@ export interface PolicyRule {
* Default is 0.
*/
priority?: number;
/**
* Approval modes this rule applies to.
* If undefined or empty, it applies to all modes.
*/
modes?: ApprovalMode[];
}
export interface SafetyCheckerRule {
@@ -143,6 +149,12 @@ export interface SafetyCheckerRule {
* additional validation of a tool call.
*/
checker: SafetyCheckerConfig;
/**
* Approval modes this rule applies to.
* If undefined or empty, it applies to all modes.
*/
modes?: ApprovalMode[];
}
export interface HookExecutionContext {
@@ -215,6 +227,12 @@ export interface PolicyEngineConfig {
* Defaults to true.
*/
allowHooks?: boolean;
/**
* Current approval mode.
* Used to filter rules that have specific 'modes' defined.
*/
approvalMode?: ApprovalMode;
}
export interface PolicySettings {