diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 84c2fdb434..fc49da7a54 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -71,6 +71,9 @@ import { WorkspaceContext } from '../utils/workspaceContext.js'; import { Storage } from './storage.js'; import { FileExclusions } from '../utils/ignorePatterns.js'; import type { EventEmitter } from 'node:events'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import type { PolicyEngineConfig } from '../policy/types.js'; import type { UserTierId } from '../code_assist/types.js'; import { ProxyAgent, setGlobalDispatcher } from 'undici'; @@ -233,6 +236,7 @@ export interface ConfigParameters { enableToolOutputTruncation?: boolean; eventEmitter?: EventEmitter; useSmartEdit?: boolean; + policyEngineConfig?: PolicyEngineConfig; output?: OutputSettings; } @@ -316,6 +320,8 @@ export class Config { private readonly fileExclusions: FileExclusions; private readonly eventEmitter?: EventEmitter; private readonly useSmartEdit: boolean; + private readonly messageBus: MessageBus; + private readonly policyEngine: PolicyEngine; private readonly outputSettings: OutputSettings; constructor(params: ConfigParameters) { @@ -400,6 +406,8 @@ export class Config { this.enablePromptCompletion = params.enablePromptCompletion ?? false; this.fileExclusions = new FileExclusions(this); this.eventEmitter = params.eventEmitter; + this.policyEngine = new PolicyEngine(params.policyEngineConfig); + this.messageBus = new MessageBus(this.policyEngine); this.outputSettings = { format: params.output?.format ?? OutputFormat.TEXT, }; @@ -908,6 +916,14 @@ export class Config { return this.fileExclusions; } + getMessageBus(): MessageBus { + return this.messageBus; + } + + getPolicyEngine(): PolicyEngine { + return this.policyEngine; + } + async createToolRegistry(): Promise { const registry = new ToolRegistry(this, this.eventEmitter); diff --git a/packages/core/src/confirmation-bus/index.ts b/packages/core/src/confirmation-bus/index.ts new file mode 100644 index 0000000000..379d9aa4d8 --- /dev/null +++ b/packages/core/src/confirmation-bus/index.ts @@ -0,0 +1,8 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './message-bus.js'; +export * from './types.js'; diff --git a/packages/core/src/confirmation-bus/message-bus.test.ts b/packages/core/src/confirmation-bus/message-bus.test.ts new file mode 100644 index 0000000000..8156671c9b --- /dev/null +++ b/packages/core/src/confirmation-bus/message-bus.test.ts @@ -0,0 +1,235 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, vi } from 'vitest'; +import { MessageBus } from './message-bus.js'; +import { PolicyEngine } from '../policy/policy-engine.js'; +import { PolicyDecision } from '../policy/types.js'; +import { + MessageBusType, + type ToolConfirmationRequest, + type ToolConfirmationResponse, + type ToolPolicyRejection, + type ToolExecutionSuccess, +} from './types.js'; + +describe('MessageBus', () => { + let messageBus: MessageBus; + let policyEngine: PolicyEngine; + + beforeEach(() => { + policyEngine = new PolicyEngine(); + messageBus = new MessageBus(policyEngine); + }); + + describe('publish', () => { + it('should emit error for invalid message', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // @ts-expect-error - Testing invalid message + messageBus.publish({ invalid: 'message' }); + + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringContaining('Invalid message structure'), + }), + ); + }); + + it('should validate tool confirmation requests have correlationId', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // @ts-expect-error - Testing missing correlationId + messageBus.publish({ + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test' }, + }); + + expect(errorHandler).toHaveBeenCalled(); + }); + + it('should emit confirmation response when policy allows', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ALLOW); + + const responseHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + const expectedResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123', + confirmed: true, + }; + expect(responseHandler).toHaveBeenCalledWith(expectedResponse); + }); + + it('should emit rejection and response when policy denies', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.DENY); + + const responseHandler = vi.fn(); + const rejectionHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_RESPONSE, + responseHandler, + ); + messageBus.subscribe( + MessageBusType.TOOL_POLICY_REJECTION, + rejectionHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + const expectedRejection: ToolPolicyRejection = { + type: MessageBusType.TOOL_POLICY_REJECTION, + toolCall: { name: 'test-tool', args: {} }, + }; + expect(rejectionHandler).toHaveBeenCalledWith(expectedRejection); + + const expectedResponse: ToolConfirmationResponse = { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: '123', + confirmed: false, + }; + expect(responseHandler).toHaveBeenCalledWith(expectedResponse); + }); + + it('should pass through to UI when policy says ASK_USER', () => { + vi.spyOn(policyEngine, 'check').mockReturnValue(PolicyDecision.ASK_USER); + + const requestHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_CONFIRMATION_REQUEST, + requestHandler, + ); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool', args: {} }, + correlationId: '123', + }; + + messageBus.publish(request); + + expect(requestHandler).toHaveBeenCalledWith(request); + }); + + it('should emit other message types directly', () => { + const successHandler = vi.fn(); + messageBus.subscribe( + MessageBusType.TOOL_EXECUTION_SUCCESS, + successHandler, + ); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test-tool' }, + result: 'success', + }; + + messageBus.publish(message); + + expect(successHandler).toHaveBeenCalledWith(message); + }); + }); + + describe('subscribe/unsubscribe', () => { + it('should allow subscribing to specific message types', () => { + const handler = vi.fn(); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler).toHaveBeenCalledWith(message); + }); + + it('should allow unsubscribing from message types', () => { + const handler = vi.fn(); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + messageBus.unsubscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler).not.toHaveBeenCalled(); + }); + + it('should support multiple subscribers for the same message type', () => { + const handler1 = vi.fn(); + const handler2 = vi.fn(); + + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler1); + messageBus.subscribe(MessageBusType.TOOL_EXECUTION_SUCCESS, handler2); + + const message: ToolExecutionSuccess = { + type: MessageBusType.TOOL_EXECUTION_SUCCESS as const, + toolCall: { name: 'test' }, + result: 'test', + }; + + messageBus.publish(message); + + expect(handler1).toHaveBeenCalledWith(message); + expect(handler2).toHaveBeenCalledWith(message); + }); + }); + + describe('error handling', () => { + it('should not crash on errors during message processing', () => { + const errorHandler = vi.fn(); + messageBus.on('error', errorHandler); + + // Mock policyEngine to throw an error + vi.spyOn(policyEngine, 'check').mockImplementation(() => { + throw new Error('Policy check failed'); + }); + + const request: ToolConfirmationRequest = { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST, + toolCall: { name: 'test-tool' }, + correlationId: '123', + }; + + // Should not throw + expect(() => messageBus.publish(request)).not.toThrow(); + + // Should emit error + expect(errorHandler).toHaveBeenCalledWith( + expect.objectContaining({ + message: 'Policy check failed', + }), + ); + }); + }); +}); diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts new file mode 100644 index 0000000000..b9d66eff6a --- /dev/null +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { EventEmitter } from 'node:events'; +import type { PolicyEngine } from '../policy/policy-engine.js'; +import { PolicyDecision } from '../policy/types.js'; +import { MessageBusType, type Message } from './types.js'; +import { safeJsonStringify } from '../utils/safeJsonStringify.js'; + +export class MessageBus extends EventEmitter { + constructor(private readonly policyEngine: PolicyEngine) { + super(); + } + + private isValidMessage(message: Message): boolean { + if (!message || !message.type) { + return false; + } + + if ( + message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST && + !('correlationId' in message) + ) { + return false; + } + + return true; + } + + private emitMessage(message: Message): void { + this.emit(message.type, message); + } + + publish(message: Message): void { + try { + if (!this.isValidMessage(message)) { + throw new Error( + `Invalid message structure: ${safeJsonStringify(message)}`, + ); + } + + if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) { + const decision = this.policyEngine.check(message.toolCall); + + switch (decision) { + case PolicyDecision.ALLOW: + // Directly emit the response instead of recursive publish + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: true, + }); + break; + case PolicyDecision.DENY: + // Emit both rejection and response messages + this.emitMessage({ + type: MessageBusType.TOOL_POLICY_REJECTION, + toolCall: message.toolCall, + }); + this.emitMessage({ + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE, + correlationId: message.correlationId, + confirmed: false, + }); + break; + case PolicyDecision.ASK_USER: + // Pass through to UI for user confirmation + this.emitMessage(message); + break; + default: + throw new Error(`Unknown policy decision: ${decision}`); + } + } else { + // For all other message types, just emit them + this.emitMessage(message); + } + } catch (error) { + this.emit('error', error); + } + } + + subscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.on(type, listener); + } + + unsubscribe( + type: T['type'], + listener: (message: T) => void, + ): void { + this.off(type, listener); + } +} diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts new file mode 100644 index 0000000000..cb86595be9 --- /dev/null +++ b/packages/core/src/confirmation-bus/types.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionCall } from '@google/genai'; + +export enum MessageBusType { + TOOL_CONFIRMATION_REQUEST = 'tool-confirmation-request', + TOOL_CONFIRMATION_RESPONSE = 'tool-confirmation-response', + TOOL_POLICY_REJECTION = 'tool-policy-rejection', + TOOL_EXECUTION_SUCCESS = 'tool-execution-success', + TOOL_EXECUTION_FAILURE = 'tool-execution-failure', +} + +export interface ToolConfirmationRequest { + type: MessageBusType.TOOL_CONFIRMATION_REQUEST; + toolCall: FunctionCall; + correlationId: string; +} + +export interface ToolConfirmationResponse { + type: MessageBusType.TOOL_CONFIRMATION_RESPONSE; + correlationId: string; + confirmed: boolean; +} + +export interface ToolPolicyRejection { + type: MessageBusType.TOOL_POLICY_REJECTION; + toolCall: FunctionCall; +} + +export interface ToolExecutionSuccess { + type: MessageBusType.TOOL_EXECUTION_SUCCESS; + toolCall: FunctionCall; + result: T; +} + +export interface ToolExecutionFailure { + type: MessageBusType.TOOL_EXECUTION_FAILURE; + toolCall: FunctionCall; + error: E; +} + +export type Message = + | ToolConfirmationRequest + | ToolConfirmationResponse + | ToolPolicyRejection + | ToolExecutionSuccess + | ToolExecutionFailure; diff --git a/packages/core/src/policy/index.ts b/packages/core/src/policy/index.ts new file mode 100644 index 0000000000..e15309ca69 --- /dev/null +++ b/packages/core/src/policy/index.ts @@ -0,0 +1,8 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export * from './policy-engine.js'; +export * from './types.js'; diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts new file mode 100644 index 0000000000..51f222b2e4 --- /dev/null +++ b/packages/core/src/policy/policy-engine.test.ts @@ -0,0 +1,624 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach } from 'vitest'; +import { PolicyEngine } from './policy-engine.js'; +import { + PolicyDecision, + type PolicyRule, + type PolicyEngineConfig, +} from './types.js'; +import type { FunctionCall } from '@google/genai'; + +describe('PolicyEngine', () => { + let engine: PolicyEngine; + + beforeEach(() => { + engine = new PolicyEngine(); + }); + + describe('constructor', () => { + it('should use default config when none provided', () => { + const decision = engine.check({ name: 'test' }); + expect(decision).toBe(PolicyDecision.ASK_USER); + }); + + it('should respect custom default decision', () => { + engine = new PolicyEngine({ defaultDecision: PolicyDecision.DENY }); + const decision = engine.check({ name: 'test' }); + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should sort rules by priority', () => { + const rules: PolicyRule[] = [ + { toolName: 'tool1', decision: PolicyDecision.DENY, priority: 1 }, + { toolName: 'tool2', decision: PolicyDecision.ALLOW, priority: 10 }, + { toolName: 'tool3', decision: PolicyDecision.ASK_USER, priority: 5 }, + ]; + + engine = new PolicyEngine({ rules }); + const sortedRules = engine.getRules(); + + expect(sortedRules[0].priority).toBe(10); + expect(sortedRules[1].priority).toBe(5); + expect(sortedRules[2].priority).toBe(1); + }); + }); + + describe('check', () => { + it('should match tool by name', () => { + const rules: PolicyRule[] = [ + { toolName: 'shell', decision: PolicyDecision.ALLOW }, + { toolName: 'edit', decision: PolicyDecision.DENY }, + ]; + + engine = new PolicyEngine({ rules }); + + expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); + expect(engine.check({ name: 'other' })).toBe(PolicyDecision.ASK_USER); + }); + + it('should match by args pattern', () => { + const rules: PolicyRule[] = [ + { + toolName: 'shell', + argsPattern: /rm -rf/, + decision: PolicyDecision.DENY, + }, + { + toolName: 'shell', + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const dangerousCall: FunctionCall = { + name: 'shell', + args: { command: 'rm -rf /' }, + }; + + const safeCall: FunctionCall = { + name: 'shell', + args: { command: 'ls -la' }, + }; + + expect(engine.check(dangerousCall)).toBe(PolicyDecision.DENY); + expect(engine.check(safeCall)).toBe(PolicyDecision.ALLOW); + }); + + it('should apply rules by priority', () => { + const rules: PolicyRule[] = [ + { toolName: 'shell', decision: PolicyDecision.DENY, priority: 1 }, + { toolName: 'shell', decision: PolicyDecision.ALLOW, priority: 10 }, + ]; + + engine = new PolicyEngine({ rules }); + + // Higher priority rule (ALLOW) should win + expect(engine.check({ name: 'shell' })).toBe(PolicyDecision.ALLOW); + }); + + it('should apply wildcard rules (no toolName)', () => { + const rules: PolicyRule[] = [ + { decision: PolicyDecision.DENY }, // Applies to all tools + { toolName: 'safe-tool', decision: PolicyDecision.ALLOW, priority: 10 }, + ]; + + engine = new PolicyEngine({ rules }); + + expect(engine.check({ name: 'safe-tool' })).toBe(PolicyDecision.ALLOW); + expect(engine.check({ name: 'any-other-tool' })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle non-interactive mode', () => { + const config: PolicyEngineConfig = { + nonInteractive: true, + rules: [ + { toolName: 'interactive-tool', decision: PolicyDecision.ASK_USER }, + { toolName: 'allowed-tool', decision: PolicyDecision.ALLOW }, + ], + }; + + engine = new PolicyEngine(config); + + // ASK_USER should become DENY in non-interactive mode + expect(engine.check({ name: 'interactive-tool' })).toBe( + PolicyDecision.DENY, + ); + // ALLOW should remain ALLOW + expect(engine.check({ name: 'allowed-tool' })).toBe(PolicyDecision.ALLOW); + // Default ASK_USER should also become DENY + expect(engine.check({ name: 'unknown-tool' })).toBe(PolicyDecision.DENY); + }); + }); + + describe('addRule', () => { + it('should add a new rule and maintain priority order', () => { + engine.addRule({ + toolName: 'tool1', + decision: PolicyDecision.ALLOW, + priority: 5, + }); + engine.addRule({ + toolName: 'tool2', + decision: PolicyDecision.DENY, + priority: 10, + }); + engine.addRule({ + toolName: 'tool3', + decision: PolicyDecision.ASK_USER, + priority: 1, + }); + + const rules = engine.getRules(); + expect(rules).toHaveLength(3); + expect(rules[0].priority).toBe(10); + expect(rules[1].priority).toBe(5); + expect(rules[2].priority).toBe(1); + }); + + it('should apply newly added rules', () => { + expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ASK_USER); + + engine.addRule({ toolName: 'new-tool', decision: PolicyDecision.ALLOW }); + + expect(engine.check({ name: 'new-tool' })).toBe(PolicyDecision.ALLOW); + }); + }); + + describe('removeRulesForTool', () => { + it('should remove rules for specific tool', () => { + engine.addRule({ toolName: 'tool1', decision: PolicyDecision.ALLOW }); + engine.addRule({ toolName: 'tool2', decision: PolicyDecision.DENY }); + engine.addRule({ + toolName: 'tool1', + decision: PolicyDecision.ASK_USER, + priority: 10, + }); + + expect(engine.getRules()).toHaveLength(3); + + engine.removeRulesForTool('tool1'); + + const remainingRules = engine.getRules(); + expect(remainingRules).toHaveLength(1); + expect(remainingRules.some((r) => r.toolName === 'tool1')).toBe(false); + expect(remainingRules.some((r) => r.toolName === 'tool2')).toBe(true); + }); + + it('should handle removing non-existent tool', () => { + engine.addRule({ toolName: 'existing', decision: PolicyDecision.ALLOW }); + + expect(() => engine.removeRulesForTool('non-existent')).not.toThrow(); + expect(engine.getRules()).toHaveLength(1); + }); + }); + + describe('getRules', () => { + it('should return readonly array of rules', () => { + const rules: PolicyRule[] = [ + { toolName: 'tool1', decision: PolicyDecision.ALLOW }, + { toolName: 'tool2', decision: PolicyDecision.DENY }, + ]; + + engine = new PolicyEngine({ rules }); + + const retrievedRules = engine.getRules(); + expect(retrievedRules).toHaveLength(2); + expect(retrievedRules[0].toolName).toBe('tool1'); + expect(retrievedRules[1].toolName).toBe('tool2'); + }); + }); + + describe('complex scenarios', () => { + it('should handle multiple matching rules with different priorities', () => { + const rules: PolicyRule[] = [ + { decision: PolicyDecision.DENY, priority: 0 }, // Default deny all + { toolName: 'shell', decision: PolicyDecision.ASK_USER, priority: 5 }, + { + toolName: 'shell', + argsPattern: /"command":"ls/, + decision: PolicyDecision.ALLOW, + priority: 10, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Matches highest priority rule (ls command) + expect(engine.check({ name: 'shell', args: { command: 'ls -la' } })).toBe( + PolicyDecision.ALLOW, + ); + + // Matches middle priority rule (shell without ls) + expect(engine.check({ name: 'shell', args: { command: 'pwd' } })).toBe( + PolicyDecision.ASK_USER, + ); + + // Matches lowest priority rule (not shell) + expect(engine.check({ name: 'edit' })).toBe(PolicyDecision.DENY); + }); + + it('should handle tools with no args', () => { + const rules: PolicyRule[] = [ + { + toolName: 'read', + argsPattern: /secret/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Tool call without args should not match pattern + expect(engine.check({ name: 'read' })).toBe(PolicyDecision.ASK_USER); + + // Tool call with args not matching pattern + expect(engine.check({ name: 'read', args: { file: 'public.txt' } })).toBe( + PolicyDecision.ASK_USER, + ); + + // Tool call with args matching pattern + expect(engine.check({ name: 'read', args: { file: 'secret.txt' } })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should match args pattern regardless of property order', () => { + const rules: PolicyRule[] = [ + { + toolName: 'shell', + // Pattern matches the stable stringified format + argsPattern: /"command":"rm[^"]*-rf/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Same args with different property order should both match + const args1 = { command: 'rm -rf /', path: '/home' }; + const args2 = { path: '/home', command: 'rm -rf /' }; + + expect(engine.check({ name: 'shell', args: args1 })).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'shell', args: args2 })).toBe( + PolicyDecision.DENY, + ); + + // Verify safe command doesn't match + const safeArgs = { command: 'ls -la', path: '/home' }; + expect(engine.check({ name: 'shell', args: safeArgs })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should handle nested objects in args with stable stringification', () => { + const rules: PolicyRule[] = [ + { + toolName: 'api', + argsPattern: /"sensitive":true/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Nested objects with different key orders should match consistently + const args1 = { + data: { sensitive: true, value: 'secret' }, + method: 'POST', + }; + const args2 = { + method: 'POST', + data: { value: 'secret', sensitive: true }, + }; + + expect(engine.check({ name: 'api', args: args1 })).toBe( + PolicyDecision.DENY, + ); + expect(engine.check({ name: 'api', args: args2 })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle circular references without stack overflow', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create an object with a circular reference + type CircularArgs = Record & { + data?: Record; + }; + const circularArgs: CircularArgs = { + name: 'test', + data: {}, + }; + // Create circular reference - TypeScript allows this since data is Record + (circularArgs.data as Record)['self'] = + circularArgs.data; + + // Should not throw stack overflow error + expect(() => + engine.check({ name: 'test', args: circularArgs }), + ).not.toThrow(); + + // Should detect the circular reference pattern + expect(engine.check({ name: 'test', args: circularArgs })).toBe( + PolicyDecision.DENY, + ); + + // Non-circular object should not match + const normalArgs = { name: 'test', data: { value: 'normal' } }; + expect(engine.check({ name: 'test', args: normalArgs })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should handle deep circular references', () => { + const rules: PolicyRule[] = [ + { + toolName: 'deep', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create a deep circular reference + type DeepCircular = Record & { + level1?: { + level2?: { + level3?: Record; + }; + }; + }; + const deepCircular: DeepCircular = { + level1: { + level2: { + level3: {}, + }, + }, + }; + // Create circular reference with proper type assertions + const level3 = deepCircular.level1!.level2!.level3!; + level3['back'] = deepCircular.level1; + + // Should handle without stack overflow + expect(() => + engine.check({ name: 'deep', args: deepCircular }), + ).not.toThrow(); + + // Should detect the circular reference + expect(engine.check({ name: 'deep', args: deepCircular })).toBe( + PolicyDecision.DENY, + ); + }); + + it('should handle repeated non-circular objects correctly', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\[Circular\]/, + decision: PolicyDecision.DENY, + }, + { + toolName: 'test', + argsPattern: /"value":"shared"/, + decision: PolicyDecision.ALLOW, + priority: 10, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Create an object with repeated references but no cycles + const sharedObj = { value: 'shared' }; + const args = { + first: sharedObj, + second: sharedObj, + third: { nested: sharedObj }, + }; + + // Should NOT mark repeated objects as circular, and should match the shared value pattern + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should omit undefined and function values from objects', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"definedValue":"test"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + definedValue: 'test', + undefinedValue: undefined, + functionValue: () => 'hello', + nullValue: null, + }; + + // Should match pattern with defined value, undefined and functions omitted + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + + // Check that the pattern would NOT match if undefined was included + const rulesWithUndefined: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /undefinedValue/, + decision: PolicyDecision.DENY, + }, + ]; + engine = new PolicyEngine({ rules: rulesWithUndefined }); + expect(engine.check({ name: 'test', args })).toBe( + PolicyDecision.ASK_USER, + ); + + // Check that the pattern would NOT match if function was included + const rulesWithFunction: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /functionValue/, + decision: PolicyDecision.DENY, + }, + ]; + engine = new PolicyEngine({ rules: rulesWithFunction }); + expect(engine.check({ name: 'test', args })).toBe( + PolicyDecision.ASK_USER, + ); + }); + + it('should convert undefined and functions to null in arrays', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /\["value",null,null,null\]/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + array: ['value', undefined, () => 'hello', null], + }; + + // Should match pattern with undefined and functions converted to null + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should produce valid JSON for all inputs', () => { + const testCases: Array<{ input: Record; desc: string }> = + [ + { input: { simple: 'string' }, desc: 'simple object' }, + { + input: { nested: { deep: { value: 123 } } }, + desc: 'nested object', + }, + { input: { data: [1, 2, 3] }, desc: 'simple array' }, + { input: { mixed: [1, { a: 'b' }, null] }, desc: 'mixed array' }, + { + input: { undef: undefined, func: () => {}, normal: 'value' }, + desc: 'object with undefined and function', + }, + { + input: { data: ['a', undefined, () => {}, null] }, + desc: 'array with undefined and function', + }, + ]; + + for (const { input } of testCases) { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /.*/, + decision: PolicyDecision.ALLOW, + }, + ]; + engine = new PolicyEngine({ rules }); + + // Should not throw when checking (which internally uses stableStringify) + expect(() => engine.check({ name: 'test', args: input })).not.toThrow(); + + // The check should succeed + expect(engine.check({ name: 'test', args: input })).toBe( + PolicyDecision.ALLOW, + ); + } + }); + + it('should respect toJSON methods on objects', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"sanitized":"safe"/, + decision: PolicyDecision.ALLOW, + }, + { + toolName: 'test', + argsPattern: /"dangerous":"data"/, + decision: PolicyDecision.DENY, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Object with toJSON that sanitizes output + const args = { + data: { + dangerous: 'data', + toJSON: () => ({ sanitized: 'safe' }), + }, + }; + + // Should match the sanitized pattern, not the dangerous one + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should handle toJSON that returns primitives', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"value":"string-value"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + value: { + complex: 'object', + toJSON: () => 'string-value', + }, + }; + + // toJSON returns a string, which should be properly stringified + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + + it('should handle toJSON that throws an error', () => { + const rules: PolicyRule[] = [ + { + toolName: 'test', + argsPattern: /"fallback":"value"/, + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + const args = { + data: { + fallback: 'value', + toJSON: () => { + throw new Error('toJSON error'); + }, + }, + }; + + // Should fall back to regular object serialization when toJSON throws + expect(engine.check({ name: 'test', args })).toBe(PolicyDecision.ALLOW); + }); + }); +}); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts new file mode 100644 index 0000000000..e1006ffdef --- /dev/null +++ b/packages/core/src/policy/policy-engine.ts @@ -0,0 +1,107 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { type FunctionCall } from '@google/genai'; +import { + PolicyDecision, + type PolicyEngineConfig, + type PolicyRule, +} from './types.js'; +import { stableStringify } from './stable-stringify.js'; + +function ruleMatches( + rule: PolicyRule, + toolCall: FunctionCall, + stringifiedArgs: string | undefined, +): boolean { + // Check tool name if specified + if (rule.toolName && toolCall.name !== rule.toolName) { + return false; + } + + // Check args pattern if specified + if (rule.argsPattern) { + // If rule has an args pattern but tool has no args, no match + if (!toolCall.args) { + return false; + } + // Use stable JSON stringification with sorted keys to ensure consistent matching + if ( + stringifiedArgs === undefined || + !rule.argsPattern.test(stringifiedArgs) + ) { + return false; + } + } + + return true; +} + +export class PolicyEngine { + private rules: PolicyRule[]; + private readonly defaultDecision: PolicyDecision; + private readonly nonInteractive: boolean; + + constructor(config: PolicyEngineConfig = {}) { + this.rules = (config.rules ?? []).sort( + (a, b) => (b.priority ?? 0) - (a.priority ?? 0), + ); + this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; + this.nonInteractive = config.nonInteractive ?? false; + } + + /** + * Check if a tool call is allowed based on the configured policies. + */ + check(toolCall: FunctionCall): PolicyDecision { + let stringifiedArgs: string | undefined; + // Compute stringified args once before the loop + if (toolCall.args && this.rules.some((rule) => rule.argsPattern)) { + stringifiedArgs = stableStringify(toolCall.args); + } + + // Find the first matching rule (already sorted by priority) + for (const rule of this.rules) { + if (ruleMatches(rule, toolCall, stringifiedArgs)) { + return this.applyNonInteractiveMode(rule.decision); + } + } + + // No matching rule found, use default decision + return this.applyNonInteractiveMode(this.defaultDecision); + } + + /** + * Add a new rule to the policy engine. + */ + addRule(rule: PolicyRule): void { + this.rules.push(rule); + // Re-sort rules by priority + this.rules.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Remove rules for a specific tool. + */ + removeRulesForTool(toolName: string): void { + this.rules = this.rules.filter((rule) => rule.toolName !== toolName); + } + + /** + * Get all current rules. + */ + getRules(): readonly PolicyRule[] { + return this.rules; + } + + private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { + // In non-interactive mode, ASK_USER becomes DENY + if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { + return PolicyDecision.DENY; + } + return decision; + } +} diff --git a/packages/core/src/policy/stable-stringify.ts b/packages/core/src/policy/stable-stringify.ts new file mode 100644 index 0000000000..78db692eab --- /dev/null +++ b/packages/core/src/policy/stable-stringify.ts @@ -0,0 +1,128 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Produces a stable, deterministic JSON string representation with sorted keys. + * + * This method is critical for security policy matching. It ensures that the same + * object always produces the same string representation, regardless of property + * insertion order, which could vary across different JavaScript engines or + * runtime conditions. + * + * Key behaviors: + * 1. **Sorted Keys**: Object properties are always serialized in alphabetical order, + * ensuring deterministic output for pattern matching. + * + * 2. **Circular Reference Protection**: Uses ancestor chain tracking (not just + * object identity) to detect true circular references while correctly handling + * repeated non-circular object references. Circular references are replaced + * with "[Circular]" to prevent stack overflow attacks. + * + * 3. **JSON Spec Compliance**: + * - undefined values: Omitted from objects, converted to null in arrays + * - Functions: Omitted from objects, converted to null in arrays + * - toJSON methods: Respected and called when present (per JSON.stringify spec) + * + * 4. **Security Considerations**: + * - Prevents DoS via circular references that would cause infinite recursion + * - Ensures consistent policy rule matching by normalizing property order + * - Respects toJSON for objects that sanitize their output + * - Handles toJSON methods that throw errors gracefully + * + * @param obj - The object to stringify (typically toolCall.args) + * @returns A deterministic JSON string representation + * + * @example + * // Different property orders produce the same output: + * stableStringify({b: 2, a: 1}) === stableStringify({a: 1, b: 2}) + * // Returns: '{"a":1,"b":2}' + * + * @example + * // Circular references are handled safely: + * const obj = {a: 1}; + * obj.self = obj; + * stableStringify(obj) + * // Returns: '{"a":1,"self":"[Circular]"}' + * + * @example + * // toJSON methods are respected: + * const obj = { + * sensitive: 'secret', + * toJSON: () => ({ safe: 'data' }) + * }; + * stableStringify(obj) + * // Returns: '{"safe":"data"}' + */ +export function stableStringify(obj: unknown): string { + const stringify = (currentObj: unknown, ancestors: Set): string => { + // Handle primitives and null + if (currentObj === undefined) { + return 'null'; // undefined in arrays becomes null in JSON + } + if (currentObj === null) { + return 'null'; + } + if (typeof currentObj === 'function') { + return 'null'; // functions in arrays become null in JSON + } + if (typeof currentObj !== 'object') { + return JSON.stringify(currentObj); + } + + // Check for circular reference (object is in ancestor chain) + if (ancestors.has(currentObj)) { + return '"[Circular]"'; + } + + ancestors.add(currentObj); + + try { + // Check for toJSON method and use it if present + const objWithToJSON = currentObj as { toJSON?: () => unknown }; + if (typeof objWithToJSON.toJSON === 'function') { + try { + const jsonValue = objWithToJSON.toJSON(); + // The result of toJSON needs to be stringified recursively + if (jsonValue === null) { + return 'null'; + } + return stringify(jsonValue, ancestors); + } catch { + // If toJSON throws, treat as a regular object + } + } + + if (Array.isArray(currentObj)) { + const items = currentObj.map((item) => { + // undefined and functions in arrays become null + if (item === undefined || typeof item === 'function') { + return 'null'; + } + return stringify(item, ancestors); + }); + return '[' + items.join(',') + ']'; + } + + // Handle objects - sort keys and filter out undefined/function values + const sortedKeys = Object.keys(currentObj).sort(); + const pairs: string[] = []; + + for (const key of sortedKeys) { + const value = (currentObj as Record)[key]; + // Skip undefined and function values in objects (per JSON spec) + if (value !== undefined && typeof value !== 'function') { + pairs.push(JSON.stringify(key) + ':' + stringify(value, ancestors)); + } + } + + return '{' + pairs.join(',') + '}'; + } finally { + ancestors.delete(currentObj); + } + }; + + return stringify(obj, new Set()); +} diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts new file mode 100644 index 0000000000..f20a88e70c --- /dev/null +++ b/packages/core/src/policy/types.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +export enum PolicyDecision { + ALLOW = 'allow', + DENY = 'deny', + ASK_USER = 'ask_user', +} + +export interface PolicyRule { + /** + * The name of the tool this rule applies to. + * If undefined, the rule applies to all tools. + */ + toolName?: string; + + /** + * Pattern to match against tool arguments. + * Can be used for more fine-grained control. + */ + argsPattern?: RegExp; + + /** + * The decision to make when this rule matches. + */ + decision: PolicyDecision; + + /** + * Priority of this rule. Higher numbers take precedence. + * Default is 0. + */ + priority?: number; +} + +export interface PolicyEngineConfig { + /** + * List of policy rules to apply. + */ + rules?: PolicyRule[]; + + /** + * Default decision when no rules match. + * Defaults to ASK_USER. + */ + defaultDecision?: PolicyDecision; + + /** + * Whether to allow tools in non-interactive mode. + * When true, ASK_USER decisions become DENY. + */ + nonInteractive?: boolean; +}