diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index 0b021049f8..09ee11d20e 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -4,10 +4,16 @@ * SPDX-License-Identifier: Apache-2.0 */ +import { randomUUID } from 'node:crypto'; import { EventEmitter } from 'node:events'; import type { PolicyEngine } from '../policy/policy-engine.js'; -import { PolicyDecision } from '../policy/types.js'; -import { MessageBusType, type Message } from './types.js'; +import { PolicyDecision, getHookSource } from '../policy/types.js'; +import { + MessageBusType, + type Message, + type HookExecutionRequest, + type HookPolicyDecision, +} from './types.js'; import { safeJsonStringify } from '../utils/safeJsonStringify.js'; export class MessageBus extends EventEmitter { @@ -83,6 +89,39 @@ export class MessageBus extends EventEmitter { default: throw new Error(`Unknown policy decision: ${decision}`); } + } else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) { + // Handle hook execution requests through policy evaluation + const hookRequest = message as HookExecutionRequest; + const decision = await this.policyEngine.checkHook(hookRequest); + + // Map decision to allow/deny for observability (ASK_USER treated as deny for hooks) + const effectiveDecision = + decision === PolicyDecision.ALLOW ? 'allow' : 'deny'; + + // Emit policy decision for observability + this.emitMessage({ + type: MessageBusType.HOOK_POLICY_DECISION, + eventName: hookRequest.eventName, + hookSource: getHookSource(hookRequest.input), + decision: effectiveDecision, + reason: + decision !== PolicyDecision.ALLOW + ? 'Hook execution denied by policy' + : undefined, + } as HookPolicyDecision); + + // If allowed, emit the request for hook system to handle + if (decision === PolicyDecision.ALLOW) { + this.emitMessage(message); + } else { + // If denied or ASK_USER, emit error response (hooks don't support interactive confirmation) + this.emitMessage({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: hookRequest.correlationId, + success: false, + error: new Error('Hook execution denied by policy'), + }); + } } else { // For all other message types, just emit them this.emitMessage(message); @@ -105,4 +144,46 @@ export class MessageBus extends EventEmitter { ): void { this.off(type, listener); } + + /** + * Request-response pattern: Publish a message and wait for a correlated response + * This enables synchronous-style communication over the async MessageBus + * The correlation ID is generated internally and added to the request + */ + async request( + request: Omit, + responseType: TResponse['type'], + timeoutMs: number = 60000, + ): Promise { + const correlationId = randomUUID(); + + return new Promise((resolve, reject) => { + const timeoutId = setTimeout(() => { + cleanup(); + reject(new Error(`Request timed out waiting for ${responseType}`)); + }, timeoutMs); + + const cleanup = () => { + clearTimeout(timeoutId); + this.unsubscribe(responseType, responseHandler); + }; + + const responseHandler = (response: TResponse) => { + // Check if this response matches our request + if ( + 'correlationId' in response && + response.correlationId === correlationId + ) { + cleanup(); + resolve(response); + } + }; + + // Subscribe to responses + this.subscribe(responseType, responseHandler); + + // Publish the request with correlation ID + this.publish({ ...request, correlationId } as TRequest); + }); + } } diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 52d7bd2e9f..7c1d010934 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -13,6 +13,9 @@ export enum MessageBusType { TOOL_EXECUTION_SUCCESS = 'tool-execution-success', TOOL_EXECUTION_FAILURE = 'tool-execution-failure', UPDATE_POLICY = 'update-policy', + HOOK_EXECUTION_REQUEST = 'hook-execution-request', + HOOK_EXECUTION_RESPONSE = 'hook-execution-response', + HOOK_POLICY_DECISION = 'hook-policy-decision', } export interface ToolConfirmationRequest { @@ -55,10 +58,36 @@ export interface ToolExecutionFailure { error: E; } +export interface HookExecutionRequest { + type: MessageBusType.HOOK_EXECUTION_REQUEST; + eventName: string; + input: Record; + correlationId: string; +} + +export interface HookExecutionResponse { + type: MessageBusType.HOOK_EXECUTION_RESPONSE; + correlationId: string; + success: boolean; + output?: Record; + error?: Error; +} + +export interface HookPolicyDecision { + type: MessageBusType.HOOK_POLICY_DECISION; + eventName: string; + hookSource: 'project' | 'user' | 'system' | 'extension'; + decision: 'allow' | 'deny'; + reason?: string; +} + export type Message = | ToolConfirmationRequest | ToolConfirmationResponse | ToolPolicyRejection | ToolExecutionSuccess | ToolExecutionFailure - | UpdatePolicy; + | UpdatePolicy + | HookExecutionRequest + | HookExecutionResponse + | HookPolicyDecision; diff --git a/packages/core/src/hooks/hookEventHandler.test.ts b/packages/core/src/hooks/hookEventHandler.test.ts new file mode 100644 index 0000000000..3ee9c8d43e --- /dev/null +++ b/packages/core/src/hooks/hookEventHandler.test.ts @@ -0,0 +1,524 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { HookEventHandler } from './hookEventHandler.js'; +import type { Config } from '../config/config.js'; +import type { HookConfig } from './types.js'; +import type { Logger } from '@opentelemetry/api-logs'; +import type { HookPlanner } from './hookPlanner.js'; +import type { HookRunner } from './hookRunner.js'; +import type { HookAggregator } from './hookAggregator.js'; +import { HookEventName, HookType } from './types.js'; +import { + NotificationType, + SessionStartSource, + type HookExecutionResult, +} from './types.js'; + +// Mock debugLogger +const mockDebugLogger = vi.hoisted(() => ({ + log: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), +})); + +vi.mock('../utils/debugLogger.js', () => ({ + debugLogger: mockDebugLogger, +})); + +describe('HookEventHandler', () => { + let hookEventHandler: HookEventHandler; + let mockConfig: Config; + let mockLogger: Logger; + let mockHookPlanner: HookPlanner; + let mockHookRunner: HookRunner; + let mockHookAggregator: HookAggregator; + + beforeEach(() => { + vi.resetAllMocks(); + + mockConfig = { + getSessionId: vi.fn().mockReturnValue('test-session'), + getWorkingDir: vi.fn().mockReturnValue('/test/project'), + } as unknown as Config; + + mockLogger = {} as Logger; + + mockHookPlanner = { + createExecutionPlan: vi.fn(), + } as unknown as HookPlanner; + + mockHookRunner = { + executeHooksParallel: vi.fn(), + executeHooksSequential: vi.fn(), + } as unknown as HookRunner; + + mockHookAggregator = { + aggregateResults: vi.fn(), + } as unknown as HookAggregator; + + hookEventHandler = new HookEventHandler( + mockConfig, + mockLogger, + mockHookPlanner, + mockHookRunner, + mockHookAggregator, + ); + }); + + describe('fireBeforeToolEvent', () => { + it('should fire BeforeTool event with correct input', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const result = await hookEventHandler.fireBeforeToolEvent('EditTool', { + file: 'test.txt', + }); + + expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith( + HookEventName.BeforeTool, + { toolName: 'EditTool' }, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.BeforeTool, + expect.objectContaining({ + session_id: 'test-session', + cwd: '/test/project', + hook_event_name: 'BeforeTool', + tool_name: 'EditTool', + tool_input: { file: 'test.txt' }, + }), + ); + + expect(result).toBe(mockAggregated); + }); + + it('should return empty result when no hooks to execute', async () => { + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue(null); + + const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {}); + + expect(result.success).toBe(true); + expect(result.allOutputs).toHaveLength(0); + expect(result.errors).toHaveLength(0); + expect(result.totalDuration).toBe(0); + }); + + it('should handle execution errors gracefully', async () => { + vi.mocked(mockHookPlanner.createExecutionPlan).mockImplementation(() => { + throw new Error('Planning failed'); + }); + + const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {}); + + expect(result.success).toBe(false); + expect(result.errors).toHaveLength(1); + expect(result.errors[0].message).toBe('Planning failed'); + expect(mockDebugLogger.error).toHaveBeenCalled(); + }); + }); + + describe('fireAfterToolEvent', () => { + it('should fire AfterTool event with tool response', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './after.sh', + } as unknown as HookConfig, + eventName: HookEventName.AfterTool, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const toolInput = { file: 'test.txt' }; + const toolResponse = { success: true, content: 'File edited' }; + + const result = await hookEventHandler.fireAfterToolEvent( + 'EditTool', + toolInput, + toolResponse, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.AfterTool, + expect.objectContaining({ + tool_name: 'EditTool', + tool_input: toolInput, + tool_response: toolResponse, + }), + ); + + expect(result).toBe(mockAggregated); + }); + }); + + describe('fireBeforeAgentEvent', () => { + it('should fire BeforeAgent event with prompt', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './before_agent.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeAgent, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 100, + hookConfig: { + type: HookType.Command, + command: './test.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeTool, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 100, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const prompt = 'Please help me with this task'; + + const result = await hookEventHandler.fireBeforeAgentEvent(prompt); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.BeforeAgent, + expect.objectContaining({ + prompt, + }), + ); + + expect(result).toBe(mockAggregated); + }); + }); + + describe('fireNotificationEvent', () => { + it('should fire Notification event', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './notification-hook.sh', + } as HookConfig, + eventName: HookEventName.Notification, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 50, + hookConfig: { + type: HookType.Command, + command: './notification-hook.sh', + timeout: 30000, + }, + eventName: HookEventName.Notification, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 50, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.Notification, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const message = 'Tool execution requires permission'; + + const result = await hookEventHandler.fireNotificationEvent( + NotificationType.ToolPermission, + message, + { type: 'ToolPermission', title: 'Test Permission' }, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.Notification, + expect.objectContaining({ + notification_type: 'ToolPermission', + details: { type: 'ToolPermission', title: 'Test Permission' }, + }), + ); + + expect(result).toBe(mockAggregated); + }); + }); + + describe('fireSessionStartEvent', () => { + it('should fire SessionStart event with source', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './session_start.sh', + } as unknown as HookConfig, + eventName: HookEventName.SessionStart, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 200, + hookConfig: { + type: HookType.Command, + command: './session_start.sh', + timeout: 30000, + }, + eventName: HookEventName.SessionStart, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 200, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.SessionStart, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const result = await hookEventHandler.fireSessionStartEvent( + SessionStartSource.Startup, + ); + + expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith( + HookEventName.SessionStart, + { trigger: 'startup' }, + ); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.SessionStart, + expect.objectContaining({ + source: 'startup', + }), + ); + + expect(result).toBe(mockAggregated); + }); + }); + + describe('fireBeforeModelEvent', () => { + it('should fire BeforeModel event with LLM request', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './model-hook.sh', + } as HookConfig, + eventName: HookEventName.BeforeModel, + }, + ]; + const mockResults: HookExecutionResult[] = [ + { + success: true, + duration: 150, + hookConfig: { + type: HookType.Command, + command: './model-hook.sh', + timeout: 30000, + }, + eventName: HookEventName.BeforeModel, + }, + ]; + const mockAggregated = { + success: true, + allOutputs: [], + errors: [], + totalDuration: 150, + }; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeModel, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue( + mockResults, + ); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue( + mockAggregated, + ); + + const llmRequest = { + model: 'gemini-pro', + config: { temperature: 0.7 }, + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + + const result = await hookEventHandler.fireBeforeModelEvent(llmRequest); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + [mockPlan[0].hookConfig], + HookEventName.BeforeModel, + expect.objectContaining({ + llm_request: expect.objectContaining({ + model: 'gemini-pro', + messages: expect.arrayContaining([ + expect.objectContaining({ + role: 'user', + content: 'Hello', + }), + ]), + }), + }), + ); + + expect(result).toBe(mockAggregated); + }); + }); + + describe('createBaseInput', () => { + it('should create base input with correct fields', async () => { + const mockPlan = [ + { + hookConfig: { + type: HookType.Command, + command: './test.sh', + } as unknown as HookConfig, + eventName: HookEventName.BeforeTool, + }, + ]; + + vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({ + eventName: HookEventName.BeforeTool, + hookConfigs: mockPlan.map((p) => p.hookConfig), + sequential: false, + }); + vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue([]); + vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue({ + success: true, + allOutputs: [], + errors: [], + totalDuration: 0, + }); + + await hookEventHandler.fireBeforeToolEvent('TestTool', {}); + + expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith( + expect.any(Array), + HookEventName.BeforeTool, + expect.objectContaining({ + session_id: 'test-session', + transcript_path: '', + cwd: '/test/project', + hook_event_name: 'BeforeTool', + timestamp: expect.any(String), + }), + ); + }); + }); +}); diff --git a/packages/core/src/hooks/hookEventHandler.ts b/packages/core/src/hooks/hookEventHandler.ts new file mode 100644 index 0000000000..67b61e3588 --- /dev/null +++ b/packages/core/src/hooks/hookEventHandler.ts @@ -0,0 +1,732 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Logger } from '@opentelemetry/api-logs'; +import type { Config } from '../config/config.js'; +import type { HookPlanner, HookEventContext } from './hookPlanner.js'; +import type { HookRunner } from './hookRunner.js'; +import type { HookAggregator, AggregatedHookResult } from './hookAggregator.js'; +import { HookEventName } from './types.js'; +import type { + HookInput, + BeforeToolInput, + AfterToolInput, + BeforeAgentInput, + NotificationInput, + AfterAgentInput, + SessionStartInput, + SessionEndInput, + PreCompressInput, + BeforeModelInput, + AfterModelInput, + BeforeToolSelectionInput, + NotificationType, + SessionStartSource, + SessionEndReason, + PreCompressTrigger, + HookExecutionResult, +} from './types.js'; +import { defaultHookTranslator } from './hookTranslator.js'; +import type { + GenerateContentParameters, + GenerateContentResponse, +} from '@google/genai'; +import { logHookCall } from '../telemetry/loggers.js'; +import { HookCallEvent } from '../telemetry/types.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + MessageBusType, + type HookExecutionRequest, +} from '../confirmation-bus/types.js'; +import { debugLogger } from '../utils/debugLogger.js'; + +/** + * Validates that a value is a non-null object + */ +function isObject(value: unknown): value is Record { + return typeof value === 'object' && value !== null; +} + +/** + * Validates BeforeTool input fields + */ +function validateBeforeToolInput(input: Record): { + toolName: string; + toolInput: Record; +} { + const toolName = input['tool_name']; + const toolInput = input['tool_input']; + if (typeof toolName !== 'string') { + throw new Error( + 'Invalid input for BeforeTool hook event: tool_name must be a string', + ); + } + if (!isObject(toolInput)) { + throw new Error( + 'Invalid input for BeforeTool hook event: tool_input must be an object', + ); + } + return { toolName, toolInput }; +} + +/** + * Validates AfterTool input fields + */ +function validateAfterToolInput(input: Record): { + toolName: string; + toolInput: Record; + toolResponse: Record; +} { + const toolName = input['tool_name']; + const toolInput = input['tool_input']; + const toolResponse = input['tool_response']; + if (typeof toolName !== 'string') { + throw new Error( + 'Invalid input for AfterTool hook event: tool_name must be a string', + ); + } + if (!isObject(toolInput)) { + throw new Error( + 'Invalid input for AfterTool hook event: tool_input must be an object', + ); + } + if (!isObject(toolResponse)) { + throw new Error( + 'Invalid input for AfterTool hook event: tool_response must be an object', + ); + } + return { toolName, toolInput, toolResponse }; +} + +/** + * Validates BeforeAgent input fields + */ +function validateBeforeAgentInput(input: Record): { + prompt: string; +} { + const prompt = input['prompt']; + if (typeof prompt !== 'string') { + throw new Error( + 'Invalid input for BeforeAgent hook event: prompt must be a string', + ); + } + return { prompt }; +} + +/** + * Validates AfterAgent input fields + */ +function validateAfterAgentInput(input: Record): { + prompt: string; + promptResponse: string; + stopHookActive: boolean; +} { + const prompt = input['prompt']; + const promptResponse = input['prompt_response']; + const stopHookActive = input['stop_hook_active']; + if (typeof prompt !== 'string') { + throw new Error( + 'Invalid input for AfterAgent hook event: prompt must be a string', + ); + } + if (typeof promptResponse !== 'string') { + throw new Error( + 'Invalid input for AfterAgent hook event: prompt_response must be a string', + ); + } + // stopHookActive defaults to false if not a boolean + return { + prompt, + promptResponse, + stopHookActive: + typeof stopHookActive === 'boolean' ? stopHookActive : false, + }; +} + +/** + * Validates model-related input fields (llm_request) + */ +function validateModelInput( + input: Record, + eventName: string, +): { llmRequest: GenerateContentParameters } { + const llmRequest = input['llm_request']; + if (!isObject(llmRequest)) { + throw new Error( + `Invalid input for ${eventName} hook event: llm_request must be an object`, + ); + } + return { llmRequest: llmRequest as unknown as GenerateContentParameters }; +} + +/** + * Validates AfterModel input fields + */ +function validateAfterModelInput(input: Record): { + llmRequest: GenerateContentParameters; + llmResponse: GenerateContentResponse; +} { + const llmRequest = input['llm_request']; + const llmResponse = input['llm_response']; + if (!isObject(llmRequest)) { + throw new Error( + 'Invalid input for AfterModel hook event: llm_request must be an object', + ); + } + if (!isObject(llmResponse)) { + throw new Error( + 'Invalid input for AfterModel hook event: llm_response must be an object', + ); + } + return { + llmRequest: llmRequest as unknown as GenerateContentParameters, + llmResponse: llmResponse as unknown as GenerateContentResponse, + }; +} + +/** + * Validates Notification input fields + */ +function validateNotificationInput(input: Record): { + notificationType: NotificationType; + message: string; + details: Record; +} { + const notificationType = input['notification_type']; + const message = input['message']; + const details = input['details']; + if (typeof notificationType !== 'string') { + throw new Error( + 'Invalid input for Notification hook event: notification_type must be a string', + ); + } + if (typeof message !== 'string') { + throw new Error( + 'Invalid input for Notification hook event: message must be a string', + ); + } + if (!isObject(details)) { + throw new Error( + 'Invalid input for Notification hook event: details must be an object', + ); + } + return { + notificationType: notificationType as NotificationType, + message, + details, + }; +} + +/** + * Hook event bus that coordinates hook execution across the system + */ +export class HookEventHandler { + private readonly config: Config; + private readonly hookPlanner: HookPlanner; + private readonly hookRunner: HookRunner; + private readonly hookAggregator: HookAggregator; + private readonly messageBus?: MessageBus; + + constructor( + config: Config, + logger: Logger, + hookPlanner: HookPlanner, + hookRunner: HookRunner, + hookAggregator: HookAggregator, + messageBus?: MessageBus, + ) { + this.config = config; + this.hookPlanner = hookPlanner; + this.hookRunner = hookRunner; + this.hookAggregator = hookAggregator; + this.messageBus = messageBus; + + // Subscribe to hook execution requests from MessageBus + if (this.messageBus) { + this.messageBus.subscribe( + MessageBusType.HOOK_EXECUTION_REQUEST, + (request) => this.handleHookExecutionRequest(request), + ); + } + } + + /** + * Fire a BeforeTool event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireBeforeToolEvent( + toolName: string, + toolInput: Record, + ): Promise { + const input: BeforeToolInput = { + ...this.createBaseInput(HookEventName.BeforeTool), + tool_name: toolName, + tool_input: toolInput, + }; + + const context: HookEventContext = { toolName }; + return await this.executeHooks(HookEventName.BeforeTool, input, context); + } + + /** + * Fire an AfterTool event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireAfterToolEvent( + toolName: string, + toolInput: Record, + toolResponse: Record, + ): Promise { + const input: AfterToolInput = { + ...this.createBaseInput(HookEventName.AfterTool), + tool_name: toolName, + tool_input: toolInput, + tool_response: toolResponse, + }; + + const context: HookEventContext = { toolName }; + return await this.executeHooks(HookEventName.AfterTool, input, context); + } + + /** + * Fire a BeforeAgent event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireBeforeAgentEvent(prompt: string): Promise { + const input: BeforeAgentInput = { + ...this.createBaseInput(HookEventName.BeforeAgent), + prompt, + }; + + return await this.executeHooks(HookEventName.BeforeAgent, input); + } + + /** + * Fire a Notification event + */ + async fireNotificationEvent( + type: NotificationType, + message: string, + details: Record, + ): Promise { + const input: NotificationInput = { + ...this.createBaseInput(HookEventName.Notification), + notification_type: type, + message, + details, + }; + + return await this.executeHooks(HookEventName.Notification, input); + } + + /** + * Fire an AfterAgent event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireAfterAgentEvent( + prompt: string, + promptResponse: string, + stopHookActive: boolean = false, + ): Promise { + const input: AfterAgentInput = { + ...this.createBaseInput(HookEventName.AfterAgent), + prompt, + prompt_response: promptResponse, + stop_hook_active: stopHookActive, + }; + + return await this.executeHooks(HookEventName.AfterAgent, input); + } + + /** + * Fire a SessionStart event + */ + async fireSessionStartEvent( + source: SessionStartSource, + ): Promise { + const input: SessionStartInput = { + ...this.createBaseInput(HookEventName.SessionStart), + source, + }; + + const context: HookEventContext = { trigger: source }; + return await this.executeHooks(HookEventName.SessionStart, input, context); + } + + /** + * Fire a SessionEnd event + */ + async fireSessionEndEvent( + reason: SessionEndReason, + ): Promise { + const input: SessionEndInput = { + ...this.createBaseInput(HookEventName.SessionEnd), + reason, + }; + + const context: HookEventContext = { trigger: reason }; + return await this.executeHooks(HookEventName.SessionEnd, input, context); + } + + /** + * Fire a PreCompress event + */ + async firePreCompressEvent( + trigger: PreCompressTrigger, + ): Promise { + const input: PreCompressInput = { + ...this.createBaseInput(HookEventName.PreCompress), + trigger, + }; + + const context: HookEventContext = { trigger }; + return await this.executeHooks(HookEventName.PreCompress, input, context); + } + + /** + * Fire a BeforeModel event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireBeforeModelEvent( + llmRequest: GenerateContentParameters, + ): Promise { + const input: BeforeModelInput = { + ...this.createBaseInput(HookEventName.BeforeModel), + llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest), + }; + + return await this.executeHooks(HookEventName.BeforeModel, input); + } + + /** + * Fire an AfterModel event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireAfterModelEvent( + llmRequest: GenerateContentParameters, + llmResponse: GenerateContentResponse, + ): Promise { + const input: AfterModelInput = { + ...this.createBaseInput(HookEventName.AfterModel), + llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest), + llm_response: defaultHookTranslator.toHookLLMResponse(llmResponse), + }; + + return await this.executeHooks(HookEventName.AfterModel, input); + } + + /** + * Fire a BeforeToolSelection event + * Called by handleHookExecutionRequest - executes hooks directly + */ + async fireBeforeToolSelectionEvent( + llmRequest: GenerateContentParameters, + ): Promise { + const input: BeforeToolSelectionInput = { + ...this.createBaseInput(HookEventName.BeforeToolSelection), + llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest), + }; + + return await this.executeHooks(HookEventName.BeforeToolSelection, input); + } + + /** + * Execute hooks for a specific event (direct execution without MessageBus) + * Used as fallback when MessageBus is not available + */ + private async executeHooks( + eventName: HookEventName, + input: HookInput, + context?: HookEventContext, + ): Promise { + try { + // Create execution plan + const plan = this.hookPlanner.createExecutionPlan(eventName, context); + + if (!plan || plan.hookConfigs.length === 0) { + return { + success: true, + allOutputs: [], + errors: [], + totalDuration: 0, + }; + } + + // Execute hooks according to the plan's strategy + const results = plan.sequential + ? await this.hookRunner.executeHooksSequential( + plan.hookConfigs, + eventName, + input, + ) + : await this.hookRunner.executeHooksParallel( + plan.hookConfigs, + eventName, + input, + ); + + // Aggregate results + const aggregated = this.hookAggregator.aggregateResults( + results, + eventName, + ); + + // Process common hook output fields centrally + this.processCommonHookOutputFields(aggregated); + + // Log hook execution + this.logHookExecution(eventName, input, results, aggregated); + + return aggregated; + } catch (error) { + debugLogger.error(`Hook event bus error for ${eventName}: ${error}`); + + return { + success: false, + allOutputs: [], + errors: [error instanceof Error ? error : new Error(String(error))], + totalDuration: 0, + }; + } + } + + /** + * Create base hook input with common fields + */ + private createBaseInput(eventName: HookEventName): HookInput { + return { + session_id: this.config.getSessionId(), + transcript_path: '', // TODO: Implement transcript path when supported + cwd: this.config.getWorkingDir(), + hook_event_name: eventName, + timestamp: new Date().toISOString(), + }; + } + + /** + * Log hook execution for observability + */ + private logHookExecution( + eventName: HookEventName, + input: HookInput, + results: HookExecutionResult[], + aggregated: AggregatedHookResult, + ): void { + const successCount = results.filter((r) => r.success).length; + const errorCount = results.length - successCount; + + if (errorCount > 0) { + debugLogger.warn( + `Hook execution for ${eventName}: ${successCount} succeeded, ${errorCount} failed, ` + + `total duration: ${aggregated.totalDuration}ms`, + ); + } else { + debugLogger.debug( + `Hook execution for ${eventName}: ${successCount} hooks executed successfully, ` + + `total duration: ${aggregated.totalDuration}ms`, + ); + } + + // Log individual hook calls to telemetry + for (const result of results) { + // Determine hook name and type for telemetry + const hookName = this.getHookNameFromResult(result); + const hookType = this.getHookTypeFromResult(result); + + const hookCallEvent = new HookCallEvent( + eventName, + hookType, + hookName, + { ...input }, + result.duration, + result.success, + result.output ? { ...result.output } : undefined, + result.exitCode, + result.stdout, + result.stderr, + result.error?.message, + ); + + logHookCall(this.config, hookCallEvent); + } + + // Log individual errors + for (const error of aggregated.errors) { + debugLogger.error(`Hook execution error: ${error.message}`); + } + } + + /** + * Process common hook output fields centrally + */ + private processCommonHookOutputFields( + aggregated: AggregatedHookResult, + ): void { + if (!aggregated.finalOutput) { + return; + } + + // Handle systemMessage - show to user in transcript mode (not to agent) + const systemMessage = aggregated.finalOutput.systemMessage; + if (systemMessage && !aggregated.finalOutput.suppressOutput) { + debugLogger.warn(`Hook system message: ${systemMessage}`); + } + + // Handle suppressOutput - already handled by not logging above when true + + // Handle continue=false - this should stop the entire agent execution + if (aggregated.finalOutput.shouldStopExecution()) { + const stopReason = aggregated.finalOutput.getEffectiveReason(); + debugLogger.log(`Hook requested to stop execution: ${stopReason}`); + + // Note: The actual stopping of execution must be handled by integration points + // as they need to interpret this signal in the context of their specific workflow + // This is just logging the request centrally + } + + // Other common fields like decision/reason are handled by specific hook output classes + } + + /** + * Get hook name from execution result for telemetry + */ + private getHookNameFromResult(result: HookExecutionResult): string { + return result.hookConfig.command || 'unknown-command'; + } + + /** + * Get hook type from execution result for telemetry + */ + private getHookTypeFromResult(result: HookExecutionResult): 'command' { + return result.hookConfig.type; + } + + /** + * Handle hook execution requests from MessageBus + * This method routes the request to the appropriate fire*Event method + * and publishes the response back through MessageBus + * + * The request input only contains event-specific fields. This method adds + * the common base fields (session_id, cwd, etc.) before routing. + */ + private async handleHookExecutionRequest( + request: HookExecutionRequest, + ): Promise { + try { + // Add base fields to the input + const enrichedInput = { + ...this.createBaseInput(request.eventName as HookEventName), + ...request.input, + } as Record; + + let result: AggregatedHookResult; + + // Route to appropriate event handler based on eventName + switch (request.eventName) { + case HookEventName.BeforeTool: { + const { toolName, toolInput } = + validateBeforeToolInput(enrichedInput); + result = await this.fireBeforeToolEvent(toolName, toolInput); + break; + } + case HookEventName.AfterTool: { + const { toolName, toolInput, toolResponse } = + validateAfterToolInput(enrichedInput); + result = await this.fireAfterToolEvent( + toolName, + toolInput, + toolResponse, + ); + break; + } + case HookEventName.BeforeAgent: { + const { prompt } = validateBeforeAgentInput(enrichedInput); + result = await this.fireBeforeAgentEvent(prompt); + break; + } + case HookEventName.AfterAgent: { + const { prompt, promptResponse, stopHookActive } = + validateAfterAgentInput(enrichedInput); + result = await this.fireAfterAgentEvent( + prompt, + promptResponse, + stopHookActive, + ); + break; + } + case HookEventName.BeforeModel: { + const { llmRequest } = validateModelInput( + enrichedInput, + 'BeforeModel', + ); + const translatedRequest = + defaultHookTranslator.toHookLLMRequest(llmRequest); + // Update the enrichedInput with translated request + enrichedInput['llm_request'] = translatedRequest; + result = await this.fireBeforeModelEvent(llmRequest); + break; + } + case HookEventName.AfterModel: { + const { llmRequest, llmResponse } = + validateAfterModelInput(enrichedInput); + const translatedRequest = + defaultHookTranslator.toHookLLMRequest(llmRequest); + const translatedResponse = + defaultHookTranslator.toHookLLMResponse(llmResponse); + // Update the enrichedInput with translated versions + enrichedInput['llm_request'] = translatedRequest; + enrichedInput['llm_response'] = translatedResponse; + result = await this.fireAfterModelEvent(llmRequest, llmResponse); + break; + } + case HookEventName.BeforeToolSelection: { + const { llmRequest } = validateModelInput( + enrichedInput, + 'BeforeToolSelection', + ); + const translatedRequest = + defaultHookTranslator.toHookLLMRequest(llmRequest); + // Update the enrichedInput with translated request + enrichedInput['llm_request'] = translatedRequest; + result = await this.fireBeforeToolSelectionEvent(llmRequest); + break; + } + case HookEventName.Notification: { + const { notificationType, message, details } = + validateNotificationInput(enrichedInput); + result = await this.fireNotificationEvent( + notificationType, + message, + details, + ); + break; + } + default: + throw new Error(`Unsupported hook event: ${request.eventName}`); + } + + // Publish response through MessageBus + if (this.messageBus) { + this.messageBus.publish({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: request.correlationId, + success: result.success, + output: result.finalOutput as unknown as Record, + }); + } + } catch (error) { + // Publish error response + if (this.messageBus) { + this.messageBus.publish({ + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: request.correlationId, + success: false, + error: error instanceof Error ? error : new Error(String(error)), + }); + } + } + } +} diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index 5cb7cd3b9a..a58725f8f2 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -1250,4 +1250,291 @@ describe('PolicyEngine', () => { expect(result.decision).toBe(PolicyDecision.DENY); }); }); + + describe('checkHook', () => { + it('should allow hooks by default', async () => { + engine = new PolicyEngine({}, mockCheckerRunner); + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + expect(decision).toBe(PolicyDecision.ALLOW); + }); + + it('should deny all hooks when allowHooks is false', async () => { + engine = new PolicyEngine({ allowHooks: false }, mockCheckerRunner); + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should deny project hooks in untrusted folders', async () => { + engine = new PolicyEngine({}, mockCheckerRunner); + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'project', + trustedFolder: false, + }); + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should allow project hooks in trusted folders', async () => { + engine = new PolicyEngine({}, mockCheckerRunner); + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'project', + trustedFolder: true, + }); + expect(decision).toBe(PolicyDecision.ALLOW); + }); + + it('should allow user hooks in untrusted folders', async () => { + engine = new PolicyEngine({}, mockCheckerRunner); + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + trustedFolder: false, + }); + expect(decision).toBe(PolicyDecision.ALLOW); + }); + + it('should run hook checkers and deny on DENY decision', async () => { + const hookCheckers = [ + { + eventName: 'BeforeTool', + checker: { type: 'external' as const, name: 'test-hook-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.DENY, + reason: 'Hook checker denied', + }); + + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(decision).toBe(PolicyDecision.DENY); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.objectContaining({ name: 'hook:BeforeTool' }), + expect.objectContaining({ name: 'test-hook-checker' }), + ); + }); + + it('should run hook checkers and allow on ALLOW decision', async () => { + const hookCheckers = [ + { + eventName: 'BeforeTool', + checker: { type: 'external' as const, name: 'test-hook-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(decision).toBe(PolicyDecision.ALLOW); + }); + + it('should return ASK_USER when checker requests it', async () => { + const hookCheckers = [ + { + checker: { type: 'external' as const, name: 'test-hook-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ASK_USER, + reason: 'Needs confirmation', + }); + + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(decision).toBe(PolicyDecision.ASK_USER); + }); + + it('should return DENY for ASK_USER in non-interactive mode', async () => { + const hookCheckers = [ + { + checker: { type: 'external' as const, name: 'test-hook-checker' }, + }, + ]; + engine = new PolicyEngine( + { hookCheckers, nonInteractive: true }, + mockCheckerRunner, + ); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ASK_USER, + reason: 'Needs confirmation', + }); + + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should match hook checkers by eventName', async () => { + const hookCheckers = [ + { + eventName: 'AfterTool', + checker: { type: 'external' as const, name: 'after-tool-checker' }, + }, + { + eventName: 'BeforeTool', + checker: { type: 'external' as const, name: 'before-tool-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'before-tool-checker' }), + ); + expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'after-tool-checker' }), + ); + }); + + it('should match hook checkers by hookSource', async () => { + const hookCheckers = [ + { + hookSource: 'project' as const, + checker: { type: 'external' as const, name: 'project-checker' }, + }, + { + hookSource: 'user' as const, + checker: { type: 'external' as const, name: 'user-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({ + decision: SafetyCheckDecision.ALLOW, + }); + + await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'user-checker' }), + ); + expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'project-checker' }), + ); + }); + + it('should deny when hook checker throws an error', async () => { + const hookCheckers = [ + { + checker: { type: 'external' as const, name: 'failing-checker' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockRejectedValue( + new Error('Checker failed'), + ); + + const decision = await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + expect(decision).toBe(PolicyDecision.DENY); + }); + + it('should run hook checkers in priority order', async () => { + const hookCheckers = [ + { + priority: 5, + checker: { type: 'external' as const, name: 'low-priority' }, + }, + { + priority: 20, + checker: { type: 'external' as const, name: 'high-priority' }, + }, + { + priority: 10, + checker: { type: 'external' as const, name: 'medium-priority' }, + }, + ]; + engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner); + + vi.mocked(mockCheckerRunner.runChecker).mockImplementation( + async (_call, config) => { + if (config.name === 'high-priority') { + return { decision: SafetyCheckDecision.DENY, reason: 'denied' }; + } + return { decision: SafetyCheckDecision.ALLOW }; + }, + ); + + await engine.checkHook({ + eventName: 'BeforeTool', + hookSource: 'user', + }); + + // Should only call the high-priority checker (first in sorted order) + expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1); + expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ name: 'high-priority' }), + ); + }); + }); + + describe('addHookChecker', () => { + it('should add a new hook checker and maintain priority order', () => { + engine = new PolicyEngine({}, mockCheckerRunner); + + engine.addHookChecker({ + priority: 5, + checker: { type: 'external', name: 'checker1' }, + }); + engine.addHookChecker({ + priority: 10, + checker: { type: 'external', name: 'checker2' }, + }); + + const checkers = engine.getHookCheckers(); + expect(checkers).toHaveLength(2); + expect(checkers[0].priority).toBe(10); + expect(checkers[0].checker.name).toBe('checker2'); + expect(checkers[1].priority).toBe(5); + expect(checkers[1].checker.name).toBe('checker1'); + }); + }); }); diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index f1fb05ec43..295270d156 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -10,11 +10,15 @@ import { type PolicyEngineConfig, type PolicyRule, type SafetyCheckerRule, + type HookCheckerRule, + type HookExecutionContext, + getHookSource, } from './types.js'; import { stableStringify } from './stable-stringify.js'; import { debugLogger } from '../utils/debugLogger.js'; import type { CheckerRunner } from '../safety/checker-runner.js'; import { SafetyCheckDecision } from '../safety/protocol.js'; +import type { HookExecutionRequest } from '../confirmation-bus/types.js'; function ruleMatches( rule: PolicyRule | SafetyCheckerRule, @@ -61,12 +65,34 @@ function ruleMatches( return true; } +/** + * Check if a hook checker rule matches a hook execution context. + */ +function hookCheckerMatches( + rule: HookCheckerRule, + context: HookExecutionContext, +): boolean { + // Check event name if specified + if (rule.eventName && rule.eventName !== context.eventName) { + return false; + } + + // Check hook source if specified + if (rule.hookSource && rule.hookSource !== context.hookSource) { + return false; + } + + return true; +} + export class PolicyEngine { private rules: PolicyRule[]; private checkers: SafetyCheckerRule[]; + private hookCheckers: HookCheckerRule[]; private readonly defaultDecision: PolicyDecision; private readonly nonInteractive: boolean; private readonly checkerRunner?: CheckerRunner; + private readonly allowHooks: boolean; constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) { this.rules = (config.rules ?? []).sort( @@ -75,9 +101,13 @@ export class PolicyEngine { this.checkers = (config.checkers ?? []).sort( (a, b) => (b.priority ?? 0) - (a.priority ?? 0), ); + this.hookCheckers = (config.hookCheckers ?? []).sort( + (a, b) => (b.priority ?? 0) - (a.priority ?? 0), + ); this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER; this.nonInteractive = config.nonInteractive ?? false; this.checkerRunner = checkerRunner; + this.allowHooks = config.allowHooks ?? true; } /** @@ -206,6 +236,99 @@ export class PolicyEngine { return this.checkers; } + /** + * Add a new hook checker to the policy engine. + */ + addHookChecker(checker: HookCheckerRule): void { + this.hookCheckers.push(checker); + this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0)); + } + + /** + * Get all current hook checkers. + */ + getHookCheckers(): readonly HookCheckerRule[] { + return this.hookCheckers; + } + + /** + * Check if a hook execution is allowed based on the configured policies. + * Runs hook-specific safety checkers if configured. + */ + async checkHook( + request: HookExecutionRequest | HookExecutionContext, + ): Promise { + // If hooks are globally disabled, deny all hook executions + if (!this.allowHooks) { + return PolicyDecision.DENY; + } + + const context: HookExecutionContext = + 'input' in request + ? { + eventName: request.eventName, + hookSource: getHookSource(request.input), + trustedFolder: + typeof request.input['trusted_folder'] === 'boolean' + ? request.input['trusted_folder'] + : undefined, + } + : request; + + // In untrusted folders, deny project-level hooks + if (context.trustedFolder === false && context.hookSource === 'project') { + return PolicyDecision.DENY; + } + + // Run hook-specific safety checkers if configured + if (this.checkerRunner && this.hookCheckers.length > 0) { + for (const checkerRule of this.hookCheckers) { + if (hookCheckerMatches(checkerRule, context)) { + debugLogger.debug( + `[PolicyEngine.checkHook] Running hook checker: ${checkerRule.checker.name} for event: ${context.eventName}`, + ); + try { + // Create a synthetic function call for the checker runner + // This allows reusing the existing checker infrastructure + const syntheticCall = { + name: `hook:${context.eventName}`, + args: { + hookSource: context.hookSource, + trustedFolder: context.trustedFolder, + }, + }; + + const result = await this.checkerRunner.runChecker( + syntheticCall, + checkerRule.checker, + ); + + if (result.decision === SafetyCheckDecision.DENY) { + debugLogger.debug( + `[PolicyEngine.checkHook] Hook checker denied: ${result.reason}`, + ); + return PolicyDecision.DENY; + } else if (result.decision === SafetyCheckDecision.ASK_USER) { + debugLogger.debug( + `[PolicyEngine.checkHook] Hook checker requested ASK_USER: ${result.reason}`, + ); + // For hooks, ASK_USER is treated as DENY in non-interactive mode + return this.applyNonInteractiveMode(PolicyDecision.ASK_USER); + } + } catch (error) { + debugLogger.debug( + `[PolicyEngine.checkHook] Hook checker failed: ${error}`, + ); + return PolicyDecision.DENY; + } + } + } + } + + // Default: Allow hooks + return PolicyDecision.ALLOW; + } + private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision { // In non-interactive mode, ASK_USER becomes DENY if (this.nonInteractive && decision === PolicyDecision.ASK_USER) { diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts index d244211e70..410d9ff1c9 100644 --- a/packages/core/src/policy/types.ts +++ b/packages/core/src/policy/types.ts @@ -12,6 +12,36 @@ export enum PolicyDecision { ASK_USER = 'ask_user', } +/** + * Valid sources for hook execution + */ +export type HookSource = 'project' | 'user' | 'system' | 'extension'; + +/** + * Array of valid hook source values for runtime validation + */ +const VALID_HOOK_SOURCES: HookSource[] = [ + 'project', + 'user', + 'system', + 'extension', +]; + +/** + * Safely extract and validate hook source from input + * Returns 'project' as default if the value is invalid or missing + */ +export function getHookSource(input: Record): HookSource { + const source = input['hook_source']; + if ( + typeof source === 'string' && + VALID_HOOK_SOURCES.includes(source as HookSource) + ) { + return source as HookSource; + } + return 'project'; +} + export enum ApprovalMode { DEFAULT = 'default', AUTO_EDIT = 'autoEdit', @@ -115,6 +145,42 @@ export interface SafetyCheckerRule { checker: SafetyCheckerConfig; } +export interface HookExecutionContext { + eventName: string; + hookSource?: HookSource; + trustedFolder?: boolean; +} + +/** + * Rule for applying safety checkers to hook executions. + * Similar to SafetyCheckerRule but with hook-specific matching criteria. + */ +export interface HookCheckerRule { + /** + * The name of the hook event this rule applies to. + * If undefined, the rule applies to all hook events. + */ + eventName?: string; + + /** + * The source of hooks this rule applies to. + * If undefined, the rule applies to all hook sources. + */ + hookSource?: HookSource; + + /** + * Priority of this checker. Higher numbers run first. + * Default is 0. + */ + priority?: number; + + /** + * Specifies an external or built-in safety checker to execute for + * additional validation of a hook execution. + */ + checker: SafetyCheckerConfig; +} + export interface PolicyEngineConfig { /** * List of policy rules to apply. @@ -122,10 +188,15 @@ export interface PolicyEngineConfig { rules?: PolicyRule[]; /** - * List of safety checkers to apply. + * List of safety checkers to apply to tool calls. */ checkers?: SafetyCheckerRule[]; + /** + * List of safety checkers to apply to hook executions. + */ + hookCheckers?: HookCheckerRule[]; + /** * Default decision when no rules match. * Defaults to ASK_USER. @@ -137,6 +208,13 @@ export interface PolicyEngineConfig { * When true, ASK_USER decisions become DENY. */ nonInteractive?: boolean; + + /** + * Whether to allow hooks to execute. + * When false, all hooks are denied. + * Defaults to true. + */ + allowHooks?: boolean; } export interface PolicySettings { diff --git a/packages/core/src/test-utils/mock-message-bus.ts b/packages/core/src/test-utils/mock-message-bus.ts new file mode 100644 index 0000000000..3423cc3dcf --- /dev/null +++ b/packages/core/src/test-utils/mock-message-bus.ts @@ -0,0 +1,177 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi } from 'vitest'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import { + MessageBusType, + type Message, + type HookExecutionRequest, + type HookExecutionResponse, +} from '../confirmation-bus/types.js'; + +/** + * Mock MessageBus for testing hook execution through MessageBus + */ +export class MockMessageBus { + private subscriptions = new Map< + MessageBusType, + Set<(message: Message) => void> + >(); + publishedMessages: Message[] = []; + hookRequests: HookExecutionRequest[] = []; + hookResponses: HookExecutionResponse[] = []; + + /** + * Mock publish method that captures messages and simulates responses + */ + publish = vi.fn((message: Message) => { + this.publishedMessages.push(message); + + // Capture hook-specific messages + if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) { + this.hookRequests.push(message as HookExecutionRequest); + + // Auto-respond with success for testing + const response: HookExecutionResponse = { + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId: (message as HookExecutionRequest).correlationId, + success: true, + output: { + decision: 'allow', + reason: 'Mock hook execution successful', + }, + }; + this.hookResponses.push(response); + + // Emit response to subscribers + this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response); + } + }); + + /** + * Mock subscribe method that stores listeners + */ + subscribe = vi.fn( + (type: T['type'], listener: (message: T) => void) => { + if (!this.subscriptions.has(type)) { + this.subscriptions.set(type, new Set()); + } + this.subscriptions.get(type)!.add(listener as (message: Message) => void); + }, + ); + + /** + * Mock unsubscribe method + */ + unsubscribe = vi.fn( + (type: T['type'], listener: (message: T) => void) => { + const listeners = this.subscriptions.get(type); + if (listeners) { + listeners.delete(listener as (message: Message) => void); + } + }, + ); + + /** + * Emit a message to subscribers (for testing) + */ + private emit(type: MessageBusType, message: Message) { + const listeners = this.subscriptions.get(type); + if (listeners) { + listeners.forEach((listener) => listener(message)); + } + } + + /** + * Manually trigger a hook response (for testing custom scenarios) + */ + triggerHookResponse( + correlationId: string, + success: boolean, + output?: Record, + error?: Error, + ) { + const response: HookExecutionResponse = { + type: MessageBusType.HOOK_EXECUTION_RESPONSE, + correlationId, + success, + output, + error, + }; + this.hookResponses.push(response); + this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response); + } + + /** + * Get the last hook request published + */ + getLastHookRequest(): HookExecutionRequest | undefined { + return this.hookRequests[this.hookRequests.length - 1]; + } + + /** + * Get all hook requests for a specific event + */ + getHookRequestsForEvent(eventName: string): HookExecutionRequest[] { + return this.hookRequests.filter((req) => req.eventName === eventName); + } + + /** + * Clear all captured messages (for test isolation) + */ + clear() { + this.publishedMessages = []; + this.hookRequests = []; + this.hookResponses = []; + this.subscriptions.clear(); + } + + /** + * Verify that a hook execution request was published + */ + expectHookRequest( + eventName: string, + input?: Partial>, + ) { + const request = this.hookRequests.find( + (req) => req.eventName === eventName, + ); + if (!request) { + throw new Error( + `Expected hook request for event "${eventName}" but none was found`, + ); + } + + if (input) { + Object.entries(input).forEach(([key, value]) => { + if (request.input[key] !== value) { + throw new Error( + `Expected hook input.${key} to be ${JSON.stringify(value)} but got ${JSON.stringify(request.input[key])}`, + ); + } + }); + } + + return request; + } +} + +/** + * Create a mock MessageBus for testing + */ +export function createMockMessageBus(): MessageBus { + return new MockMessageBus() as unknown as MessageBus; +} + +/** + * Get the MockMessageBus instance from a mocked MessageBus + */ +export function getMockMessageBusInstance( + messageBus: MessageBus, +): MockMessageBus { + return messageBus as unknown as MockMessageBus; +}