From 1ed163a6660eea44ba4254b4b0eb41824d84bc11 Mon Sep 17 00:00:00 2001 From: Allen Hutchison Date: Wed, 12 Nov 2025 13:18:34 -0800 Subject: [PATCH] feat(safety): Introduce safety checker framework (#12504) --- .../config/policy-engine.integration.test.ts | 268 ++++--- .../src/confirmation-bus/message-bus.test.ts | 54 +- .../core/src/confirmation-bus/message-bus.ts | 4 +- packages/core/src/policy/config.test.ts | 190 ++++- packages/core/src/policy/config.ts | 12 +- packages/core/src/policy/policies/write.toml | 10 + .../core/src/policy/policy-engine.test.ts | 756 ++++++++++++++---- packages/core/src/policy/policy-engine.ts | 100 ++- packages/core/src/policy/toml-loader.test.ts | 1 + packages/core/src/policy/toml-loader.ts | 128 ++- packages/core/src/policy/types.ts | 79 ++ packages/core/src/safety/built-in.test.ts | 244 ++++++ packages/core/src/safety/built-in.ts | 154 ++++ .../core/src/safety/checker-runner.test.ts | 303 +++++++ packages/core/src/safety/checker-runner.ts | 297 +++++++ .../core/src/safety/context-builder.test.ts | 56 ++ packages/core/src/safety/context-builder.ts | 54 ++ packages/core/src/safety/protocol.ts | 100 +++ packages/core/src/safety/registry.test.ts | 47 ++ packages/core/src/safety/registry.ts | 83 ++ .../src/tools/base-tool-invocation.test.ts | 24 +- 21 files changed, 2636 insertions(+), 328 deletions(-) create mode 100644 packages/core/src/safety/built-in.test.ts create mode 100644 packages/core/src/safety/built-in.ts create mode 100644 packages/core/src/safety/checker-runner.test.ts create mode 100644 packages/core/src/safety/checker-runner.ts create mode 100644 packages/core/src/safety/context-builder.test.ts create mode 100644 packages/core/src/safety/context-builder.ts create mode 100644 packages/core/src/safety/protocol.ts create mode 100644 packages/core/src/safety/registry.test.ts create mode 100644 packages/core/src/safety/registry.ts diff --git a/packages/cli/src/config/policy-engine.integration.test.ts b/packages/cli/src/config/policy-engine.integration.test.ts index 0c22cfeba9..092567fb6d 100644 --- a/packages/cli/src/config/policy-engine.integration.test.ts +++ b/packages/cli/src/config/policy-engine.integration.test.ts @@ -30,24 +30,24 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Allowed tool should be allowed - expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'run_shell_command' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // Excluded tool should be denied - expect(engine.check({ name: 'write_file' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'write_file' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); // Other write tools should ask user - expect(engine.check({ name: 'replace' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'replace' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // Unknown tools should use default - expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'unknown_tool' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); }); it('should handle MCP server wildcard patterns correctly', async () => { @@ -72,33 +72,49 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Tools from allowed server should be allowed - expect(engine.check({ name: 'allowed-server__tool1' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + // Tools from allowed server should be allowed expect( - engine.check({ name: 'allowed-server__another_tool' }, undefined), + (await engine.check({ name: 'allowed-server__tool1' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); + expect( + ( + await engine.check( + { name: 'allowed-server__another_tool' }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.ALLOW); // Tools from trusted server should be allowed - expect(engine.check({ name: 'trusted-server__tool1' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); expect( - engine.check({ name: 'trusted-server__special_tool' }, undefined), + (await engine.check({ name: 'trusted-server__tool1' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); + expect( + ( + await engine.check( + { name: 'trusted-server__special_tool' }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.ALLOW); // Tools from blocked server should be denied - expect(engine.check({ name: 'blocked-server__tool1' }, undefined)).toBe( - PolicyDecision.DENY, - ); expect( - engine.check({ name: 'blocked-server__any_tool' }, undefined), + (await engine.check({ name: 'blocked-server__tool1' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'blocked-server__any_tool' }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); // Tools from unknown servers should use default - expect(engine.check({ name: 'unknown-server__tool' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'unknown-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ASK_USER); }); it('should correctly prioritize specific tool excludes over MCP server wildcards', async () => { @@ -118,12 +134,15 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // MCP server allowed (priority 2.1) provides general allow for server - expect(engine.check({ name: 'my-server__safe-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + // MCP server allowed (priority 2.1) provides general allow for server + expect( + (await engine.check({ name: 'my-server__safe-tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); // But specific tool exclude (priority 2.4) wins over server allow expect( - engine.check({ name: 'my-server__dangerous-tool' }, undefined), + (await engine.check({ name: 'my-server__dangerous-tool' }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); }); @@ -154,46 +173,50 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Read-only tools should be allowed (autoAccept) - expect(engine.check({ name: 'read_file' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'list_directory' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'read_file' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'list_directory' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // But glob is explicitly excluded, so it should be denied - expect(engine.check({ name: 'glob' }, undefined)).toBe( + expect((await engine.check({ name: 'glob' }, undefined)).decision).toBe( PolicyDecision.DENY, ); // Replace should ask user (normal write tool behavior) - expect(engine.check({ name: 'replace' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'replace' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // Explicitly allowed tools - expect(engine.check({ name: 'custom-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'my-server__special-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'custom-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'my-server__special-tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); // MCP server tools - expect(engine.check({ name: 'allowed-server__tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'trusted-server__tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'blocked-server__tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'allowed-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'trusted-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'blocked-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); // Write tools should ask by default - expect(engine.check({ name: 'write_file' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'write_file' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); }); it('should handle YOLO mode correctly', async () => { @@ -210,20 +233,20 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Most tools should be allowed in YOLO mode - expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'write_file' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'run_shell_command' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'write_file' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'unknown_tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // But explicitly excluded tools should still be denied - expect(engine.check({ name: 'dangerous-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'dangerous-tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); }); it('should handle AUTO_EDIT mode correctly', async () => { @@ -236,17 +259,17 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Edit tools should be allowed in AUTO_EDIT mode - expect(engine.check({ name: 'replace' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'write_file' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'replace' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'write_file' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // Other tools should follow normal rules - expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'run_shell_command' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); }); it('should verify priority ordering works correctly in practice', async () => { @@ -305,22 +328,24 @@ describe('Policy Engine Integration Tests', () => { expect(readOnlyToolRule?.priority).toBeCloseTo(1.05, 5); // Verify the engine applies these priorities correctly - expect(engine.check({ name: 'blocked-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'blocked-server__any' }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'specific-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'trusted-server__any' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'mcp-server__any' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'glob' }, undefined)).toBe( + expect( + (await engine.check({ name: 'blocked-tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'blocked-server__any' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'specific-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'trusted-server__any' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'mcp-server__any' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect((await engine.check({ name: 'glob' }, undefined)).decision).toBe( PolicyDecision.ALLOW, ); }); @@ -346,9 +371,10 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Exclusion (195) should win over trust (90) - expect(engine.check({ name: 'conflicted-server__tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'conflicted-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); }); it('should handle edge case: specific tool allowed but server excluded', async () => { @@ -369,12 +395,14 @@ describe('Policy Engine Integration Tests', () => { // Server exclusion (195) wins over specific tool allow (100) // This might be counterintuitive but follows the priority system - expect(engine.check({ name: 'my-server__special-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'my-server__other-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'my-server__special-tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'my-server__other-tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); }); it('should verify non-interactive mode transformation', async () => { @@ -389,12 +417,12 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(engineConfig); // ASK_USER should become DENY in non-interactive mode - expect(engine.check({ name: 'unknown_tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'run_shell_command' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'unknown_tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'run_shell_command' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); }); it('should handle empty settings gracefully', async () => { @@ -407,17 +435,17 @@ describe('Policy Engine Integration Tests', () => { const engine = new PolicyEngine(config); // Should have default rules for write tools - expect(engine.check({ name: 'write_file' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); - expect(engine.check({ name: 'replace' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'write_file' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); + expect( + (await engine.check({ name: 'replace' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // Unknown tools should use default - expect(engine.check({ name: 'unknown' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'unknown' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); }); it('should verify rules are created with correct priorities', async () => { diff --git a/packages/core/src/confirmation-bus/message-bus.test.ts b/packages/core/src/confirmation-bus/message-bus.test.ts index 8156671c9b..e240df1532 100644 --- a/packages/core/src/confirmation-bus/message-bus.test.ts +++ b/packages/core/src/confirmation-bus/message-bus.test.ts @@ -26,12 +26,12 @@ describe('MessageBus', () => { }); describe('publish', () => { - it('should emit error for invalid message', () => { + it('should emit error for invalid message', async () => { const errorHandler = vi.fn(); messageBus.on('error', errorHandler); // @ts-expect-error - Testing invalid message - messageBus.publish({ invalid: 'message' }); + await messageBus.publish({ invalid: 'message' }); expect(errorHandler).toHaveBeenCalledWith( expect.objectContaining({ @@ -40,12 +40,12 @@ describe('MessageBus', () => { ); }); - it('should validate tool confirmation requests have correlationId', () => { + it('should validate tool confirmation requests have correlationId', async () => { const errorHandler = vi.fn(); messageBus.on('error', errorHandler); // @ts-expect-error - Testing missing correlationId - messageBus.publish({ + await messageBus.publish({ type: MessageBusType.TOOL_CONFIRMATION_REQUEST, toolCall: { name: 'test' }, }); @@ -53,8 +53,10 @@ describe('MessageBus', () => { expect(errorHandler).toHaveBeenCalled(); }); - it('should emit confirmation response when policy allows', () => { - vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW); + it('should emit confirmation response when policy allows', async () => { + vi.spyOn(policyEngine, 'check').mockResolvedValue({ + decision: PolicyDecision.ALLOW, + }); const responseHandler = vi.fn(); messageBus.subscribe( @@ -68,7 +70,7 @@ describe('MessageBus', () => { correlationId: '123', }; - messageBus.publish(request); + await messageBus.publish(request); const expectedResponse: ToolConfirmationResponse = { type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, @@ -78,8 +80,10 @@ describe('MessageBus', () => { expect(responseHandler).toHaveBeenCalledWith(expectedResponse); }); - it('should emit rejection and response when policy denies', () => { - vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY); + it('should emit rejection and response when policy denies', async () => { + vi.spyOn(policyEngine, 'check').mockResolvedValue({ + decision: PolicyDecision.DENY, + }); const responseHandler = vi.fn(); const rejectionHandler = vi.fn(); @@ -98,7 +102,7 @@ describe('MessageBus', () => { correlationId: '123', }; - messageBus.publish(request); + await messageBus.publish(request); const expectedRejection: ToolPolicyRejection = { type: MessageBusType.TOOL_POLICY_REJECTION, @@ -114,8 +118,10 @@ describe('MessageBus', () => { expect(responseHandler).toHaveBeenCalledWith(expectedResponse); }); - it('should pass through to UI when policy says ASK_USER', () => { - vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER); + it('should pass through to UI when policy says ASK_USER', async () => { + vi.spyOn(policyEngine, 'check').mockResolvedValue({ + decision: PolicyDecision.ASK_USER, + }); const requestHandler = vi.fn(); messageBus.subscribe( @@ -129,12 +135,12 @@ describe('MessageBus', () => { correlationId: '123', }; - messageBus.publish(request); + await messageBus.publish(request); expect(requestHandler).toHaveBeenCalledWith(request); }); - it('should emit other message types directly', () => { + it('should emit other message types directly', async () => { const successHandler = vi.fn(); messageBus.subscribe( MessageBusType.TOOL_EXECUTION_SUCCESS, @@ -147,14 +153,14 @@ describe('MessageBus', () => { result: 'success', }; - messageBus.publish(message); + await messageBus.publish(message); expect(successHandler).toHaveBeenCalledWith(message); }); }); describe('subscribe/unsubscribe', () => { - it('should allow subscribing to specific message types', () => { + it('should allow subscribing to specific message types', async () => { const handler = vi.fn(); messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); @@ -164,12 +170,12 @@ describe('MessageBus', () => { result: 'test', }; - messageBus.publish(message); + await messageBus.publish(message); expect(handler).toHaveBeenCalledWith(message); }); - it('should allow unsubscribing from message types', () => { + it('should allow unsubscribing from message types', async () => { const handler = vi.fn(); messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); @@ -180,12 +186,12 @@ describe('MessageBus', () => { result: 'test', }; - messageBus.publish(message); + await messageBus.publish(message); expect(handler).not.toHaveBeenCalled(); }); - it('should support multiple subscribers for the same message type', () => { + it('should support multiple subscribers for the same message type', async () => { const handler1 = vi.fn(); const handler2 = vi.fn(); @@ -198,7 +204,7 @@ describe('MessageBus', () => { result: 'test', }; - messageBus.publish(message); + await messageBus.publish(message); expect(handler1).toHaveBeenCalledWith(message); expect(handler2).toHaveBeenCalledWith(message); @@ -206,12 +212,12 @@ describe('MessageBus', () => { }); describe('error handling', () => { - it('should not crash on errors during message processing', () => { + it('should not crash on errors during message processing', async () => { const errorHandler = vi.fn(); messageBus.on('error', errorHandler); // Mock policyEngine to throw an error - vi.spyOn(policyEngine, 'check').mockImplementation(() => { + vi.spyOn(policyEngine, 'check').mockImplementation(async () => { throw new Error('Policy check failed'); }); @@ -222,7 +228,7 @@ describe('MessageBus', () => { }; // Should not throw - expect(() => messageBus.publish(request)).not.toThrow(); + await expect(messageBus.publish(request)).resolves.not.toThrow(); // Should emit error expect(errorHandler).toHaveBeenCalledWith( diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index cd293a72a6..0b021049f8 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -38,7 +38,7 @@ export class MessageBus extends EventEmitter { this.emit(message.type, message); } - publish(message: Message): void { + async publish(message: Message): Promise { if (this.debug) { console.debug(`[MESSAGE_BUS] publish: ${safeJsonStringify(message)}`); } @@ -50,7 +50,7 @@ export class MessageBus extends EventEmitter { } if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { - const decision = this.policyEngine.check( + const { decision } = await this.policyEngine.check( message.toolCall, message.serverName, ); diff --git a/packages/core/src/policy/config.test.ts b/packages/core/src/policy/config.test.ts index 460087639b..b8e76f73b6 100644 --- a/packages/core/src/policy/config.test.ts +++ b/packages/core/src/policy/config.test.ts @@ -9,7 +9,7 @@ import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest'; import nodePath from 'node:path'; import type { PolicySettings } from './types.js'; -import { ApprovalMode, PolicyDecision } from './types.js'; +import { ApprovalMode, PolicyDecision, InProcessCheckerType } from './types.js'; import { Storage } from '../config/storage.js'; @@ -642,6 +642,194 @@ priority = 150 vi.doUnmock('node:fs/promises'); }); + it('should load safety_checker configuration from TOML', async () => { + const actualFs = + await vi.importActual( + 'node:fs/promises', + ); + + const mockReaddir = vi.fn( + async ( + path: string | Buffer | URL, + options?: Parameters[1], + ) => { + if ( + typeof path === 'string' && + nodePath + .normalize(path) + .includes(nodePath.normalize('.gemini/policies')) + ) { + return [ + { + name: 'safety.toml', + isFile: () => true, + isDirectory: () => false, + }, + ] as unknown as Awaited>; + } + return actualFs.readdir( + path, + options as Parameters[1], + ); + }, + ); + + const mockReadFile = vi.fn( + async ( + path: Parameters[0], + options: Parameters[1], + ) => { + if ( + typeof path === 'string' && + nodePath + .normalize(path) + .includes(nodePath.normalize('.gemini/policies/safety.toml')) + ) { + return ` +[[rule]] +toolName = "write_file" +decision = "allow" +priority = 10 + +[[rule]] +toolName = "write_file" +decision = "allow" +priority = 10 + +[[safety_checker]] +toolName = "write_file" +priority = 10 +[safety_checker.checker] +type = "in-process" +name = "allowed-path" +required_context = ["environment"] +[safety_checker.checker.config] +`; + } + return actualFs.readFile(path, options); + }, + ); + + vi.doMock('node:fs/promises', () => ({ + ...actualFs, + default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, + readFile: mockReadFile, + readdir: mockReaddir, + })); + + vi.resetModules(); + const { createPolicyEngineConfig } = await import('./config.js'); + + const settings: PolicySettings = {}; + const config = await createPolicyEngineConfig( + settings, + ApprovalMode.DEFAULT, + '/tmp/mock/default/policies', + ); + + const rule = config.rules?.find( + (r) => r.toolName === 'write_file' && r.decision === PolicyDecision.ALLOW, + ); + expect(rule).toBeDefined(); + + const checker = config.checkers?.find( + (c) => c.toolName === 'write_file' && c.checker.type === 'in-process', + ); + expect(checker).toBeDefined(); + expect(checker?.checker.type).toBe('in-process'); + expect(checker?.checker.name).toBe(InProcessCheckerType.ALLOWED_PATH); + expect(checker?.checker.required_context).toEqual(['environment']); + + vi.doUnmock('node:fs/promises'); + }); + + it('should reject invalid in-process checker names', async () => { + const actualFs = + await vi.importActual( + 'node:fs/promises', + ); + + const mockReaddir = vi.fn( + async ( + path: string | Buffer | URL, + options?: Parameters[1], + ) => { + if ( + typeof path === 'string' && + nodePath + .normalize(path) + .includes(nodePath.normalize('.gemini/policies')) + ) { + return [ + { + name: 'invalid_safety.toml', + isFile: () => true, + isDirectory: () => false, + }, + ] as unknown as Awaited>; + } + return actualFs.readdir( + path, + options as Parameters[1], + ); + }, + ); + + const mockReadFile = vi.fn( + async ( + path: Parameters[0], + options: Parameters[1], + ) => { + if ( + typeof path === 'string' && + nodePath + .normalize(path) + .includes( + nodePath.normalize('.gemini/policies/invalid_safety.toml'), + ) + ) { + return ` +[[rule]] +toolName = "write_file" +decision = "allow" +priority = 10 + +[[safety_checker]] +toolName = "write_file" +priority = 10 +[safety_checker.checker] +type = "in-process" +name = "invalid-name" +`; + } + return actualFs.readFile(path, options); + }, + ); + + vi.doMock('node:fs/promises', () => ({ + ...actualFs, + default: { ...actualFs, readFile: mockReadFile, readdir: mockReaddir }, + readFile: mockReadFile, + readdir: mockReaddir, + })); + + vi.resetModules(); + const { createPolicyEngineConfig } = await import('./config.js'); + + const settings: PolicySettings = {}; + const config = await createPolicyEngineConfig( + settings, + ApprovalMode.DEFAULT, + '/tmp/mock/default/policies', + ); + + // The rule should be rejected because 'invalid-name' is not in the enum + const rule = config.rules?.find((r) => r.toolName === 'write_file'); + expect(rule).toBeUndefined(); + + vi.doUnmock('node:fs/promises'); + }); + it('should have default ASK_USER rule for discovered tools', async () => { vi.resetModules(); vi.doUnmock('node:fs/promises'); diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index 64d6476d10..6ea78d30ca 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -114,10 +114,12 @@ export async function createPolicyEngineConfig( const policyDirs = getPolicyDirectories(defaultPoliciesDir); // Load policies from TOML files - const { rules: tomlRules, errors } = await loadPoliciesFromToml( - approvalMode, - policyDirs, - (dir) => getPolicyTier(dir, defaultPoliciesDir), + const { + rules: tomlRules, + checkers: tomlCheckers, + errors, + } = await loadPoliciesFromToml(approvalMode, policyDirs, (dir) => + getPolicyTier(dir, defaultPoliciesDir), ); // Emit any errors encountered during TOML loading to the UI @@ -129,6 +131,7 @@ export async function createPolicyEngineConfig( } const rules: PolicyRule[] = [...tomlRules]; + const checkers = [...tomlCheckers]; // Priority system for policy rules: // - Higher priority numbers win over lower priority numbers @@ -225,6 +228,7 @@ export async function createPolicyEngineConfig( return { rules, + checkers, defaultDecision: PolicyDecision.ASK_USER, }; } diff --git a/packages/core/src/policy/policies/write.toml b/packages/core/src/policy/policies/write.toml index 8e4c1ae70e..09387b59c1 100644 --- a/packages/core/src/policy/policies/write.toml +++ b/packages/core/src/policy/policies/write.toml @@ -36,6 +36,11 @@ decision = "allow" priority = 15 modes = ["autoEdit"] +[rule.safety_checker] +type = "in-process" +name = "allowed-path" +required_context = ["environment"] + [[rule]] toolName = "save_memory" decision = "ask_user" @@ -57,6 +62,11 @@ decision = "allow" priority = 15 modes = ["autoEdit"] +[rule.safety_checker] +type = "in-process" +name = "allowed-path" +required_context = ["environment"] + [[rule]] toolName = "web_fetch" decision = "ask_user" diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index fd3a4b62b2..5cb7cd3b9a 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -4,31 +4,39 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, beforeEach } from 'vitest'; +import { describe, it, expect, beforeEach, vi } from 'vitest'; import { PolicyEngine } from './policy-engine.js'; import { PolicyDecision, type PolicyRule, type PolicyEngineConfig, + type SafetyCheckerRule, + InProcessCheckerType, } from './types.js'; import type { FunctionCall } from '@google/genai'; +import { SafetyCheckDecision } from '../safety/protocol.js'; +import type { CheckerRunner } from '../safety/checker-runner.js'; describe('PolicyEngine', () => { let engine: PolicyEngine; + let mockCheckerRunner: CheckerRunner; beforeEach(() => { - engine = new PolicyEngine(); + mockCheckerRunner = { + runChecker: vi.fn(), + } as unknown as CheckerRunner; + engine = new PolicyEngine({}, mockCheckerRunner); }); describe('constructor', () => { - it('should use default config when none provided', () => { - const decision = engine.check({ name: 'test' }, undefined); + it('should use default config when none provided', async () => { + const { decision } = await engine.check({ name: 'test' }, undefined); expect(decision).toBe(PolicyDecision.ASK_USER); }); - it('should respect custom default decision', () => { + it('should respect custom default decision', async () => { engine = new PolicyEngine({ defaultDecision: PolicyDecision.DENY }); - const decision = engine.check({ name: 'test' }, undefined); + const { decision } = await engine.check({ name: 'test' }, undefined); expect(decision).toBe(PolicyDecision.DENY); }); @@ -49,7 +57,7 @@ describe('PolicyEngine', () => { }); describe('check', () => { - it('should match tool by name', () => { + it('should match tool by name', async () => { const rules: PolicyRule[] = [ { toolName: 'shell', decision: PolicyDecision.ALLOW }, { toolName: 'edit', decision: PolicyDecision.DENY }, @@ -57,18 +65,18 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); - expect(engine.check({ name: 'shell' }, undefined)).toBe( + expect((await engine.check({ name: 'shell' }, undefined)).decision).toBe( PolicyDecision.ALLOW, ); - expect(engine.check({ name: 'edit' }, undefined)).toBe( + expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe( PolicyDecision.DENY, ); - expect(engine.check({ name: 'other' }, undefined)).toBe( + expect((await engine.check({ name: 'other' }, undefined)).decision).toBe( PolicyDecision.ASK_USER, ); }); - it('should match by args pattern', () => { + it('should match by args pattern', async () => { const rules: PolicyRule[] = [ { toolName: 'shell', @@ -93,11 +101,15 @@ describe('PolicyEngine', () => { args: { command: 'ls -la' }, }; - expect(engine.check(dangerousCall, undefined)).toBe(PolicyDecision.DENY); - expect(engine.check(safeCall, undefined)).toBe(PolicyDecision.ALLOW); + expect((await engine.check(dangerousCall, undefined)).decision).toBe( + PolicyDecision.DENY, + ); + expect((await engine.check(safeCall, undefined)).decision).toBe( + PolicyDecision.ALLOW, + ); }); - it('should apply rules by priority', () => { + it('should apply rules by priority', async () => { const rules: PolicyRule[] = [ { toolName: 'shell', decision: PolicyDecision.DENY, priority: 1 }, { toolName: 'shell', decision: PolicyDecision.ALLOW, priority: 10 }, @@ -106,12 +118,12 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Higher priority rule (ALLOW) should win - expect(engine.check({ name: 'shell' }, undefined)).toBe( + expect((await engine.check({ name: 'shell' }, undefined)).decision).toBe( PolicyDecision.ALLOW, ); }); - it('should apply wildcard rules (no toolName)', () => { + it('should apply wildcard rules (no toolName)', async () => { const rules: PolicyRule[] = [ { decision: PolicyDecision.DENY }, // Applies to all tools { toolName: 'safe-tool', decision: PolicyDecision.ALLOW, priority: 10 }, @@ -119,15 +131,15 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); - expect(engine.check({ name: 'safe-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'any-other-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'safe-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'any-other-tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); }); - it('should handle non-interactive mode', () => { + it('should handle non-interactive mode', async () => { const config: PolicyEngineConfig = { nonInteractive: true, rules: [ @@ -139,17 +151,17 @@ describe('PolicyEngine', () => { engine = new PolicyEngine(config); // ASK_USER should become DENY in non-interactive mode - expect(engine.check({ name: 'interactive-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'interactive-tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); // ALLOW should remain ALLOW - expect(engine.check({ name: 'allowed-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'allowed-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // Default ASK_USER should also become DENY - expect(engine.check({ name: 'unknown-tool' }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'unknown-tool' }, undefined)).decision, + ).toBe(PolicyDecision.DENY); }); }); @@ -178,16 +190,16 @@ describe('PolicyEngine', () => { expect(rules[2].priority).toBe(1); }); - it('should apply newly added rules', () => { - expect(engine.check({ name: 'new-tool' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + it('should apply newly added rules', async () => { + expect( + (await engine.check({ name: 'new-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); engine.addRule({ toolName: 'new-tool', decision: PolicyDecision.ALLOW }); - expect(engine.check({ name: 'new-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'new-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); }); }); @@ -236,7 +248,7 @@ describe('PolicyEngine', () => { }); describe('MCP server wildcard patterns', () => { - it('should match MCP server wildcard patterns', () => { + it('should match MCP server wildcard patterns', async () => { const rules: PolicyRule[] = [ { toolName: 'my-server__*', @@ -253,34 +265,38 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Should match my-server tools - expect(engine.check({ name: 'my-server__tool1' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); - expect(engine.check({ name: 'my-server__another_tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'my-server__tool1' }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + expect( + (await engine.check({ name: 'my-server__another_tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); // Should match blocked-server tools - expect(engine.check({ name: 'blocked-server__tool1' }, undefined)).toBe( - PolicyDecision.DENY, - ); expect( - engine.check({ name: 'blocked-server__dangerous' }, undefined), + (await engine.check({ name: 'blocked-server__tool1' }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'blocked-server__dangerous' }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); // Should not match other patterns - expect(engine.check({ name: 'other-server__tool' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); - expect(engine.check({ name: 'my-server-tool' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); // No __ separator - expect(engine.check({ name: 'my-server' }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); // No tool name + expect( + (await engine.check({ name: 'other-server__tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ASK_USER); + expect( + (await engine.check({ name: 'my-server-tool' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // No __ separator + expect( + (await engine.check({ name: 'my-server' }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // No tool name }); - it('should prioritize specific tool rules over server wildcards', () => { + it('should prioritize specific tool rules over server wildcards', async () => { const rules: PolicyRule[] = [ { toolName: 'my-server__*', @@ -298,14 +314,16 @@ describe('PolicyEngine', () => { // Specific tool deny should override server allow expect( - engine.check({ name: 'my-server__dangerous-tool' }, undefined), + (await engine.check({ name: 'my-server__dangerous-tool' }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); - expect(engine.check({ name: 'my-server__safe-tool' }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'my-server__safe-tool' }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); }); - it('should NOT match spoofed server names when using wildcards', () => { + it('should NOT match spoofed server names when using wildcards', async () => { // Vulnerability: A rule for 'prefix__*' matches 'prefix__suffix__tool' // effectively allowing a server named 'prefix__suffix' to spoof 'prefix'. const rules: PolicyRule[] = [ @@ -321,12 +339,13 @@ describe('PolicyEngine', () => { // CURRENT BEHAVIOR (FIXED): Matches because it starts with 'safe_server__' BUT serverName doesn't match 'safe_server' // We expect this to FAIL matching the ALLOW rule, thus falling back to default (ASK_USER) - expect(engine.check(spoofedToolCall, 'safe_server__malicious')).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check(spoofedToolCall, 'safe_server__malicious')) + .decision, + ).toBe(PolicyDecision.ASK_USER); }); - it('should verify tool name prefix even if serverName matches', () => { + it('should verify tool name prefix even if serverName matches', async () => { const rules: PolicyRule[] = [ { toolName: 'safe_server__*', @@ -337,12 +356,12 @@ describe('PolicyEngine', () => { // serverName matches, but tool name does not start with prefix const invalidToolCall = { name: 'other_server__tool' }; - expect(engine.check(invalidToolCall, 'safe_server')).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check(invalidToolCall, 'safe_server')).decision, + ).toBe(PolicyDecision.ASK_USER); }); - it('should allow when both serverName and tool name prefix match', () => { + it('should allow when both serverName and tool name prefix match', async () => { const rules: PolicyRule[] = [ { toolName: 'safe_server__*', @@ -352,14 +371,14 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); const validToolCall = { name: 'safe_server__tool' }; - expect(engine.check(validToolCall, 'safe_server')).toBe( + expect((await engine.check(validToolCall, 'safe_server')).decision).toBe( PolicyDecision.ALLOW, ); }); }); describe('complex scenarios', () => { - it('should handle multiple matching rules with different priorities', () => { + it('should handle multiple matching rules with different priorities', async () => { const rules: PolicyRule[] = [ { decision: PolicyDecision.DENY, priority: 0 }, // Default deny all { toolName: 'shell', decision: PolicyDecision.ASK_USER, priority: 5 }, @@ -375,21 +394,31 @@ describe('PolicyEngine', () => { // Matches highest priority rule (ls command) expect( - engine.check({ name: 'shell', args: { command: 'ls -la' } }, undefined), + ( + await engine.check( + { name: 'shell', args: { command: 'ls -la' } }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.ALLOW); // Matches middle priority rule (shell without ls) expect( - engine.check({ name: 'shell', args: { command: 'pwd' } }, undefined), + ( + await engine.check( + { name: 'shell', args: { command: 'pwd' } }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.ASK_USER); // Matches lowest priority rule (not shell) - expect(engine.check({ name: 'edit' }, undefined)).toBe( + expect((await engine.check({ name: 'edit' }, undefined)).decision).toBe( PolicyDecision.DENY, ); }); - it('should handle tools with no args', () => { + it('should handle tools with no args', async () => { const rules: PolicyRule[] = [ { toolName: 'read', @@ -401,22 +430,32 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Tool call without args should not match pattern - expect(engine.check({ name: 'read' }, undefined)).toBe( + expect((await engine.check({ name: 'read' }, undefined)).decision).toBe( PolicyDecision.ASK_USER, ); // Tool call with args not matching pattern expect( - engine.check({ name: 'read', args: { file: 'public.txt' } }, undefined), + ( + await engine.check( + { name: 'read', args: { file: 'public.txt' } }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.ASK_USER); // Tool call with args matching pattern expect( - engine.check({ name: 'read', args: { file: 'secret.txt' } }, undefined), + ( + await engine.check( + { name: 'read', args: { file: 'secret.txt' } }, + undefined, + ) + ).decision, ).toBe(PolicyDecision.DENY); }); - it('should match args pattern regardless of property order', () => { + it('should match args pattern regardless of property order', async () => { const rules: PolicyRule[] = [ { toolName: 'shell', @@ -432,21 +471,24 @@ describe('PolicyEngine', () => { const args1 = { command: 'rm -rf /', path: '/home' }; const args2 = { path: '/home', command: 'rm -rf /' }; - expect(engine.check({ name: 'shell', args: args1 }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'shell', args: args2 }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'shell', args: args1 }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'shell', args: args2 }, undefined)) + .decision, + ).toBe(PolicyDecision.DENY); // Verify safe command doesn't match const safeArgs = { command: 'ls -la', path: '/home' }; - expect(engine.check({ name: 'shell', args: safeArgs }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'shell', args: safeArgs }, undefined)) + .decision, + ).toBe(PolicyDecision.ASK_USER); }); - it('should handle nested objects in args with stable stringification', () => { + it('should handle nested objects in args with stable stringification', async () => { const rules: PolicyRule[] = [ { toolName: 'api', @@ -467,15 +509,15 @@ describe('PolicyEngine', () => { data: { value: 'secret', sensitive: true }, }; - expect(engine.check({ name: 'api', args: args1 }, undefined)).toBe( - PolicyDecision.DENY, - ); - expect(engine.check({ name: 'api', args: args2 }, undefined)).toBe( - PolicyDecision.DENY, - ); + expect( + (await engine.check({ name: 'api', args: args1 }, undefined)).decision, + ).toBe(PolicyDecision.DENY); + expect( + (await engine.check({ name: 'api', args: args2 }, undefined)).decision, + ).toBe(PolicyDecision.DENY); }); - it('should handle circular references without stack overflow', () => { + it('should handle circular references without stack overflow', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -499,23 +541,25 @@ describe('PolicyEngine', () => { circularArgs.data; // Should not throw stack overflow error - expect(() => + await expect( engine.check({ name: 'test', args: circularArgs }, undefined), - ).not.toThrow(); + ).resolves.not.toThrow(); // Should detect the circular reference pattern expect( - engine.check({ name: 'test', args: circularArgs }, undefined), + (await engine.check({ name: 'test', args: circularArgs }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); // Non-circular object should not match const normalArgs = { name: 'test', data: { value: 'normal' } }; - expect(engine.check({ name: 'test', args: normalArgs }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'test', args: normalArgs }, undefined)) + .decision, + ).toBe(PolicyDecision.ASK_USER); }); - it('should handle deep circular references', () => { + it('should handle deep circular references', async () => { const rules: PolicyRule[] = [ { toolName: 'deep', @@ -546,17 +590,18 @@ describe('PolicyEngine', () => { level3['back'] = deepCircular.level1; // Should handle without stack overflow - expect(() => + await expect( engine.check({ name: 'deep', args: deepCircular }, undefined), - ).not.toThrow(); + ).resolves.not.toThrow(); // Should detect the circular reference expect( - engine.check({ name: 'deep', args: deepCircular }, undefined), + (await engine.check({ name: 'deep', args: deepCircular }, undefined)) + .decision, ).toBe(PolicyDecision.DENY); }); - it('should handle repeated non-circular objects correctly', () => { + it('should handle repeated non-circular objects correctly', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -582,12 +627,12 @@ describe('PolicyEngine', () => { }; // Should NOT mark repeated objects as circular, and should match the shared value pattern - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); }); - it('should omit undefined and function values from objects', () => { + it('should omit undefined and function values from objects', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -606,9 +651,9 @@ describe('PolicyEngine', () => { }; // Should match pattern with defined value, undefined and functions omitted - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); // Check that the pattern would NOT match if undefined was included const rulesWithUndefined: PolicyRule[] = [ @@ -619,9 +664,9 @@ describe('PolicyEngine', () => { }, ]; engine = new PolicyEngine({ rules: rulesWithUndefined }); - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); // Check that the pattern would NOT match if function was included const rulesWithFunction: PolicyRule[] = [ @@ -632,12 +677,12 @@ describe('PolicyEngine', () => { }, ]; engine = new PolicyEngine({ rules: rulesWithFunction }); - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ASK_USER, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ASK_USER); }); - it('should convert undefined and functions to null in arrays', () => { + it('should convert undefined and functions to null in arrays', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -653,12 +698,12 @@ describe('PolicyEngine', () => { }; // Should match pattern with undefined and functions converted to null - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); }); - it('should produce valid JSON for all inputs', () => { + it('should produce valid JSON for all inputs', async () => { const testCases: Array<{ input: Record; desc: string }> = [ { input: { simple: 'string' }, desc: 'simple object' }, @@ -689,18 +734,19 @@ describe('PolicyEngine', () => { engine = new PolicyEngine({ rules }); // Should not throw when checking (which internally uses stableStringify) - expect(() => + await expect( engine.check({ name: 'test', args: input }, undefined), - ).not.toThrow(); + ).resolves.not.toThrow(); // The check should succeed - expect(engine.check({ name: 'test', args: input }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args: input }, undefined)) + .decision, + ).toBe(PolicyDecision.ALLOW); } }); - it('should respect toJSON methods on objects', () => { + it('should respect toJSON methods on objects', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -725,12 +771,12 @@ describe('PolicyEngine', () => { }; // Should match the sanitized pattern, not the dangerous one - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); }); - it('should handle toJSON that returns primitives', () => { + it('should handle toJSON that returns primitives', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -749,12 +795,12 @@ describe('PolicyEngine', () => { }; // toJSON returns a string, which should be properly stringified - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, - ); + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); }); - it('should handle toJSON that throws an error', () => { + it('should handle toJSON that throws an error', async () => { const rules: PolicyRule[] = [ { toolName: 'test', @@ -775,23 +821,433 @@ describe('PolicyEngine', () => { }; // Should fall back to regular object serialization when toJSON throws - expect(engine.check({ name: 'test', args }, undefined)).toBe( - PolicyDecision.ALLOW, + expect( + (await engine.check({ name: 'test', args }, undefined)).decision, + ).toBe(PolicyDecision.ALLOW); + }); + }); + + describe('safety checker integration', () => { + it('should call checker when rule allows and has safety_checker', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test-tool', + decision: PolicyDecision.ALLOW, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test-tool', + checker: { + type: 'external', + name: 'test-checker', + config: { content: 'test-content' }, + }, + }, + ]; + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + const result = await engine.check( + { name: 'test-tool', args: { foo: 'bar' } }, + undefined, ); + + expect(result.decision).toBe(PolicyDecision.ALLOW); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + { name: 'test-tool', args: { foo: 'bar' } }, + { + type: 'external', + name: 'test-checker', + config: { content: 'test-content' }, + }, + ); + }); + + it('should handle checker errors as DENY', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + decision: PolicyDecision.ALLOW, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test', + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + mockCheckerRunner.runChecker = vi + .fn() + .mockRejectedValue(new Error('Checker failed')); + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + const { decision } = await engine.check({ name: 'test' }, undefined); + + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should return DENY when checker denies', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test-tool', + decision: PolicyDecision.ALLOW, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test-tool', + checker: { + type: 'external', + name: 'test-checker', + config: { content: 'test-content' }, + }, + }, + ]; + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.DENY, + reason: 'test reason', + }); + + const result = await engine.check( + { name: 'test-tool', args: { foo: 'bar' } }, + undefined, + ); + + expect(result.decision).toBe(PolicyDecision.DENY); + expect(mockCheckerRunner.runChecker).toHaveBeenCalled(); + }); + + it('should not call checker if decision is not ALLOW', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test-tool', + decision: PolicyDecision.ASK_USER, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test-tool', + checker: { + type: 'external', + name: 'test-checker', + config: { content: 'test-content' }, + }, + }, + ]; + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + const result = await engine.check( + { name: 'test-tool', args: { foo: 'bar' } }, + undefined, + ); + + expect(result.decision).toBe(PolicyDecision.ASK_USER); + expect(mockCheckerRunner.runChecker).toHaveBeenCalled(); + }); + + it('should run checkers when rule allows', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + decision: PolicyDecision.ALLOW, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test', + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + mockCheckerRunner.runChecker = vi.fn().mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + const { decision } = await engine.check({ name: 'test' }, undefined); + + expect(decision).toBe(PolicyDecision.ALLOW); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1); + }); + + it('should not call checker if rule has no safety_checker', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test-tool', + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }, mockCheckerRunner); + + const result = await engine.check( + { name: 'test-tool', args: { foo: 'bar' } }, + undefined, + ); + + expect(result.decision).toBe(PolicyDecision.ALLOW); + expect(mockCheckerRunner.runChecker).not.toHaveBeenCalled(); }); }); describe('serverName requirement', () => { - it('should require serverName for checks', () => { + it('should require serverName for checks', async () => { // @ts-expect-error - intentionally testing missing serverName - expect(engine.check({ name: 'test' })).toBe(PolicyDecision.ASK_USER); - // When serverName is provided (even undefined), it should work - expect(engine.check({ name: 'test' }, undefined)).toBe( + expect((await engine.check({ name: 'test' })).decision).toBe( PolicyDecision.ASK_USER, ); - expect(engine.check({ name: 'test' }, 'some-server')).toBe( + // When serverName is provided (even undefined), it should work + expect((await engine.check({ name: 'test' }, undefined)).decision).toBe( PolicyDecision.ASK_USER, ); + expect( + (await engine.check({ name: 'test' }, 'some-server')).decision, + ).toBe(PolicyDecision.ASK_USER); + }); + it('should run multiple checkers in priority order and stop at first denial', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + decision: PolicyDecision.ALLOW, + }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + toolName: 'test', + priority: 10, + checker: { type: 'external', name: 'checker1' }, + }, + { + toolName: 'test', + priority: 20, // Should run first + checker: { type: 'external', name: 'checker2' }, + }, + ]; + + mockCheckerRunner.runChecker = vi + .fn() + .mockImplementation(async (_toolCall, config) => { + if (config.name === 'checker2') { + return { + decision: SafetyCheckDecision.DENY, + reason: 'checker2 denied', + }; + } + return { decision: SafetyCheckDecision.ALLOW }; + }); + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + const { decision, rule } = await engine.check( + { name: 'test' }, + undefined, + ); + + expect(decision).toBe(PolicyDecision.DENY); + expect(rule).toBeDefined(); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'checker2' }), + ); + }); + }); + + describe('addChecker', () => { + it('should add a new checker and maintain priority order', () => { + const checker1: SafetyCheckerRule = { + checker: { type: 'external', name: 'checker1' }, + priority: 5, + }; + const checker2: SafetyCheckerRule = { + checker: { type: 'external', name: 'checker2' }, + priority: 10, + }; + + engine.addChecker(checker1); + engine.addChecker(checker2); + + const checkers = engine.getCheckers(); + expect(checkers).toHaveLength(2); + expect(checkers[0].priority).toBe(10); + expect(checkers[0].checker.name).toBe('checker2'); + expect(checkers[1].priority).toBe(5); + expect(checkers[1].checker.name).toBe('checker1'); + }); + }); + + describe('checker matching logic', () => { + it('should match checkers using toolName and argsPattern', async () => { + const rules: PolicyRule[] = [ + { toolName: 'tool', decision: PolicyDecision.ALLOW }, + ]; + const matchingChecker: SafetyCheckerRule = { + checker: { type: 'external', name: 'matching' }, + toolName: 'tool', + argsPattern: /"safe":true/, + }; + const nonMatchingChecker: SafetyCheckerRule = { + checker: { type: 'external', name: 'non-matching' }, + toolName: 'other', + }; + + engine = new PolicyEngine( + { rules, checkers: [matchingChecker, nonMatchingChecker] }, + mockCheckerRunner, + ); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + await engine.check({ name: 'tool', args: { safe: true } }, undefined); + + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'matching' }), + ); + expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'non-matching' }), + ); + }); + + it('should support wildcard patterns for checkers', async () => { + const rules: PolicyRule[] = [ + { toolName: 'server__tool', decision: PolicyDecision.ALLOW }, + ]; + const wildcardChecker: SafetyCheckerRule = { + checker: { type: 'external', name: 'wildcard' }, + toolName: 'server__*', + }; + + engine = new PolicyEngine( + { rules, checkers: [wildcardChecker] }, + mockCheckerRunner, + ); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + await engine.check({ name: 'server__tool' }, 'server'); + + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'wildcard' }), + ); + }); + it('should run safety checkers when decision is ASK_USER and downgrade to DENY on failure', async () => { + const rules: PolicyRule[] = [ + { toolName: 'tool', decision: PolicyDecision.ASK_USER }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.DENY, + reason: 'Safety check failed', + }); + + const result = await engine.check({ name: 'tool' }, undefined); + expect(result.decision).toBe(PolicyDecision.DENY); + expect(mockCheckerRunner.runChecker).toHaveBeenCalled(); + }); + + it('should run safety checkers when decision is ASK_USER and keep ASK_USER on success', async () => { + const rules: PolicyRule[] = [ + { toolName: 'tool', decision: PolicyDecision.ASK_USER }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + const result = await engine.check({ name: 'tool' }, undefined); + expect(result.decision).toBe(PolicyDecision.ASK_USER); + expect(mockCheckerRunner.runChecker).toHaveBeenCalled(); + }); + + it('should downgrade ALLOW to ASK_USER if checker returns ASK_USER', async () => { + const rules: PolicyRule[] = [ + { toolName: 'tool', decision: PolicyDecision.ALLOW }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + engine = new PolicyEngine({ rules, checkers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ASK_USER, + reason: 'Suspicious path', + }); + + const result = await engine.check({ name: 'tool' }, undefined); + expect(result.decision).toBe(PolicyDecision.ASK_USER); + }); + + it('should DENY if checker returns ASK_USER in non-interactive mode', async () => { + const rules: PolicyRule[] = [ + { toolName: 'tool', decision: PolicyDecision.ALLOW }, + ]; + const checkers: SafetyCheckerRule[] = [ + { + checker: { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }, + }, + ]; + + engine = new PolicyEngine( + { rules, checkers, nonInteractive: true }, + mockCheckerRunner, + ); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ASK_USER, + reason: 'Suspicious path', + }); + + const result = await engine.check({ name: 'tool' }, undefined); + expect(result.decision).toBe(PolicyDecision.DENY); }); }); }); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index 034cce2c8b..f1fb05ec43 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -9,12 +9,15 @@ import { PolicyDecision, type PolicyEngineConfig, type PolicyRule, + type SafetyCheckerRule, } from './types.js'; import { stableStringify } from './stable-stringify.js'; import { debugLogger } from '../utils/debugLogger.js'; +import type { CheckerRunner } from '../safety/checker-runner.js'; +import { SafetyCheckDecision } from '../safety/protocol.js'; function ruleMatches( - rule: PolicyRule, + rule: PolicyRule | SafetyCheckerRule, toolCall: FunctionCall, stringifiedArgs: string | undefined, serverName: string | undefined, @@ -60,27 +63,41 @@ function ruleMatches( export class PolicyEngine { private rules: PolicyRule[]; + private checkers: SafetyCheckerRule[]; private readonly defaultDecision: PolicyDecision; private readonly nonInteractive: boolean; + private readonly checkerRunner?: CheckerRunner; - constructor(config: PolicyEngineConfig = {}) { + constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { this.rules = (config.rules ?? []).sort( (a, b) => (b.priority ?? 0) - (a.priority ?? 0), ); + this.checkers = (config.checkers ?? []).sort( + (a, b) => (b.priority ?? 0) - (a.priority ?? 0), + ); this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; this.nonInteractive = config.nonInteractive ?? false; + this.checkerRunner = checkerRunner; } /** * Check if a tool call is allowed based on the configured policies. + * Returns the decision and the matching rule (if any). */ - check( + async check( toolCall: FunctionCall, serverName: string | undefined, - ): PolicyDecision { + ): Promise<{ + decision: PolicyDecision; + rule?: PolicyRule; + }> { let stringifiedArgs: string | undefined; // Compute stringified args once before the loop - if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) { + if ( + toolCall.args && + (this.rules.some((rule) => rule.argsPattern) || + this.checkers.some((checker) => checker.argsPattern)) + ) { stringifiedArgs = stableStringify(toolCall.args); } @@ -89,20 +106,72 @@ export class PolicyEngine { ); // Find the first matching rule (already sorted by priority) + let matchedRule: PolicyRule | undefined; + let decision: PolicyDecision | undefined; + for (const rule of this.rules) { if (ruleMatches(rule, toolCall, stringifiedArgs, serverName)) { debugLogger.debug( `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, ); - return this.applyNonInteractiveMode(rule.decision); + matchedRule = rule; + decision = this.applyNonInteractiveMode(rule.decision); + break; } } - // No matching rule found, use default decision - debugLogger.debug( - `[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`, - ); - return this.applyNonInteractiveMode(this.defaultDecision); + if (!decision) { + // No matching rule found, use default decision + debugLogger.debug( + `[PolicyEngine.check] NO MATCH - using default decision: ${this.defaultDecision}`, + ); + decision = this.applyNonInteractiveMode(this.defaultDecision); + } + + // If decision is not DENY, run safety checkers + if (decision !== PolicyDecision.DENY && this.checkerRunner) { + for (const checkerRule of this.checkers) { + if (ruleMatches(checkerRule, toolCall, stringifiedArgs, serverName)) { + debugLogger.debug( + `[PolicyEngine.check] Running safety checker: ${checkerRule.checker.name}`, + ); + try { + const result = await this.checkerRunner.runChecker( + toolCall, + checkerRule.checker, + ); + + if (result.decision === SafetyCheckDecision.DENY) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker denied: ${result.reason}`, + ); + return { + decision: PolicyDecision.DENY, + rule: matchedRule, + }; + } else if (result.decision === SafetyCheckDecision.ASK_USER) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker requested ASK_USER: ${result.reason}`, + ); + decision = PolicyDecision.ASK_USER; + } + } catch (error) { + debugLogger.debug( + `[PolicyEngine.check] Safety checker failed: ${error}`, + ); + return { + decision: PolicyDecision.DENY, + rule: matchedRule, + }; + } + } + } + } + + return { + decision: this.applyNonInteractiveMode(decision), + rule: matchedRule, + }; } /** @@ -114,6 +183,11 @@ export class PolicyEngine { this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); } + addChecker(checker: SafetyCheckerRule): void { + this.checkers.push(checker); + this.checkers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + /** * Remove rules for a specific tool. */ @@ -128,6 +202,10 @@ export class PolicyEngine { return this.rules; } + getCheckers(): readonly SafetyCheckerRule[] { + return this.checkers; + } + private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { // In non-interactive mode, ASK_USER becomes DENY if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { diff --git a/packages/core/src/policy/toml-loader.test.ts b/packages/core/src/policy/toml-loader.test.ts index 77b44bc2ca..38f29b15fd 100644 --- a/packages/core/src/policy/toml-loader.test.ts +++ b/packages/core/src/policy/toml-loader.test.ts @@ -85,6 +85,7 @@ priority = 100 decision: PolicyDecision.ALLOW, priority: 1.1, // tier 1 + 100/1000 }); + expect(result.checkers).toHaveLength(0); expect(result.errors).toHaveLength(0); }); diff --git a/packages/core/src/policy/toml-loader.ts b/packages/core/src/policy/toml-loader.ts index a57d4aa77a..ed63db0929 100644 --- a/packages/core/src/policy/toml-loader.ts +++ b/packages/core/src/policy/toml-loader.ts @@ -4,7 +4,14 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { type PolicyRule, PolicyDecision, type ApprovalMode } from './types.js'; +import { + type PolicyRule, + PolicyDecision, + type ApprovalMode, + type SafetyCheckerConfig, + type SafetyCheckerRule, + InProcessCheckerType, +} from './types.js'; import fs from 'node:fs/promises'; import path from 'node:path'; import toml from '@iarna/toml'; @@ -39,11 +46,39 @@ const PolicyRuleSchema = z.object({ modes: z.array(z.string()).optional(), }); +/** + * Schema for a single safety checker rule in the TOML file. + */ +const SafetyCheckerRuleSchema = z.object({ + toolName: z.union([z.string(), z.array(z.string())]).optional(), + mcpName: z.string().optional(), + argsPattern: z.string().optional(), + commandPrefix: z.union([z.string(), z.array(z.string())]).optional(), + commandRegex: z.string().optional(), + priority: z.number().int().default(0), + modes: z.array(z.string()).optional(), + checker: z.discriminatedUnion('type', [ + z.object({ + type: z.literal('in-process'), + name: z.nativeEnum(InProcessCheckerType), + required_context: z.array(z.string()).optional(), + config: z.record(z.unknown()).optional(), + }), + z.object({ + type: z.literal('external'), + name: z.string(), + required_context: z.array(z.string()).optional(), + config: z.record(z.unknown()).optional(), + }), + ]), +}); + /** * Schema for the entire policy TOML file. */ const PolicyFileSchema = z.object({ - rule: z.array(PolicyRuleSchema), + rule: z.array(PolicyRuleSchema).optional(), + safety_checker: z.array(SafetyCheckerRuleSchema).optional(), }); /** @@ -80,6 +115,7 @@ export interface PolicyFileError { */ export interface PolicyLoadResult { rules: PolicyRule[]; + checkers: SafetyCheckerRule[]; errors: PolicyFileError[]; } @@ -194,6 +230,7 @@ export async function loadPoliciesFromToml( getPolicyTier: (dir: string) => number, ): Promise { const rules: PolicyRule[] = []; + const checkers: SafetyCheckerRule[] = []; const errors: PolicyFileError[] = []; for (const dir of policyDirs) { @@ -267,8 +304,9 @@ export async function loadPoliciesFromToml( } // Validate shell command convenience syntax - for (let i = 0; i < validationResult.data.rule.length; i++) { - const rule = validationResult.data.rule[i]; + const tomlRules = validationResult.data.rule ?? []; + for (let i = 0; i < tomlRules.length; i++) { + const rule = tomlRules[i]; const validationError = validateShellCommandSyntax(rule, i); if (validationError) { errors.push({ @@ -285,7 +323,7 @@ export async function loadPoliciesFromToml( } // Transform rules - const parsedRules: PolicyRule[] = validationResult.data.rule + const parsedRules: PolicyRule[] = (validationResult.data.rule ?? []) .filter((rule) => { // Filter by mode if (!rule.modes || rule.modes.length === 0) { @@ -369,6 +407,84 @@ export async function loadPoliciesFromToml( .filter((rule): rule is PolicyRule => rule !== null); rules.push(...parsedRules); + + // Transform checkers + const parsedCheckers: SafetyCheckerRule[] = ( + validationResult.data.safety_checker ?? [] + ) + .filter((checker) => { + if (!checker.modes || checker.modes.length === 0) { + return true; + } + return checker.modes.includes(approvalMode); + }) + .flatMap((checker) => { + let effectiveArgsPattern = checker.argsPattern; + const commandPrefixes: string[] = []; + + if (checker.commandPrefix) { + const prefixes = Array.isArray(checker.commandPrefix) + ? checker.commandPrefix + : [checker.commandPrefix]; + commandPrefixes.push(...prefixes); + } else if (checker.commandRegex) { + effectiveArgsPattern = `"command":"${checker.commandRegex}`; + } + + const argsPatterns: Array = + commandPrefixes.length > 0 + ? commandPrefixes.map( + (prefix) => `"command":"${escapeRegex(prefix)}`, + ) + : [effectiveArgsPattern]; + + return argsPatterns.flatMap((argsPattern) => { + const toolNames: Array = checker.toolName + ? Array.isArray(checker.toolName) + ? checker.toolName + : [checker.toolName] + : [undefined]; + + return toolNames.map((toolName) => { + let effectiveToolName: string | undefined; + if (checker.mcpName && toolName) { + effectiveToolName = `${checker.mcpName}__${toolName}`; + } else if (checker.mcpName) { + effectiveToolName = `${checker.mcpName}__*`; + } else { + effectiveToolName = toolName; + } + + const safetyCheckerRule: SafetyCheckerRule = { + toolName: effectiveToolName, + priority: checker.priority, + checker: checker.checker as SafetyCheckerConfig, + }; + + if (argsPattern) { + try { + safetyCheckerRule.argsPattern = new RegExp(argsPattern); + } catch (e) { + const error = e as Error; + errors.push({ + filePath, + fileName: file, + tier: tierName, + errorType: 'regex_compilation', + message: 'Invalid regex pattern in safety checker', + details: `Pattern: ${argsPattern}\nError: ${error.message}`, + }); + return null; + } + } + + return safetyCheckerRule; + }); + }); + }) + .filter((checker): checker is SafetyCheckerRule => checker !== null); + + checkers.push(...parsedCheckers); } catch (e) { const error = e as NodeJS.ErrnoException; // Catch-all for unexpected errors @@ -386,5 +502,5 @@ export async function loadPoliciesFromToml( } } - return { rules, errors }; + return { rules, checkers, errors }; } diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts index 2ab7ba5ed0..d244211e70 100644 --- a/packages/core/src/policy/types.ts +++ b/packages/core/src/policy/types.ts @@ -4,6 +4,8 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { SafetyCheckInput } from '../safety/protocol.js'; + export enum PolicyDecision { ALLOW = 'allow', DENY = 'deny', @@ -16,6 +18,52 @@ export enum ApprovalMode { YOLO = 'yolo', } +/** + * Configuration for the built-in allowed-path checker. + */ +export interface AllowedPathConfig { + /** + * Explicitly include argument keys to be checked as paths. + */ + included_args?: string[]; + + /** + * Explicitly exclude argument keys from being checked as paths. + */ + excluded_args?: string[]; +} + +/** + * Base interface for external checkers. + */ +export interface ExternalCheckerConfig { + type: 'external'; + name: string; + config?: unknown; + required_context?: Array; +} + +export enum InProcessCheckerType { + ALLOWED_PATH = 'allowed-path', +} + +/** + * Base interface for in-process checkers. + */ +export interface InProcessCheckerConfig { + type: 'in-process'; + name: InProcessCheckerType; + config?: AllowedPathConfig; + required_context?: Array; +} + +/** + * A discriminated union for all safety checker configurations. + */ +export type SafetyCheckerConfig = + | ExternalCheckerConfig + | InProcessCheckerConfig; + export interface PolicyRule { /** * The name of the tool this rule applies to. @@ -41,12 +89,43 @@ export interface PolicyRule { priority?: number; } +export interface SafetyCheckerRule { + /** + * The name of the tool this rule applies to. + * If undefined, the rule applies to all tools. + */ + toolName?: string; + + /** + * Pattern to match against tool arguments. + * Can be used for more fine-grained control. + */ + argsPattern?: RegExp; + + /** + * Priority of this checker. Higher numbers run first. + * Default is 0. + */ + priority?: number; + + /** + * Specifies an external or built-in safety checker to execute for + * additional validation of a tool call. + */ + checker: SafetyCheckerConfig; +} + export interface PolicyEngineConfig { /** * List of policy rules to apply. */ rules?: PolicyRule[]; + /** + * List of safety checkers to apply. + */ + checkers?: SafetyCheckerRule[]; + /** * Default decision when no rules match. * Defaults to ASK_USER. diff --git a/packages/core/src/safety/built-in.test.ts b/packages/core/src/safety/built-in.test.ts new file mode 100644 index 0000000000..d940929009 --- /dev/null +++ b/packages/core/src/safety/built-in.test.ts @@ -0,0 +1,244 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import * as fs from 'node:fs/promises'; +import * as os from 'node:os'; +import * as path from 'node:path'; +import { AllowedPathChecker } from './built-in.js'; +import type { SafetyCheckInput } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { FunctionCall } from '@google/genai'; + +describe('AllowedPathChecker', () => { + let checker: AllowedPathChecker; + let testRootDir: string; + let mockCwd: string; + let mockWorkspaces: string[]; + + beforeEach(async () => { + checker = new AllowedPathChecker(); + testRootDir = await fs.mkdtemp(path.join(os.tmpdir(), 'safety-test-')); + mockCwd = path.join(testRootDir, 'home', 'user', 'project'); + await fs.mkdir(mockCwd, { recursive: true }); + mockWorkspaces = [ + mockCwd, + path.join(testRootDir, 'home', 'user', 'other-project'), + ]; + await fs.mkdir(mockWorkspaces[1], { recursive: true }); + }); + + afterEach(async () => { + await fs.rm(testRootDir, { recursive: true, force: true }); + }); + + const createInput = ( + toolArgs: Record, + config?: Record, + ): SafetyCheckInput => ({ + protocolVersion: '1.0.0', + toolCall: { + name: 'test_tool', + args: toolArgs, + } as unknown as FunctionCall, + context: { + environment: { + cwd: mockCwd, + workspaces: mockWorkspaces, + }, + }, + config, + }); + + it('should allow paths within CWD', async () => { + const filePath = path.join(mockCwd, 'file.txt'); + await fs.writeFile(filePath, 'test content'); + const input = createInput({ + path: filePath, + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); + + it('should allow paths within workspace roots', async () => { + const filePath = path.join(mockWorkspaces[1], 'data.json'); + await fs.writeFile(filePath, 'test content'); + const input = createInput({ + path: filePath, + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); + + it('should deny paths outside allowed areas', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput({ path: outsidePath }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain('outside of the allowed workspace'); + }); + + it('should deny paths using ../ to escape', async () => { + const secretPath = path.join(testRootDir, 'home', 'user', 'secret.txt'); + await fs.writeFile(secretPath, 'secret'); + const input = createInput({ + path: path.join(mockCwd, '..', 'secret.txt'), + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + }); + + it('should check multiple path arguments', async () => { + const passwdPath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(passwdPath), { recursive: true }); + await fs.writeFile(passwdPath, 'secret'); + const srcPath = path.join(mockCwd, 'src.txt'); + await fs.writeFile(srcPath, 'source content'); + + const input = createInput({ + source: srcPath, + destination: passwdPath, + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain(passwdPath); + }); + + it('should handle non-existent paths gracefully if they are inside allowed dir', async () => { + const input = createInput({ + path: path.join(mockCwd, 'new-file.txt'), + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); + + it('should deny access if path contains a symlink pointing outside allowed directories', async () => { + const symlinkPath = path.join(mockCwd, 'symlink'); + const targetPath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(targetPath), { recursive: true }); + await fs.writeFile(targetPath, 'secret'); + + // Create symlink: mockCwd/symlink -> targetPath + await fs.symlink(targetPath, symlinkPath); + + const input = createInput({ path: symlinkPath }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain( + 'outside of the allowed workspace directories', + ); + }); + + it('should allow access if path contains a symlink pointing INSIDE allowed directories', async () => { + const symlinkPath = path.join(mockCwd, 'symlink-inside'); + const realFilePath = path.join(mockCwd, 'real-file'); + await fs.writeFile(realFilePath, 'real content'); + + // Create symlink: mockCwd/symlink-inside -> mockCwd/real-file + await fs.symlink(realFilePath, symlinkPath); + + const input = createInput({ path: symlinkPath }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); + + it('should check explicitly included arguments', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput( + { custom_arg: outsidePath }, + { included_args: ['custom_arg'] }, + ); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain('outside of the allowed workspace'); + }); + + it('should skip explicitly excluded arguments', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + // Normally 'path' would be checked, but we exclude it + const input = createInput( + { path: outsidePath }, + { excluded_args: ['path'] }, + ); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); + + it('should handle both included and excluded arguments', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput( + { + path: outsidePath, // Excluded + custom_arg: outsidePath, // Included + }, + { + excluded_args: ['path'], + included_args: ['custom_arg'], + }, + ); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + // Should be denied because of custom_arg, not path + expect(result.reason).toContain(outsidePath); + }); + + it('should check nested path arguments', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput({ + nested: { + path: outsidePath, + }, + }); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain(outsidePath); + expect(result.reason).toContain('nested.path'); + }); + + it('should support dot notation for included_args', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput( + { + nested: { + custom: outsidePath, + }, + }, + { included_args: ['nested.custom'] }, + ); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain(outsidePath); + expect(result.reason).toContain('nested.custom'); + }); + + it('should support dot notation for excluded_args', async () => { + const outsidePath = path.join(testRootDir, 'etc', 'passwd'); + await fs.mkdir(path.dirname(outsidePath), { recursive: true }); + await fs.writeFile(outsidePath, 'secret'); + const input = createInput( + { + nested: { + path: outsidePath, + }, + }, + { excluded_args: ['nested.path'] }, + ); + const result = await checker.check(input); + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + }); +}); diff --git a/packages/core/src/safety/built-in.ts b/packages/core/src/safety/built-in.ts new file mode 100644 index 0000000000..57a22d55e3 --- /dev/null +++ b/packages/core/src/safety/built-in.ts @@ -0,0 +1,154 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { AllowedPathConfig } from '../policy/types.js'; + +/** + * Interface for all in-process safety checkers. + */ +export interface InProcessChecker { + check(input: SafetyCheckInput): Promise; +} + +/** + * An in-process checker to validate file paths. + */ +export class AllowedPathChecker implements InProcessChecker { + async check(input: SafetyCheckInput): Promise { + const { toolCall, context } = input; + const config = input.config as AllowedPathConfig | undefined; + + // Build list of allowed directories + const allowedDirs = [ + context.environment.cwd, + ...context.environment.workspaces, + ]; + + // Find all arguments that look like paths + const includedArgs = config?.included_args ?? []; + const excludedArgs = config?.excluded_args ?? []; + + const pathsToCheck = this.collectPathsToCheck( + toolCall.args, + includedArgs, + excludedArgs, + ); + + // Check each path + for (const { path: p, argName } of pathsToCheck) { + const resolvedPath = this.safelyResolvePath(p, context.environment.cwd); + + if (!resolvedPath) { + // If path cannot be resolved, deny it + return { + decision: SafetyCheckDecision.DENY, + reason: `Cannot resolve path "${p}" in argument "${argName}"`, + }; + } + + const isAllowed = allowedDirs.some((dir) => { + // Also resolve allowed directories to handle symlinks + const resolvedDir = this.safelyResolvePath( + dir, + context.environment.cwd, + ); + if (!resolvedDir) return false; + return this.isPathAllowed(resolvedPath, resolvedDir); + }); + + if (!isAllowed) { + return { + decision: SafetyCheckDecision.DENY, + reason: `Path "${p}" in argument "${argName}" is outside of the allowed workspace directories.`, + }; + } + } + + return { decision: SafetyCheckDecision.ALLOW }; + } + + private safelyResolvePath(inputPath: string, cwd: string): string | null { + try { + const resolved = path.resolve(cwd, inputPath); + + // Walk up the directory tree until we find a path that exists + let current = resolved; + // Stop at root (dirname(root) === root on many systems, or it becomes empty/'.' depending on implementation) + while (current && current !== path.dirname(current)) { + if (fs.existsSync(current)) { + const canonical = fs.realpathSync(current); + // Re-construct the full path from this canonical base + const relative = path.relative(current, resolved); + // path.join handles empty relative paths correctly (returns canonical) + return path.join(canonical, relative); + } + current = path.dirname(current); + } + + // Fallback if nothing exists (unlikely if root exists) + return resolved; + } catch (_error) { + return null; + } + } + + private isPathAllowed(targetPath: string, allowedDir: string): boolean { + const relative = path.relative(allowedDir, targetPath); + return ( + relative === '' || + (!relative.startsWith('..') && !path.isAbsolute(relative)) + ); + } + + private collectPathsToCheck( + args: unknown, + includedArgs: string[], + excludedArgs: string[], + prefix = '', + ): Array<{ path: string; argName: string }> { + const paths: Array<{ path: string; argName: string }> = []; + + if (typeof args !== 'object' || args === null) { + return paths; + } + + for (const [key, value] of Object.entries(args)) { + const fullKey = prefix ? `${prefix}.${key}` : key; + + if (excludedArgs.includes(fullKey)) { + continue; + } + + if (typeof value === 'string') { + if ( + includedArgs.includes(fullKey) || + key.includes('path') || + key.includes('directory') || + key.includes('file') || + key === 'source' || + key === 'destination' + ) { + paths.push({ path: value, argName: fullKey }); + } + } else if (typeof value === 'object') { + paths.push( + ...this.collectPathsToCheck( + value, + includedArgs, + excludedArgs, + fullKey, + ), + ); + } + } + + return paths; + } +} diff --git a/packages/core/src/safety/checker-runner.test.ts b/packages/core/src/safety/checker-runner.test.ts new file mode 100644 index 0000000000..cd3c0e18ba --- /dev/null +++ b/packages/core/src/safety/checker-runner.test.ts @@ -0,0 +1,303 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { spawn } from 'node:child_process'; +import { CheckerRunner } from './checker-runner.js'; +import { ContextBuilder } from './context-builder.js'; +import { CheckerRegistry } from './registry.js'; +import { + type InProcessCheckerConfig, + InProcessCheckerType, +} from '../policy/types.js'; +import type { SafetyCheckResult } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { Config } from '../config/config.js'; + +// Mock dependencies +vi.mock('./registry.js'); +vi.mock('./context-builder.js'); +vi.mock('node:child_process'); + +describe('CheckerRunner', () => { + let runner: CheckerRunner; + let mockContextBuilder: ContextBuilder; + let mockRegistry: CheckerRegistry; + + const mockToolCall = { name: 'test_tool', args: {} }; + const mockInProcessConfig: InProcessCheckerConfig = { + type: 'in-process', + name: InProcessCheckerType.ALLOWED_PATH, + }; + + beforeEach(() => { + mockContextBuilder = new ContextBuilder({} as Config); + mockRegistry = new CheckerRegistry('/mock/dist'); + CheckerRegistry.prototype.resolveInProcess = vi.fn(); + + runner = new CheckerRunner(mockContextBuilder, mockRegistry, { + checkersPath: '/mock/dist', + }); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should run in-process checker successfully', async () => { + const mockResult: SafetyCheckResult = { + decision: SafetyCheckDecision.ALLOW, + }; + const mockChecker = { + check: vi.fn().mockResolvedValue(mockResult), + }; + vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const result = await runner.runChecker(mockToolCall, mockInProcessConfig); + + expect(result).toEqual(mockResult); + expect(mockRegistry.resolveInProcess).toHaveBeenCalledWith( + InProcessCheckerType.ALLOWED_PATH, + ); + expect(mockChecker.check).toHaveBeenCalled(); + }); + + it('should handle in-process checker errors', async () => { + const mockChecker = { + check: vi.fn().mockRejectedValue(new Error('Checker failed')), + }; + vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const result = await runner.runChecker(mockToolCall, mockInProcessConfig); + + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain('Failed to run in-process checker'); + expect(result.reason).toContain('Checker failed'); + }); + + it('should respect timeout for in-process checkers', async () => { + vi.useFakeTimers(); + const mockChecker = { + check: vi.fn().mockImplementation(async () => { + await new Promise((resolve) => setTimeout(resolve, 6000)); // Longer than default 5s timeout + return { decision: SafetyCheckDecision.ALLOW }; + }), + }; + vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const runPromise = runner.runChecker(mockToolCall, mockInProcessConfig); + vi.advanceTimersByTime(5001); + + const result = await runPromise; + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain('timed out'); + + vi.useRealTimers(); + }); + + it('should use minimal context when requested', async () => { + const configWithContext: InProcessCheckerConfig = { + ...mockInProcessConfig, + required_context: ['environment'], + }; + const mockChecker = { + check: vi.fn().mockResolvedValue({ decision: SafetyCheckDecision.ALLOW }), + }; + vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker); + vi.mocked(mockContextBuilder.buildMinimalContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + await runner.runChecker(mockToolCall, configWithContext); + + expect(mockContextBuilder.buildMinimalContext).toHaveBeenCalledWith([ + 'environment', + ]); + expect(mockContextBuilder.buildFullContext).not.toHaveBeenCalled(); + }); + + it('should pass config to in-process checker via toolCall', async () => { + const mockConfig = { included_args: ['foo'] }; + const configWithConfig: InProcessCheckerConfig = { + ...mockInProcessConfig, + config: mockConfig, + }; + const mockResult: SafetyCheckResult = { + decision: SafetyCheckDecision.ALLOW, + }; + const mockChecker = { + check: vi.fn().mockResolvedValue(mockResult), + }; + vi.mocked(mockRegistry.resolveInProcess).mockReturnValue(mockChecker); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + await runner.runChecker(mockToolCall, configWithConfig); + + expect(mockChecker.check).toHaveBeenCalledWith( + expect.objectContaining({ + toolCall: mockToolCall, + config: mockConfig, + }), + ); + }); + + describe('External Checkers', () => { + const mockExternalConfig = { + type: 'external' as const, + name: 'python-checker', + }; + + it('should spawn external checker directly', async () => { + const mockCheckerPath = '/mock/dist/python-checker'; + vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const mockStdout = { + on: vi.fn().mockImplementation((event, callback) => { + if (event === 'data') { + callback( + Buffer.from( + JSON.stringify({ decision: SafetyCheckDecision.ALLOW }), + ), + ); + } + }), + }; + const mockChildProcess = { + stdin: { write: vi.fn(), end: vi.fn() }, + stdout: mockStdout, + stderr: { on: vi.fn() }, + on: vi.fn().mockImplementation((event, callback) => { + if (event === 'close') { + // Defer the close callback slightly to allow stdout 'data' to be registered + setTimeout(() => callback(0), 0); + } + }), + kill: vi.fn(), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(spawn).mockReturnValue(mockChildProcess as any); + + const result = await runner.runChecker(mockToolCall, mockExternalConfig); + + expect(result.decision).toBe(SafetyCheckDecision.ALLOW); + expect(spawn).toHaveBeenCalledWith( + mockCheckerPath, + [], + expect.anything(), + ); + }); + + it('should include checker name in timeout error message', async () => { + vi.useFakeTimers(); + const mockCheckerPath = '/mock/dist/python-checker'; + vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const mockChildProcess = { + stdin: { write: vi.fn(), end: vi.fn() }, + stdout: { on: vi.fn() }, + stderr: { on: vi.fn() }, + on: vi.fn(), // Never calls 'close' + kill: vi.fn(), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(spawn).mockReturnValue(mockChildProcess as any); + + const runPromise = runner.runChecker(mockToolCall, mockExternalConfig); + vi.advanceTimersByTime(5001); + + const result = await runPromise; + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain( + 'Safety checker "python-checker" timed out', + ); + + vi.useRealTimers(); + }); + + it('should send SIGKILL if process ignores SIGTERM', async () => { + vi.useFakeTimers(); + const mockCheckerPath = '/mock/dist/python-checker'; + vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const mockChildProcess = { + stdin: { write: vi.fn(), end: vi.fn() }, + stdout: { on: vi.fn() }, + stderr: { on: vi.fn() }, + on: vi.fn(), // Never calls 'close' automatically + kill: vi.fn(), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(spawn).mockReturnValue(mockChildProcess as any); + + const runPromise = runner.runChecker(mockToolCall, mockExternalConfig); + + // Trigger main timeout + vi.advanceTimersByTime(5001); + + // Should have sent SIGTERM + expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGTERM'); + + // Advance past cleanup timeout (5000ms) + vi.advanceTimersByTime(5000); + + // Should have sent SIGKILL + expect(mockChildProcess.kill).toHaveBeenCalledWith('SIGKILL'); + + // Clean up promise + await runPromise; + vi.useRealTimers(); + }); + + it('should include checker name in non-zero exit code error message', async () => { + const mockCheckerPath = '/mock/dist/python-checker'; + vi.mocked(mockRegistry.resolveExternal).mockReturnValue(mockCheckerPath); + vi.mocked(mockContextBuilder.buildFullContext).mockReturnValue({ + environment: { cwd: '/tmp', workspaces: [] }, + }); + + const mockChildProcess = { + stdin: { write: vi.fn(), end: vi.fn() }, + stdout: { on: vi.fn() }, + stderr: { on: vi.fn() }, + on: vi.fn().mockImplementation((event, callback) => { + if (event === 'close') { + callback(1); // Exit code 1 + } + }), + kill: vi.fn(), + }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + vi.mocked(spawn).mockReturnValue(mockChildProcess as any); + + const result = await runner.runChecker(mockToolCall, mockExternalConfig); + + expect(result.decision).toBe(SafetyCheckDecision.DENY); + expect(result.reason).toContain( + 'Safety checker "python-checker" exited with code 1', + ); + }); + }); +}); diff --git a/packages/core/src/safety/checker-runner.ts b/packages/core/src/safety/checker-runner.ts new file mode 100644 index 0000000000..e9748cd5c3 --- /dev/null +++ b/packages/core/src/safety/checker-runner.ts @@ -0,0 +1,297 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { spawn } from 'node:child_process'; +import type { FunctionCall } from '@google/genai'; +import type { + SafetyCheckerConfig, + InProcessCheckerConfig, + ExternalCheckerConfig, +} from '../policy/types.js'; +import type { SafetyCheckInput, SafetyCheckResult } from './protocol.js'; +import { SafetyCheckDecision } from './protocol.js'; +import type { CheckerRegistry } from './registry.js'; +import type { ContextBuilder } from './context-builder.js'; + +/** + * Configuration for the checker runner. + */ +export interface CheckerRunnerConfig { + /** + * Maximum time (in milliseconds) to wait for a checker to complete. + * Default: 5000 (5 seconds) + */ + timeout?: number; + + /** + * Path to the directory containing external checkers. + */ + checkersPath: string; +} + +/** + * Service for executing safety checker processes. + */ +export class CheckerRunner { + private static readonly DEFAULT_TIMEOUT = 5000; // 5 seconds + + private readonly registry: CheckerRegistry; + private readonly contextBuilder: ContextBuilder; + private readonly timeout: number; + + constructor( + contextBuilder: ContextBuilder, + registry: CheckerRegistry, + config: CheckerRunnerConfig, + ) { + this.contextBuilder = contextBuilder; + this.registry = registry; + this.timeout = config.timeout ?? CheckerRunner.DEFAULT_TIMEOUT; + } + + /** + * Runs a safety checker and returns the result. + */ + async runChecker( + toolCall: FunctionCall, + checkerConfig: SafetyCheckerConfig, + ): Promise { + if (checkerConfig.type === 'in-process') { + return this.runInProcessChecker(toolCall, checkerConfig); + } + return this.runExternalChecker(toolCall, checkerConfig); + } + + private async runInProcessChecker( + toolCall: FunctionCall, + checkerConfig: InProcessCheckerConfig, + ): Promise { + try { + const checker = this.registry.resolveInProcess(checkerConfig.name); + const context = checkerConfig.required_context + ? this.contextBuilder.buildMinimalContext( + checkerConfig.required_context, + ) + : this.contextBuilder.buildFullContext(); + + const input: SafetyCheckInput = { + protocolVersion: '1.0.0', + toolCall, + context, + config: checkerConfig.config, + }; + + // In-process checkers can be async, but we'll also apply a timeout + // for safety, in case of infinite loops or unexpected delays. + return await this.executeWithTimeout(checker.check(input)); + } catch (error) { + return { + decision: SafetyCheckDecision.DENY, + reason: `Failed to run in-process checker "${checkerConfig.name}": ${ + error instanceof Error ? error.message : String(error) + }`, + }; + } + } + + private async runExternalChecker( + toolCall: FunctionCall, + checkerConfig: ExternalCheckerConfig, + ): Promise { + try { + // Resolve the checker executable path + const checkerPath = this.registry.resolveExternal(checkerConfig.name); + + // Build the appropriate context + const context = checkerConfig.required_context + ? this.contextBuilder.buildMinimalContext( + checkerConfig.required_context, + ) + : this.contextBuilder.buildFullContext(); + + // Create the input payload + const input: SafetyCheckInput = { + protocolVersion: '1.0.0', + toolCall, + context, + config: checkerConfig.config, + }; + + // Run the checker process + return await this.executeCheckerProcess( + checkerPath, + input, + checkerConfig.name, + ); + } catch (error) { + // If anything goes wrong, deny the operation + return { + decision: SafetyCheckDecision.DENY, + reason: `Failed to run safety checker "${checkerConfig.name}": ${ + error instanceof Error ? error.message : String(error) + }`, + }; + } + } + + /** + * Executes an external checker process and handles its lifecycle. + */ + private executeCheckerProcess( + checkerPath: string, + input: SafetyCheckInput, + checkerName: string, + ): Promise { + return new Promise((resolve) => { + const child = spawn(checkerPath, [], { + stdio: ['pipe', 'pipe', 'pipe'], + }); + + let stdout = ''; + let stderr = ''; + let timeoutHandle: NodeJS.Timeout | null = null; + let killed = false; + + let exited = false; + + // Set up timeout + timeoutHandle = setTimeout(() => { + killed = true; + child.kill('SIGTERM'); + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Safety checker "${checkerName}" timed out after ${this.timeout}ms`, + }); + + // Fallback: if process doesn't exit after 5s, force kill + setTimeout(() => { + if (!exited) { + child.kill('SIGKILL'); + } + }, 5000).unref(); + }, this.timeout); + + // Collect output + if (child.stdout) { + child.stdout.on('data', (data: Buffer) => { + stdout += data.toString(); + }); + } + + if (child.stderr) { + child.stderr.on('data', (data: Buffer) => { + stderr += data.toString(); + }); + } + + // Handle process completion + child.on('close', (code: number | null) => { + exited = true; + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + // If we already killed it due to timeout, don't process the result + if (killed) { + return; + } + + // Non-zero exit code is a failure + if (code !== 0) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Safety checker "${checkerName}" exited with code ${code}${ + stderr ? `: ${stderr}` : '' + }`, + }); + return; + } + + // Try to parse the output + try { + const result: SafetyCheckResult = JSON.parse(stdout); + + // Validate the result structure + if ( + !result.decision || + !Object.values(SafetyCheckDecision).includes(result.decision) + ) { + throw new Error( + 'Invalid result: missing or invalid "decision" field', + ); + } + + resolve(result); + } catch (parseError) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to parse output from safety checker "${checkerName}": ${ + parseError instanceof Error + ? parseError.message + : String(parseError) + }`, + }); + } + }); + + // Handle process errors + child.on('error', (error: Error) => { + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + if (!killed) { + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to spawn safety checker "${checkerName}": ${error.message}`, + }); + } + }); + + // Send input to the checker + try { + if (child.stdin) { + child.stdin.write(JSON.stringify(input)); + child.stdin.end(); + } else { + throw new Error('Failed to open stdin for checker process'); + } + } catch (writeError) { + if (timeoutHandle) { + clearTimeout(timeoutHandle); + } + + child.kill(); + resolve({ + decision: SafetyCheckDecision.DENY, + reason: `Failed to write to stdin of safety checker "${checkerName}": ${ + writeError instanceof Error + ? writeError.message + : String(writeError) + }`, + }); + } + }); + } + + /** + * Executes a promise with a timeout. + */ + private executeWithTimeout(promise: Promise): Promise { + return new Promise((resolve, reject) => { + const timeoutHandle = setTimeout(() => { + reject(new Error(`Checker timed out after ${this.timeout}ms`)); + }, this.timeout); + + promise + .then(resolve) + .catch(reject) + .finally(() => { + clearTimeout(timeoutHandle); + }); + }); + } +} diff --git a/packages/core/src/safety/context-builder.test.ts b/packages/core/src/safety/context-builder.test.ts new file mode 100644 index 0000000000..3ee9da432c --- /dev/null +++ b/packages/core/src/safety/context-builder.test.ts @@ -0,0 +1,56 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ContextBuilder } from './context-builder.js'; +import type { Config } from '../config/config.js'; +import type { ConversationTurn } from './protocol.js'; + +describe('ContextBuilder', () => { + let contextBuilder: ContextBuilder; + let mockConfig: Config; + const mockHistory: ConversationTurn[] = [ + { user: { text: 'hello' }, model: { text: 'hi' } }, + ]; + const mockCwd = '/home/user/project'; + const mockWorkspaces = ['/home/user/project']; + + beforeEach(() => { + vi.spyOn(process, 'cwd').mockReturnValue(mockCwd); + mockConfig = { + getWorkspaceContext: vi.fn().mockReturnValue({ + getDirectories: vi.fn().mockReturnValue(mockWorkspaces), + }), + apiKey: 'secret-api-key', + somePublicConfig: 'public-value', + nested: { + secretToken: 'hidden', + public: 'visible', + }, + } as unknown as Config; + contextBuilder = new ContextBuilder(mockConfig, mockHistory); + }); + + it('should build full context with all fields', () => { + const context = contextBuilder.buildFullContext(); + expect(context.environment.cwd).toBe(mockCwd); + expect(context.environment.workspaces).toEqual(mockWorkspaces); + expect(context.history?.turns).toEqual(mockHistory); + }); + + it('should build minimal context with only required keys', () => { + const context = contextBuilder.buildMinimalContext(['environment']); + expect(context).toHaveProperty('environment'); + expect(context).not.toHaveProperty('config'); + expect(context).not.toHaveProperty('history'); + }); + + it('should handle missing history', () => { + contextBuilder = new ContextBuilder(mockConfig); + const context = contextBuilder.buildFullContext(); + expect(context.history?.turns).toEqual([]); + }); +}); diff --git a/packages/core/src/safety/context-builder.ts b/packages/core/src/safety/context-builder.ts new file mode 100644 index 0000000000..9c20a1d7ab --- /dev/null +++ b/packages/core/src/safety/context-builder.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { SafetyCheckInput, ConversationTurn } from './protocol.js'; +import type { Config } from '../config/config.js'; + +/** + * Builds context objects for safety checkers, ensuring sensitive data is filtered. + */ +export class ContextBuilder { + constructor( + private readonly config: Config, + private readonly conversationHistory: ConversationTurn[] = [], + ) {} + + /** + * Builds the full context object with all available data. + */ + buildFullContext(): SafetyCheckInput['context'] { + return { + environment: { + cwd: process.cwd(), + workspaces: this.config + .getWorkspaceContext() + .getDirectories() as string[], + }, + history: { + turns: this.conversationHistory, + }, + }; + } + + /** + * Builds a minimal context with only the specified keys. + */ + buildMinimalContext( + requiredKeys: Array, + ): SafetyCheckInput['context'] { + const fullContext = this.buildFullContext(); + const minimalContext: Partial = {}; + + for (const key of requiredKeys) { + if (key in fullContext) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (minimalContext as any)[key] = fullContext[key]; + } + } + + return minimalContext as SafetyCheckInput['context']; + } +} diff --git a/packages/core/src/safety/protocol.ts b/packages/core/src/safety/protocol.ts new file mode 100644 index 0000000000..5028bd6897 --- /dev/null +++ b/packages/core/src/safety/protocol.ts @@ -0,0 +1,100 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { FunctionCall } from '@google/genai'; + +/** + * Represents a single turn in the conversation between the user and the model. + * This provides semantic context for why a tool call might be happening. + */ +export interface ConversationTurn { + user: { + text: string; + }; + model: { + text?: string; + toolCalls?: FunctionCall[]; + }; +} + +/** + * The data structure passed from the CLI to a safety checker process via stdin. + */ +export interface SafetyCheckInput { + /** + * The semantic version of the protocol (e.g., "1.0.0"). This allows + * for introducing breaking changes in the future while maintaining + * support for older checkers. + */ + protocolVersion: '1.0.0'; + + /** + * The specific tool call that is being validated. + */ + toolCall: FunctionCall; + + /** + * A container for all contextual information from the CLI's internal state. + * By grouping data into categories, we can easily add new context in the + * future without creating a flat, unmanageable object. + */ + context: { + /** + * Information about the user's file system and execution environment. + */ + environment: { + cwd: string; + workspaces: string[]; // A list of user-configured workspace roots + }; + + /** + * The recent history of the conversation. This can be used by checkers + * that need to understand the intent behind a tool call. + */ + history?: { + turns: ConversationTurn[]; + }; + }; + + /** + * Configuration for the safety checker. + * This allows checkers to be parameterized (e.g. allowed paths). + */ + config?: unknown; +} + +/** + * The possible decisions a safety checker can make. + */ +export enum SafetyCheckDecision { + ALLOW = 'allow', + DENY = 'deny', + ASK_USER = 'ask_user', +} + +/** + * The data structure returned by a safety checker process via stdout. + */ +export type SafetyCheckResult = + | { + /** + * The decision made by the safety checker. + */ + decision: SafetyCheckDecision.ALLOW; + /** + * If not allowed, a message explaining why the tool call was blocked. + * This will be shown to the user. + */ + reason?: string; + } + | { + decision: SafetyCheckDecision.DENY; + reason: string; + } + | { + decision: SafetyCheckDecision.ASK_USER; + reason: string; + }; diff --git a/packages/core/src/safety/registry.test.ts b/packages/core/src/safety/registry.test.ts new file mode 100644 index 0000000000..b0f9d26744 --- /dev/null +++ b/packages/core/src/safety/registry.test.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach } from 'vitest'; +import { CheckerRegistry } from './registry.js'; +import { InProcessCheckerType } from '../policy/types.js'; +import { AllowedPathChecker } from './built-in.js'; + +describe('CheckerRegistry', () => { + let registry: CheckerRegistry; + const mockCheckersPath = '/mock/checkers/path'; + + beforeEach(() => { + registry = new CheckerRegistry(mockCheckersPath); + }); + + it('should resolve built-in in-process checkers', () => { + const checker = registry.resolveInProcess( + InProcessCheckerType.ALLOWED_PATH, + ); + expect(checker).toBeInstanceOf(AllowedPathChecker); + }); + + it('should throw for unknown in-process checkers', () => { + expect(() => registry.resolveInProcess('unknown-checker')).toThrow( + 'Unknown in-process checker "unknown-checker"', + ); + }); + + it('should validate checker names', () => { + expect(() => registry.resolveInProcess('invalid name!')).toThrow( + 'Invalid checker name', + ); + expect(() => registry.resolveInProcess('../escape')).toThrow( + 'Invalid checker name', + ); + }); + + it('should throw for unknown external checkers (for now)', () => { + expect(() => registry.resolveExternal('some-external')).toThrow( + 'Unknown external checker "some-external"', + ); + }); +}); diff --git a/packages/core/src/safety/registry.ts b/packages/core/src/safety/registry.ts new file mode 100644 index 0000000000..2775a82fd4 --- /dev/null +++ b/packages/core/src/safety/registry.ts @@ -0,0 +1,83 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as path from 'node:path'; +import * as fs from 'node:fs'; +import { type InProcessChecker, AllowedPathChecker } from './built-in.js'; +import { InProcessCheckerType } from '../policy/types.js'; + +/** + * Registry for managing safety checker resolution. + */ +export class CheckerRegistry { + private static readonly BUILT_IN_EXTERNAL_CHECKERS = new Map([ + // No external built-ins for now + ]); + + private static readonly BUILT_IN_IN_PROCESS_CHECKERS = new Map< + string, + InProcessChecker + >([[InProcessCheckerType.ALLOWED_PATH, new AllowedPathChecker()]]); + + // Regex to validate checker names (alphanumeric and hyphens only) + private static readonly VALID_NAME_PATTERN = /^[a-z0-9-]+$/; + + constructor(private readonly checkersPath: string) {} + + /** + * Resolves an external checker name to an absolute executable path. + */ + resolveExternal(name: string): string { + if (!CheckerRegistry.isValidCheckerName(name)) { + throw new Error( + `Invalid checker name "${name}". Checker names must contain only lowercase letters, numbers, and hyphens.`, + ); + } + + const builtInPath = CheckerRegistry.BUILT_IN_EXTERNAL_CHECKERS.get(name); + if (builtInPath) { + const fullPath = path.join(this.checkersPath, builtInPath); + if (!fs.existsSync(fullPath)) { + throw new Error(`Built-in checker "${name}" not found at ${fullPath}`); + } + return fullPath; + } + + // TODO: Phase 5 - Add support for custom external checkers + throw new Error(`Unknown external checker "${name}".`); + } + + /** + * Resolves an in-process checker name to a checker instance. + */ + resolveInProcess(name: string): InProcessChecker { + if (!CheckerRegistry.isValidCheckerName(name)) { + throw new Error(`Invalid checker name "${name}".`); + } + + const checker = CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.get(name); + if (checker) { + return checker; + } + + throw new Error( + `Unknown in-process checker "${name}". Available: ${Array.from( + CheckerRegistry.BUILT_IN_IN_PROCESS_CHECKERS.keys(), + ).join(', ')}`, + ); + } + + private static isValidCheckerName(name: string): boolean { + return this.VALID_NAME_PATTERN.test(name) && !name.includes('..'); + } + + static getBuiltInCheckers(): string[] { + return [ + ...Array.from(this.BUILT_IN_EXTERNAL_CHECKERS.keys()), + ...Array.from(this.BUILT_IN_IN_PROCESS_CHECKERS.keys()), + ]; + } +} diff --git a/packages/core/src/tools/base-tool-invocation.test.ts b/packages/core/src/tools/base-tool-invocation.test.ts index 38d651f076..f25bc8828f 100644 --- a/packages/core/src/tools/base-tool-invocation.test.ts +++ b/packages/core/src/tools/base-tool-invocation.test.ts @@ -47,11 +47,13 @@ describe('BaseToolInvocation', () => { ); let capturedRequest: ToolConfirmationRequest | undefined; - vi.mocked(messageBus.publish).mockImplementation((request: Message) => { - if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { - capturedRequest = request; - } - }); + vi.mocked(messageBus.publish).mockImplementation( + async (request: Message) => { + if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + capturedRequest = request; + } + }, + ); let responseHandler: | ((response: ToolConfirmationResponse) => void) @@ -102,11 +104,13 @@ describe('BaseToolInvocation', () => { ); let capturedRequest: ToolConfirmationRequest | undefined; - vi.mocked(messageBus.publish).mockImplementation((request: Message) => { - if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { - capturedRequest = request; - } - }); + vi.mocked(messageBus.publish).mockImplementation( + async (request: Message) => { + if (request.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + capturedRequest = request; + } + }, + ); // We need to mock subscribe to avoid hanging if we want to await the promise, // but for this test we just need to check publish.