diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 888ee942a1..ffa6258f11 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -361,7 +361,6 @@ export class Config { private userMemory: string; private geminiMdFileCount: number; private geminiMdFilePaths: string[]; - private approvalMode: ApprovalMode; private readonly showMemoryUsage: boolean; private readonly accessibility: AccessibilitySettings; private readonly telemetrySettings: TelemetrySettings; @@ -481,7 +480,6 @@ export class Config { this.userMemory = params.userMemory ?? ''; this.geminiMdFileCount = params.geminiMdFileCount ?? 0; this.geminiMdFilePaths = params.geminiMdFilePaths ?? []; - this.approvalMode = params.approvalMode ?? ApprovalMode.DEFAULT; this.showMemoryUsage = params.showMemoryUsage ?? false; this.accessibility = params.accessibility ?? {}; this.telemetrySettings = { @@ -596,7 +594,11 @@ export class Config { this.enablePromptCompletion = params.enablePromptCompletion ?? false; this.fileExclusions = new FileExclusions(this); this.eventEmitter = params.eventEmitter; - this.policyEngine = new PolicyEngine(params.policyEngineConfig); + this.policyEngine = new PolicyEngine({ + ...params.policyEngineConfig, + approvalMode: + params.approvalMode ?? params.policyEngineConfig?.approvalMode, + }); this.messageBus = new MessageBus(this.policyEngine, this.debugMode); this.outputSettings = { format: params.output?.format ?? OutputFormat.TEXT, @@ -1127,7 +1129,7 @@ export class Config { } getApprovalMode(): ApprovalMode { - return this.approvalMode; + return this.policyEngine.getApprovalMode(); } setApprovalMode(mode: ApprovalMode): void { @@ -1136,7 +1138,7 @@ export class Config { 'Cannot enable privileged approval modes in an untrusted folder.', ); } - this.approvalMode = mode; + this.policyEngine.setApprovalMode(mode); } isYoloModeDisabled(): boolean { diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index 3c6016f086..9057506ae2 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -124,7 +124,7 @@ export async function createPolicyEngineConfig( rules: tomlRules, checkers: tomlCheckers, errors, - } = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) => + } = await loadPoliciesFromToml(policyDirs, (dir) => getPolicyTier(dir, defaultPoliciesDir), ); @@ -236,6 +236,7 @@ export async function createPolicyEngineConfig( rules, checkers, defaultDecision: PolicyDecision.ASK_USER, + approvalMode, }; } diff --git a/packages/core/src/policy/persistence.test.ts b/packages/core/src/policy/persistence.test.ts index e7916b8644..4954d7280e 100644 --- a/packages/core/src/policy/persistence.test.ts +++ b/packages/core/src/policy/persistence.test.ts @@ -20,6 +20,7 @@ import { PolicyEngine } from './policy-engine.js'; import { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBusType } from '../confirmation-bus/types.js'; import { Storage } from '../config/storage.js'; +import { ApprovalMode } from './types.js'; vi.mock('node:fs/promises'); vi.mock('../config/storage.js'); @@ -29,7 +30,11 @@ describe('createPolicyUpdater', () => { let messageBus: MessageBus; beforeEach(() => { - policyEngine = new PolicyEngine({ rules: [], checkers: [] }); + policyEngine = new PolicyEngine({ + rules: [], + checkers: [], + approvalMode: ApprovalMode.DEFAULT, + }); messageBus = new MessageBus(policyEngine); vi.clearAllMocks(); }); diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index a58725f8f2..a362d1995a 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -12,6 +12,7 @@ import { type PolicyEngineConfig, type SafetyCheckerRule, InProcessCheckerType, + ApprovalMode, } from './types.js'; import type { FunctionCall } from '@google/genai'; import { SafetyCheckDecision } from '../safety/protocol.js'; @@ -25,7 +26,10 @@ describe('PolicyEngine', () => { mockCheckerRunner = { runChecker: vi.fn(), } as unknown as CheckerRunner; - engine = new PolicyEngine({}, mockCheckerRunner); + engine = new PolicyEngine( + { approvalMode: ApprovalMode.DEFAULT }, + mockCheckerRunner, + ); }); describe('constructor', () => { @@ -163,6 +167,41 @@ describe('PolicyEngine', () => { (await engine.check({ name: 'unknown-tool' }, undefined)).decision, ).toBe(PolicyDecision.DENY); }); + + it('should dynamically switch between modes and respect rule modes', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'edit', + decision: PolicyDecision.ASK_USER, + priority: 10, + }, + { + toolName: 'edit', + decision: PolicyDecision.ALLOW, + priority: 20, + modes: [ApprovalMode.AUTO_EDIT], + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Default mode: priority 20 rule doesn't match, falls back to priority 10 + expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe( + PolicyDecision.ASK_USER, + ); + + // Switch to autoEdit mode + engine.setApprovalMode(ApprovalMode.AUTO_EDIT); + expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe( + PolicyDecision.ALLOW, + ); + + // Switch back to default + engine.setApprovalMode(ApprovalMode.DEFAULT); + expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe( + PolicyDecision.ASK_USER, + ); + }); }); describe('addRule', () => { diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index 71a4ff232b..06e7adc00b 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -13,6 +13,7 @@ import { type HookCheckerRule, type HookExecutionContext, getHookSource, + ApprovalMode, } from './types.js'; import { stableStringify } from './stable-stringify.js'; import { debugLogger } from '../utils/debugLogger.js'; @@ -30,7 +31,15 @@ function ruleMatches( toolCall: FunctionCall, stringifiedArgs: string | undefined, serverName: string | undefined, + currentApprovalMode: ApprovalMode, ): boolean { + // Check if rule applies to current approval mode + if (rule.modes && rule.modes.length > 0) { + if (!rule.modes.includes(currentApprovalMode)) { + return false; + } + } + // Check tool name if specified if (rule.toolName) { // Support wildcard patterns: "serverName__*" matches "serverName__anyTool" @@ -98,6 +107,7 @@ export class PolicyEngine { private readonly nonInteractive: boolean; private readonly checkerRunner?: CheckerRunner; private readonly allowHooks: boolean; + private approvalMode: ApprovalMode; constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { this.rules = (config.rules ?? []).sort( @@ -113,6 +123,21 @@ export class PolicyEngine { this.nonInteractive = config.nonInteractive ?? false; this.checkerRunner = checkerRunner; this.allowHooks = config.allowHooks ?? true; + this.approvalMode = config.approvalMode ?? ApprovalMode.DEFAULT; + } + + /** + * Update the current approval mode. + */ + setApprovalMode(mode: ApprovalMode): void { + this.approvalMode = mode; + } + + /** + * Get the current approval mode. + */ + getApprovalMode(): ApprovalMode { + return this.approvalMode; } /** @@ -145,7 +170,15 @@ export class PolicyEngine { let decision: PolicyDecision | undefined; for (const rule of this.rules) { - if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) { + if ( + ruleMatches( + rule, + toolCall, + stringifiedArgs, + serverName, + this.approvalMode, + ) + ) { debugLogger.debug( `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, ); @@ -225,7 +258,15 @@ export class PolicyEngine { // If decision is not DENY, run safety checkers if (decision !== PolicyDecision.DENY && this.checkerRunner) { for (const checkerRule of this.checkers) { - if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) { + if ( + ruleMatches( + checkerRule, + toolCall, + stringifiedArgs, + serverName, + this.approvalMode, + ) + ) { debugLogger.debug( `[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`, ); diff --git a/packages/core/src/policy/shell-safety.test.ts b/packages/core/src/policy/shell-safety.test.ts index 89c8362f5d..bcc9f562d9 100644 --- a/packages/core/src/policy/shell-safety.test.ts +++ b/packages/core/src/policy/shell-safety.test.ts @@ -6,7 +6,7 @@ import { describe, it, expect, beforeEach } from 'vitest'; import { PolicyEngine } from './policy-engine.js'; -import { PolicyDecision } from './types.js'; +import { PolicyDecision, ApprovalMode } from './types.js'; import type { FunctionCall } from '@google/genai'; describe('Shell Safety Policy', () => { @@ -25,6 +25,7 @@ describe('Shell Safety Policy', () => { }, ], defaultDecision: PolicyDecision.ASK_USER, + approvalMode: ApprovalMode.DEFAULT, }); }); diff --git a/packages/core/src/policy/toml-loader.test.ts b/packages/core/src/policy/toml-loader.test.ts index 75c743dfd7..1bb41fdfd6 100644 --- a/packages/core/src/policy/toml-loader.test.ts +++ b/packages/core/src/policy/toml-loader.test.ts @@ -5,7 +5,7 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { ApprovalMode, PolicyDecision } from './types.js'; +import { PolicyDecision } from './types.js'; import * as fs from 'node:fs/promises'; import * as path from 'node:path'; import * as os from 'node:os'; @@ -36,7 +36,7 @@ describe('policy-toml-loader', () => { ): Promise { await fs.writeFile(path.join(tempDir, fileName), tomlContent); const getPolicyTier = (_dir: string) => 1; - return loadPoliciesFromToml(ApprovalMode.DEFAULT, [tempDir], getPolicyTier); + return loadPoliciesFromToml([tempDir], getPolicyTier); } describe('loadPoliciesFromToml', () => { @@ -133,7 +133,7 @@ priority = 100 expect(result.errors).toHaveLength(0); }); - it('should filter rules by mode', async () => { + it('should NOT filter rules by mode at load time but preserve modes property', async () => { const result = await runLoadPoliciesFromToml(` [[rule]] toolName = "glob" @@ -148,12 +148,39 @@ priority = 100 modes = ["yolo"] `); - // Only the first rule should be included (modes includes "default") - expect(result.rules).toHaveLength(1); + // Both rules should be included + expect(result.rules).toHaveLength(2); expect(result.rules[0].toolName).toBe('glob'); + expect(result.rules[0].modes).toEqual(['default', 'yolo']); + expect(result.rules[1].toolName).toBe('grep'); + expect(result.rules[1].modes).toEqual(['yolo']); expect(result.errors).toHaveLength(0); }); + it('should return error if modes property is used for Tier 2 and Tier 3 policies', async () => { + await fs.writeFile( + path.join(tempDir, 'tier2.toml'), + ` +[[rule]] +toolName = "tier2-tool" +decision = "allow" +priority = 100 +modes = ["autoEdit"] +`, + ); + + const getPolicyTier = (_dir: string) => 2; // Tier 2 + const result = await loadPoliciesFromToml([tempDir], getPolicyTier); + + // It still transforms the rule, but it should also report an error + expect(result.rules).toHaveLength(1); + expect(result.rules[0].toolName).toBe('tier2-tool'); + expect(result.rules[0].modes).toBeUndefined(); // Should be restricted + expect(result.errors).toHaveLength(1); + expect(result.errors[0].errorType).toBe('rule_validation'); + expect(result.errors[0].message).toContain('Restricted property "modes"'); + }); + it('should handle TOML parse errors', async () => { const result = await runLoadPoliciesFromToml(` [[rule] @@ -267,11 +294,7 @@ priority = -1 ); const getPolicyTier = (_dir: string) => 1; - const result = await loadPoliciesFromToml( - ApprovalMode.DEFAULT, - [tempDir], - getPolicyTier, - ); + const result = await loadPoliciesFromToml([tempDir], getPolicyTier); expect(result.rules).toHaveLength(1); expect(result.rules[0].toolName).toBe('glob'); @@ -439,11 +462,7 @@ priority = 100 await fs.writeFile(filePath, 'content'); const getPolicyTier = (_dir: string) => 1; - const result = await loadPoliciesFromToml( - ApprovalMode.DEFAULT, - [filePath], - getPolicyTier, - ); + const result = await loadPoliciesFromToml([filePath], getPolicyTier); expect(result.errors).toHaveLength(1); const error = result.errors[0]; diff --git a/packages/core/src/policy/toml-loader.ts b/packages/core/src/policy/toml-loader.ts index 79454c2e98..162a250906 100644 --- a/packages/core/src/policy/toml-loader.ts +++ b/packages/core/src/policy/toml-loader.ts @@ -7,7 +7,7 @@ import { type PolicyRule, PolicyDecision, - type ApprovalMode, + ApprovalMode, type SafetyCheckerConfig, type SafetyCheckerRule, InProcessCheckerType, @@ -43,7 +43,7 @@ const PolicyRuleSchema = z.object({ message: 'priority must be <= 999 to prevent tier overflow. Priorities >= 1000 would jump to the next tier.', }), - modes: z.array(z.string()).optional(), + modes: z.array(z.nativeEnum(ApprovalMode)).optional(), }); /** @@ -56,7 +56,7 @@ const SafetyCheckerRuleSchema = z.object({ commandPrefix: z.union([z.string(), z.array(z.string())]).optional(), commandRegex: z.string().optional(), priority: z.number().int().default(0), - modes: z.array(z.string()).optional(), + modes: z.array(z.nativeEnum(ApprovalMode)).optional(), checker: z.discriminatedUnion('type', [ z.object({ type: z.literal('in-process'), @@ -216,16 +216,13 @@ function transformPriority(priority: number, tier: number): number { * 1. Scans directories for .toml files * 2. Parses and validates each file * 3. Transforms rules (commandPrefix, arrays, mcpName, priorities) - * 4. Filters rules by approval mode - * 5. Collects detailed error information for any failures + * 4. Collects detailed error information for any failures * - * @param approvalMode The current approval mode (for filtering rules by mode) * @param policyDirs Array of directory paths to scan for policy files * @param getPolicyTier Function to determine tier (1-3) for a directory * @returns Object containing successfully parsed rules and any errors encountered */ export async function loadPoliciesFromToml( - approvalMode: ApprovalMode, policyDirs: string[], getPolicyTier: (dir: string) => number, ): Promise { @@ -305,6 +302,8 @@ export async function loadPoliciesFromToml( // Validate shell command convenience syntax const tomlRules = validationResult.data.rule ?? []; + const tomlCheckers = validationResult.data.safety_checker ?? []; + for (let i = 0; i < tomlRules.length; i++) { const rule = tomlRules[i]; const validationError = validateShellCommandSyntax(rule, i); @@ -320,17 +319,40 @@ export async function loadPoliciesFromToml( }); // Continue to next rule, don't skip the entire file } + + if (tier > 1 && rule.modes && rule.modes.length > 0) { + errors.push({ + filePath, + fileName: file, + tier: tierName, + ruleIndex: i, + errorType: 'rule_validation', + message: 'Restricted property "modes"', + details: `Rule #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`, + suggestion: 'Remove the "modes" property from this rule.', + }); + } + } + + for (let i = 0; i < tomlCheckers.length; i++) { + const checker = tomlCheckers[i]; + if (tier > 1 && checker.modes && checker.modes.length > 0) { + errors.push({ + filePath, + fileName: file, + tier: tierName, + ruleIndex: i, + errorType: 'rule_validation', + message: 'Restricted property "modes" in safety checker', + details: `Safety Checker #${i + 1}: The "modes" property is currently reserved for Tier 1 (system) policies and cannot be used in ${tierName} policies.`, + suggestion: + 'Remove the "modes" property from this safety checker.', + }); + } } // Transform rules const parsedRules: PolicyRule[] = (validationResult.data.rule ?? []) - .filter((rule) => { - // Filter by mode - if (!rule.modes || rule.modes.length === 0) { - return true; - } - return rule.modes.includes(approvalMode); - }) .flatMap((rule) => { // Transform commandPrefix/commandRegex to argsPattern let effectiveArgsPattern = rule.argsPattern; @@ -377,6 +399,7 @@ export async function loadPoliciesFromToml( toolName: effectiveToolName, decision: rule.decision, priority: transformPriority(rule.priority, tier), + modes: tier === 1 ? rule.modes : undefined, }; // Compile regex pattern @@ -412,12 +435,6 @@ export async function loadPoliciesFromToml( const parsedCheckers: SafetyCheckerRule[] = ( validationResult.data.safety_checker ?? [] ) - .filter((checker) => { - if (!checker.modes || checker.modes.length === 0) { - return true; - } - return checker.modes.includes(approvalMode); - }) .flatMap((checker) => { let effectiveArgsPattern = checker.argsPattern; const commandPrefixes: string[] = []; @@ -459,6 +476,7 @@ export async function loadPoliciesFromToml( toolName: effectiveToolName, priority: checker.priority, checker: checker.checker as SafetyCheckerConfig, + modes: tier === 1 ? checker.modes : undefined, }; if (argsPattern) { diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts index 410d9ff1c9..426fdaac9c 100644 --- a/packages/core/src/policy/types.ts +++ b/packages/core/src/policy/types.ts @@ -117,6 +117,12 @@ export interface PolicyRule { * Default is 0. */ priority?: number; + + /** + * Approval modes this rule applies to. + * If undefined or empty, it applies to all modes. + */ + modes?: ApprovalMode[]; } export interface SafetyCheckerRule { @@ -143,6 +149,12 @@ export interface SafetyCheckerRule { * additional validation of a tool call. */ checker: SafetyCheckerConfig; + + /** + * Approval modes this rule applies to. + * If undefined or empty, it applies to all modes. + */ + modes?: ApprovalMode[]; } export interface HookExecutionContext { @@ -215,6 +227,12 @@ export interface PolicyEngineConfig { * Defaults to true. */ allowHooks?: boolean; + + /** + * Current approval mode. + * Used to filter rules that have specific 'modes' defined. + */ + approvalMode?: ApprovalMode; } export interface PolicySettings {