feat(hooks): Hook Event Handling (#9097)

This commit is contained in:
Edilmo Palencia
2025-11-24 13:51:39 -08:00
committed by GitHub
parent d53a5c4fb5
commit 2034098780
8 changed files with 2035 additions and 4 deletions

View File

@@ -4,10 +4,16 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { randomUUID } from 'node:crypto';
import { EventEmitter } from 'node:events';
import type { PolicyEngine } from '../policy/policy-engine.js';
import { PolicyDecision } from '../policy/types.js';
import { MessageBusType, type Message } from './types.js';
import { PolicyDecision, getHookSource } from '../policy/types.js';
import {
MessageBusType,
type Message,
type HookExecutionRequest,
type HookPolicyDecision,
} from './types.js';
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
export class MessageBus extends EventEmitter {
@@ -83,6 +89,39 @@ export class MessageBus extends EventEmitter {
default:
throw new Error(`Unknown policy decision: ${decision}`);
}
} else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) {
// Handle hook execution requests through policy evaluation
const hookRequest = message as HookExecutionRequest;
const decision = await this.policyEngine.checkHook(hookRequest);
// Map decision to allow/deny for observability (ASK_USER treated as deny for hooks)
const effectiveDecision =
decision === PolicyDecision.ALLOW ? 'allow' : 'deny';
// Emit policy decision for observability
this.emitMessage({
type: MessageBusType.HOOK_POLICY_DECISION,
eventName: hookRequest.eventName,
hookSource: getHookSource(hookRequest.input),
decision: effectiveDecision,
reason:
decision !== PolicyDecision.ALLOW
? 'Hook execution denied by policy'
: undefined,
} as HookPolicyDecision);
// If allowed, emit the request for hook system to handle
if (decision === PolicyDecision.ALLOW) {
this.emitMessage(message);
} else {
// If denied or ASK_USER, emit error response (hooks don't support interactive confirmation)
this.emitMessage({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: hookRequest.correlationId,
success: false,
error: new Error('Hook execution denied by policy'),
});
}
} else {
// For all other message types, just emit them
this.emitMessage(message);
@@ -105,4 +144,46 @@ export class MessageBus extends EventEmitter {
): void {
this.off(type, listener);
}
/**
* Request-response pattern: Publish a message and wait for a correlated response
* This enables synchronous-style communication over the async MessageBus
* The correlation ID is generated internally and added to the request
*/
async request<TRequest extends Message, TResponse extends Message>(
request: Omit<TRequest, 'correlationId'>,
responseType: TResponse['type'],
timeoutMs: number = 60000,
): Promise<TResponse> {
const correlationId = randomUUID();
return new Promise<TResponse>((resolve, reject) => {
const timeoutId = setTimeout(() => {
cleanup();
reject(new Error(`Request timed out waiting for ${responseType}`));
}, timeoutMs);
const cleanup = () => {
clearTimeout(timeoutId);
this.unsubscribe(responseType, responseHandler);
};
const responseHandler = (response: TResponse) => {
// Check if this response matches our request
if (
'correlationId' in response &&
response.correlationId === correlationId
) {
cleanup();
resolve(response);
}
};
// Subscribe to responses
this.subscribe<TResponse>(responseType, responseHandler);
// Publish the request with correlation ID
this.publish({ ...request, correlationId } as TRequest);
});
}
}

View File

@@ -13,6 +13,9 @@ export enum MessageBusType {
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
UPDATE_POLICY = 'update-policy',
HOOK_EXECUTION_REQUEST = 'hook-execution-request',
HOOK_EXECUTION_RESPONSE = 'hook-execution-response',
HOOK_POLICY_DECISION = 'hook-policy-decision',
}
export interface ToolConfirmationRequest {
@@ -55,10 +58,36 @@ export interface ToolExecutionFailure<E = Error> {
error: E;
}
export interface HookExecutionRequest {
type: MessageBusType.HOOK_EXECUTION_REQUEST;
eventName: string;
input: Record<string, unknown>;
correlationId: string;
}
export interface HookExecutionResponse {
type: MessageBusType.HOOK_EXECUTION_RESPONSE;
correlationId: string;
success: boolean;
output?: Record<string, unknown>;
error?: Error;
}
export interface HookPolicyDecision {
type: MessageBusType.HOOK_POLICY_DECISION;
eventName: string;
hookSource: 'project' | 'user' | 'system' | 'extension';
decision: 'allow' | 'deny';
reason?: string;
}
export type Message =
| ToolConfirmationRequest
| ToolConfirmationResponse
| ToolPolicyRejection
| ToolExecutionSuccess
| ToolExecutionFailure
| UpdatePolicy;
| UpdatePolicy
| HookExecutionRequest
| HookExecutionResponse
| HookPolicyDecision;

View File

@@ -0,0 +1,524 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { HookEventHandler } from './hookEventHandler.js';
import type { Config } from '../config/config.js';
import type { HookConfig } from './types.js';
import type { Logger } from '@opentelemetry/api-logs';
import type { HookPlanner } from './hookPlanner.js';
import type { HookRunner } from './hookRunner.js';
import type { HookAggregator } from './hookAggregator.js';
import { HookEventName, HookType } from './types.js';
import {
NotificationType,
SessionStartSource,
type HookExecutionResult,
} from './types.js';
// Mock debugLogger
const mockDebugLogger = vi.hoisted(() => ({
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
}));
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: mockDebugLogger,
}));
describe('HookEventHandler', () => {
let hookEventHandler: HookEventHandler;
let mockConfig: Config;
let mockLogger: Logger;
let mockHookPlanner: HookPlanner;
let mockHookRunner: HookRunner;
let mockHookAggregator: HookAggregator;
beforeEach(() => {
vi.resetAllMocks();
mockConfig = {
getSessionId: vi.fn().mockReturnValue('test-session'),
getWorkingDir: vi.fn().mockReturnValue('/test/project'),
} as unknown as Config;
mockLogger = {} as Logger;
mockHookPlanner = {
createExecutionPlan: vi.fn(),
} as unknown as HookPlanner;
mockHookRunner = {
executeHooksParallel: vi.fn(),
executeHooksSequential: vi.fn(),
} as unknown as HookRunner;
mockHookAggregator = {
aggregateResults: vi.fn(),
} as unknown as HookAggregator;
hookEventHandler = new HookEventHandler(
mockConfig,
mockLogger,
mockHookPlanner,
mockHookRunner,
mockHookAggregator,
);
});
describe('fireBeforeToolEvent', () => {
it('should fire BeforeTool event with correct input', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {
file: 'test.txt',
});
expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith(
HookEventName.BeforeTool,
{ toolName: 'EditTool' },
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.BeforeTool,
expect.objectContaining({
session_id: 'test-session',
cwd: '/test/project',
hook_event_name: 'BeforeTool',
tool_name: 'EditTool',
tool_input: { file: 'test.txt' },
}),
);
expect(result).toBe(mockAggregated);
});
it('should return empty result when no hooks to execute', async () => {
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue(null);
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {});
expect(result.success).toBe(true);
expect(result.allOutputs).toHaveLength(0);
expect(result.errors).toHaveLength(0);
expect(result.totalDuration).toBe(0);
});
it('should handle execution errors gracefully', async () => {
vi.mocked(mockHookPlanner.createExecutionPlan).mockImplementation(() => {
throw new Error('Planning failed');
});
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {});
expect(result.success).toBe(false);
expect(result.errors).toHaveLength(1);
expect(result.errors[0].message).toBe('Planning failed');
expect(mockDebugLogger.error).toHaveBeenCalled();
});
});
describe('fireAfterToolEvent', () => {
it('should fire AfterTool event with tool response', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './after.sh',
} as unknown as HookConfig,
eventName: HookEventName.AfterTool,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const toolInput = { file: 'test.txt' };
const toolResponse = { success: true, content: 'File edited' };
const result = await hookEventHandler.fireAfterToolEvent(
'EditTool',
toolInput,
toolResponse,
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.AfterTool,
expect.objectContaining({
tool_name: 'EditTool',
tool_input: toolInput,
tool_response: toolResponse,
}),
);
expect(result).toBe(mockAggregated);
});
});
describe('fireBeforeAgentEvent', () => {
it('should fire BeforeAgent event with prompt', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './before_agent.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeAgent,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 100,
hookConfig: {
type: HookType.Command,
command: './test.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeTool,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 100,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const prompt = 'Please help me with this task';
const result = await hookEventHandler.fireBeforeAgentEvent(prompt);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.BeforeAgent,
expect.objectContaining({
prompt,
}),
);
expect(result).toBe(mockAggregated);
});
});
describe('fireNotificationEvent', () => {
it('should fire Notification event', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './notification-hook.sh',
} as HookConfig,
eventName: HookEventName.Notification,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 50,
hookConfig: {
type: HookType.Command,
command: './notification-hook.sh',
timeout: 30000,
},
eventName: HookEventName.Notification,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 50,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.Notification,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const message = 'Tool execution requires permission';
const result = await hookEventHandler.fireNotificationEvent(
NotificationType.ToolPermission,
message,
{ type: 'ToolPermission', title: 'Test Permission' },
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.Notification,
expect.objectContaining({
notification_type: 'ToolPermission',
details: { type: 'ToolPermission', title: 'Test Permission' },
}),
);
expect(result).toBe(mockAggregated);
});
});
describe('fireSessionStartEvent', () => {
it('should fire SessionStart event with source', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './session_start.sh',
} as unknown as HookConfig,
eventName: HookEventName.SessionStart,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 200,
hookConfig: {
type: HookType.Command,
command: './session_start.sh',
timeout: 30000,
},
eventName: HookEventName.SessionStart,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 200,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.SessionStart,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const result = await hookEventHandler.fireSessionStartEvent(
SessionStartSource.Startup,
);
expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith(
HookEventName.SessionStart,
{ trigger: 'startup' },
);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.SessionStart,
expect.objectContaining({
source: 'startup',
}),
);
expect(result).toBe(mockAggregated);
});
});
describe('fireBeforeModelEvent', () => {
it('should fire BeforeModel event with LLM request', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './model-hook.sh',
} as HookConfig,
eventName: HookEventName.BeforeModel,
},
];
const mockResults: HookExecutionResult[] = [
{
success: true,
duration: 150,
hookConfig: {
type: HookType.Command,
command: './model-hook.sh',
timeout: 30000,
},
eventName: HookEventName.BeforeModel,
},
];
const mockAggregated = {
success: true,
allOutputs: [],
errors: [],
totalDuration: 150,
};
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeModel,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
mockResults,
);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
mockAggregated,
);
const llmRequest = {
model: 'gemini-pro',
config: { temperature: 0.7 },
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
};
const result = await hookEventHandler.fireBeforeModelEvent(llmRequest);
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
[mockPlan[0].hookConfig],
HookEventName.BeforeModel,
expect.objectContaining({
llm_request: expect.objectContaining({
model: 'gemini-pro',
messages: expect.arrayContaining([
expect.objectContaining({
role: 'user',
content: 'Hello',
}),
]),
}),
}),
);
expect(result).toBe(mockAggregated);
});
});
describe('createBaseInput', () => {
it('should create base input with correct fields', async () => {
const mockPlan = [
{
hookConfig: {
type: HookType.Command,
command: './test.sh',
} as unknown as HookConfig,
eventName: HookEventName.BeforeTool,
},
];
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
eventName: HookEventName.BeforeTool,
hookConfigs: mockPlan.map((p) => p.hookConfig),
sequential: false,
});
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue([]);
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue({
success: true,
allOutputs: [],
errors: [],
totalDuration: 0,
});
await hookEventHandler.fireBeforeToolEvent('TestTool', {});
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
expect.any(Array),
HookEventName.BeforeTool,
expect.objectContaining({
session_id: 'test-session',
transcript_path: '',
cwd: '/test/project',
hook_event_name: 'BeforeTool',
timestamp: expect.any(String),
}),
);
});
});
});

View File

@@ -0,0 +1,732 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Logger } from '@opentelemetry/api-logs';
import type { Config } from '../config/config.js';
import type { HookPlanner, HookEventContext } from './hookPlanner.js';
import type { HookRunner } from './hookRunner.js';
import type { HookAggregator, AggregatedHookResult } from './hookAggregator.js';
import { HookEventName } from './types.js';
import type {
HookInput,
BeforeToolInput,
AfterToolInput,
BeforeAgentInput,
NotificationInput,
AfterAgentInput,
SessionStartInput,
SessionEndInput,
PreCompressInput,
BeforeModelInput,
AfterModelInput,
BeforeToolSelectionInput,
NotificationType,
SessionStartSource,
SessionEndReason,
PreCompressTrigger,
HookExecutionResult,
} from './types.js';
import { defaultHookTranslator } from './hookTranslator.js';
import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import { logHookCall } from '../telemetry/loggers.js';
import { HookCallEvent } from '../telemetry/types.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
MessageBusType,
type HookExecutionRequest,
} from '../confirmation-bus/types.js';
import { debugLogger } from '../utils/debugLogger.js';
/**
* Validates that a value is a non-null object
*/
function isObject(value: unknown): value is Record<string, unknown> {
return typeof value === 'object' && value !== null;
}
/**
* Validates BeforeTool input fields
*/
function validateBeforeToolInput(input: Record<string, unknown>): {
toolName: string;
toolInput: Record<string, unknown>;
} {
const toolName = input['tool_name'];
const toolInput = input['tool_input'];
if (typeof toolName !== 'string') {
throw new Error(
'Invalid input for BeforeTool hook event: tool_name must be a string',
);
}
if (!isObject(toolInput)) {
throw new Error(
'Invalid input for BeforeTool hook event: tool_input must be an object',
);
}
return { toolName, toolInput };
}
/**
* Validates AfterTool input fields
*/
function validateAfterToolInput(input: Record<string, unknown>): {
toolName: string;
toolInput: Record<string, unknown>;
toolResponse: Record<string, unknown>;
} {
const toolName = input['tool_name'];
const toolInput = input['tool_input'];
const toolResponse = input['tool_response'];
if (typeof toolName !== 'string') {
throw new Error(
'Invalid input for AfterTool hook event: tool_name must be a string',
);
}
if (!isObject(toolInput)) {
throw new Error(
'Invalid input for AfterTool hook event: tool_input must be an object',
);
}
if (!isObject(toolResponse)) {
throw new Error(
'Invalid input for AfterTool hook event: tool_response must be an object',
);
}
return { toolName, toolInput, toolResponse };
}
/**
* Validates BeforeAgent input fields
*/
function validateBeforeAgentInput(input: Record<string, unknown>): {
prompt: string;
} {
const prompt = input['prompt'];
if (typeof prompt !== 'string') {
throw new Error(
'Invalid input for BeforeAgent hook event: prompt must be a string',
);
}
return { prompt };
}
/**
* Validates AfterAgent input fields
*/
function validateAfterAgentInput(input: Record<string, unknown>): {
prompt: string;
promptResponse: string;
stopHookActive: boolean;
} {
const prompt = input['prompt'];
const promptResponse = input['prompt_response'];
const stopHookActive = input['stop_hook_active'];
if (typeof prompt !== 'string') {
throw new Error(
'Invalid input for AfterAgent hook event: prompt must be a string',
);
}
if (typeof promptResponse !== 'string') {
throw new Error(
'Invalid input for AfterAgent hook event: prompt_response must be a string',
);
}
// stopHookActive defaults to false if not a boolean
return {
prompt,
promptResponse,
stopHookActive:
typeof stopHookActive === 'boolean' ? stopHookActive : false,
};
}
/**
* Validates model-related input fields (llm_request)
*/
function validateModelInput(
input: Record<string, unknown>,
eventName: string,
): { llmRequest: GenerateContentParameters } {
const llmRequest = input['llm_request'];
if (!isObject(llmRequest)) {
throw new Error(
`Invalid input for ${eventName} hook event: llm_request must be an object`,
);
}
return { llmRequest: llmRequest as unknown as GenerateContentParameters };
}
/**
* Validates AfterModel input fields
*/
function validateAfterModelInput(input: Record<string, unknown>): {
llmRequest: GenerateContentParameters;
llmResponse: GenerateContentResponse;
} {
const llmRequest = input['llm_request'];
const llmResponse = input['llm_response'];
if (!isObject(llmRequest)) {
throw new Error(
'Invalid input for AfterModel hook event: llm_request must be an object',
);
}
if (!isObject(llmResponse)) {
throw new Error(
'Invalid input for AfterModel hook event: llm_response must be an object',
);
}
return {
llmRequest: llmRequest as unknown as GenerateContentParameters,
llmResponse: llmResponse as unknown as GenerateContentResponse,
};
}
/**
* Validates Notification input fields
*/
function validateNotificationInput(input: Record<string, unknown>): {
notificationType: NotificationType;
message: string;
details: Record<string, unknown>;
} {
const notificationType = input['notification_type'];
const message = input['message'];
const details = input['details'];
if (typeof notificationType !== 'string') {
throw new Error(
'Invalid input for Notification hook event: notification_type must be a string',
);
}
if (typeof message !== 'string') {
throw new Error(
'Invalid input for Notification hook event: message must be a string',
);
}
if (!isObject(details)) {
throw new Error(
'Invalid input for Notification hook event: details must be an object',
);
}
return {
notificationType: notificationType as NotificationType,
message,
details,
};
}
/**
* Hook event bus that coordinates hook execution across the system
*/
export class HookEventHandler {
private readonly config: Config;
private readonly hookPlanner: HookPlanner;
private readonly hookRunner: HookRunner;
private readonly hookAggregator: HookAggregator;
private readonly messageBus?: MessageBus;
constructor(
config: Config,
logger: Logger,
hookPlanner: HookPlanner,
hookRunner: HookRunner,
hookAggregator: HookAggregator,
messageBus?: MessageBus,
) {
this.config = config;
this.hookPlanner = hookPlanner;
this.hookRunner = hookRunner;
this.hookAggregator = hookAggregator;
this.messageBus = messageBus;
// Subscribe to hook execution requests from MessageBus
if (this.messageBus) {
this.messageBus.subscribe<HookExecutionRequest>(
MessageBusType.HOOK_EXECUTION_REQUEST,
(request) => this.handleHookExecutionRequest(request),
);
}
}
/**
* Fire a BeforeTool event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireBeforeToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
): Promise<AggregatedHookResult> {
const input: BeforeToolInput = {
...this.createBaseInput(HookEventName.BeforeTool),
tool_name: toolName,
tool_input: toolInput,
};
const context: HookEventContext = { toolName };
return await this.executeHooks(HookEventName.BeforeTool, input, context);
}
/**
* Fire an AfterTool event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireAfterToolEvent(
toolName: string,
toolInput: Record<string, unknown>,
toolResponse: Record<string, unknown>,
): Promise<AggregatedHookResult> {
const input: AfterToolInput = {
...this.createBaseInput(HookEventName.AfterTool),
tool_name: toolName,
tool_input: toolInput,
tool_response: toolResponse,
};
const context: HookEventContext = { toolName };
return await this.executeHooks(HookEventName.AfterTool, input, context);
}
/**
* Fire a BeforeAgent event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireBeforeAgentEvent(prompt: string): Promise<AggregatedHookResult> {
const input: BeforeAgentInput = {
...this.createBaseInput(HookEventName.BeforeAgent),
prompt,
};
return await this.executeHooks(HookEventName.BeforeAgent, input);
}
/**
* Fire a Notification event
*/
async fireNotificationEvent(
type: NotificationType,
message: string,
details: Record<string, unknown>,
): Promise<AggregatedHookResult> {
const input: NotificationInput = {
...this.createBaseInput(HookEventName.Notification),
notification_type: type,
message,
details,
};
return await this.executeHooks(HookEventName.Notification, input);
}
/**
* Fire an AfterAgent event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireAfterAgentEvent(
prompt: string,
promptResponse: string,
stopHookActive: boolean = false,
): Promise<AggregatedHookResult> {
const input: AfterAgentInput = {
...this.createBaseInput(HookEventName.AfterAgent),
prompt,
prompt_response: promptResponse,
stop_hook_active: stopHookActive,
};
return await this.executeHooks(HookEventName.AfterAgent, input);
}
/**
* Fire a SessionStart event
*/
async fireSessionStartEvent(
source: SessionStartSource,
): Promise<AggregatedHookResult> {
const input: SessionStartInput = {
...this.createBaseInput(HookEventName.SessionStart),
source,
};
const context: HookEventContext = { trigger: source };
return await this.executeHooks(HookEventName.SessionStart, input, context);
}
/**
* Fire a SessionEnd event
*/
async fireSessionEndEvent(
reason: SessionEndReason,
): Promise<AggregatedHookResult> {
const input: SessionEndInput = {
...this.createBaseInput(HookEventName.SessionEnd),
reason,
};
const context: HookEventContext = { trigger: reason };
return await this.executeHooks(HookEventName.SessionEnd, input, context);
}
/**
* Fire a PreCompress event
*/
async firePreCompressEvent(
trigger: PreCompressTrigger,
): Promise<AggregatedHookResult> {
const input: PreCompressInput = {
...this.createBaseInput(HookEventName.PreCompress),
trigger,
};
const context: HookEventContext = { trigger };
return await this.executeHooks(HookEventName.PreCompress, input, context);
}
/**
* Fire a BeforeModel event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireBeforeModelEvent(
llmRequest: GenerateContentParameters,
): Promise<AggregatedHookResult> {
const input: BeforeModelInput = {
...this.createBaseInput(HookEventName.BeforeModel),
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
};
return await this.executeHooks(HookEventName.BeforeModel, input);
}
/**
* Fire an AfterModel event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireAfterModelEvent(
llmRequest: GenerateContentParameters,
llmResponse: GenerateContentResponse,
): Promise<AggregatedHookResult> {
const input: AfterModelInput = {
...this.createBaseInput(HookEventName.AfterModel),
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
llm_response: defaultHookTranslator.toHookLLMResponse(llmResponse),
};
return await this.executeHooks(HookEventName.AfterModel, input);
}
/**
* Fire a BeforeToolSelection event
* Called by handleHookExecutionRequest - executes hooks directly
*/
async fireBeforeToolSelectionEvent(
llmRequest: GenerateContentParameters,
): Promise<AggregatedHookResult> {
const input: BeforeToolSelectionInput = {
...this.createBaseInput(HookEventName.BeforeToolSelection),
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
};
return await this.executeHooks(HookEventName.BeforeToolSelection, input);
}
/**
* Execute hooks for a specific event (direct execution without MessageBus)
* Used as fallback when MessageBus is not available
*/
private async executeHooks(
eventName: HookEventName,
input: HookInput,
context?: HookEventContext,
): Promise<AggregatedHookResult> {
try {
// Create execution plan
const plan = this.hookPlanner.createExecutionPlan(eventName, context);
if (!plan || plan.hookConfigs.length === 0) {
return {
success: true,
allOutputs: [],
errors: [],
totalDuration: 0,
};
}
// Execute hooks according to the plan's strategy
const results = plan.sequential
? await this.hookRunner.executeHooksSequential(
plan.hookConfigs,
eventName,
input,
)
: await this.hookRunner.executeHooksParallel(
plan.hookConfigs,
eventName,
input,
);
// Aggregate results
const aggregated = this.hookAggregator.aggregateResults(
results,
eventName,
);
// Process common hook output fields centrally
this.processCommonHookOutputFields(aggregated);
// Log hook execution
this.logHookExecution(eventName, input, results, aggregated);
return aggregated;
} catch (error) {
debugLogger.error(`Hook event bus error for ${eventName}: ${error}`);
return {
success: false,
allOutputs: [],
errors: [error instanceof Error ? error : new Error(String(error))],
totalDuration: 0,
};
}
}
/**
* Create base hook input with common fields
*/
private createBaseInput(eventName: HookEventName): HookInput {
return {
session_id: this.config.getSessionId(),
transcript_path: '', // TODO: Implement transcript path when supported
cwd: this.config.getWorkingDir(),
hook_event_name: eventName,
timestamp: new Date().toISOString(),
};
}
/**
* Log hook execution for observability
*/
private logHookExecution(
eventName: HookEventName,
input: HookInput,
results: HookExecutionResult[],
aggregated: AggregatedHookResult,
): void {
const successCount = results.filter((r) => r.success).length;
const errorCount = results.length - successCount;
if (errorCount > 0) {
debugLogger.warn(
`Hook execution for ${eventName}: ${successCount} succeeded, ${errorCount} failed, ` +
`total duration: ${aggregated.totalDuration}ms`,
);
} else {
debugLogger.debug(
`Hook execution for ${eventName}: ${successCount} hooks executed successfully, ` +
`total duration: ${aggregated.totalDuration}ms`,
);
}
// Log individual hook calls to telemetry
for (const result of results) {
// Determine hook name and type for telemetry
const hookName = this.getHookNameFromResult(result);
const hookType = this.getHookTypeFromResult(result);
const hookCallEvent = new HookCallEvent(
eventName,
hookType,
hookName,
{ ...input },
result.duration,
result.success,
result.output ? { ...result.output } : undefined,
result.exitCode,
result.stdout,
result.stderr,
result.error?.message,
);
logHookCall(this.config, hookCallEvent);
}
// Log individual errors
for (const error of aggregated.errors) {
debugLogger.error(`Hook execution error: ${error.message}`);
}
}
/**
* Process common hook output fields centrally
*/
private processCommonHookOutputFields(
aggregated: AggregatedHookResult,
): void {
if (!aggregated.finalOutput) {
return;
}
// Handle systemMessage - show to user in transcript mode (not to agent)
const systemMessage = aggregated.finalOutput.systemMessage;
if (systemMessage && !aggregated.finalOutput.suppressOutput) {
debugLogger.warn(`Hook system message: ${systemMessage}`);
}
// Handle suppressOutput - already handled by not logging above when true
// Handle continue=false - this should stop the entire agent execution
if (aggregated.finalOutput.shouldStopExecution()) {
const stopReason = aggregated.finalOutput.getEffectiveReason();
debugLogger.log(`Hook requested to stop execution: ${stopReason}`);
// Note: The actual stopping of execution must be handled by integration points
// as they need to interpret this signal in the context of their specific workflow
// This is just logging the request centrally
}
// Other common fields like decision/reason are handled by specific hook output classes
}
/**
* Get hook name from execution result for telemetry
*/
private getHookNameFromResult(result: HookExecutionResult): string {
return result.hookConfig.command || 'unknown-command';
}
/**
* Get hook type from execution result for telemetry
*/
private getHookTypeFromResult(result: HookExecutionResult): 'command' {
return result.hookConfig.type;
}
/**
* Handle hook execution requests from MessageBus
* This method routes the request to the appropriate fire*Event method
* and publishes the response back through MessageBus
*
* The request input only contains event-specific fields. This method adds
* the common base fields (session_id, cwd, etc.) before routing.
*/
private async handleHookExecutionRequest(
request: HookExecutionRequest,
): Promise<void> {
try {
// Add base fields to the input
const enrichedInput = {
...this.createBaseInput(request.eventName as HookEventName),
...request.input,
} as Record<string, unknown>;
let result: AggregatedHookResult;
// Route to appropriate event handler based on eventName
switch (request.eventName) {
case HookEventName.BeforeTool: {
const { toolName, toolInput } =
validateBeforeToolInput(enrichedInput);
result = await this.fireBeforeToolEvent(toolName, toolInput);
break;
}
case HookEventName.AfterTool: {
const { toolName, toolInput, toolResponse } =
validateAfterToolInput(enrichedInput);
result = await this.fireAfterToolEvent(
toolName,
toolInput,
toolResponse,
);
break;
}
case HookEventName.BeforeAgent: {
const { prompt } = validateBeforeAgentInput(enrichedInput);
result = await this.fireBeforeAgentEvent(prompt);
break;
}
case HookEventName.AfterAgent: {
const { prompt, promptResponse, stopHookActive } =
validateAfterAgentInput(enrichedInput);
result = await this.fireAfterAgentEvent(
prompt,
promptResponse,
stopHookActive,
);
break;
}
case HookEventName.BeforeModel: {
const { llmRequest } = validateModelInput(
enrichedInput,
'BeforeModel',
);
const translatedRequest =
defaultHookTranslator.toHookLLMRequest(llmRequest);
// Update the enrichedInput with translated request
enrichedInput['llm_request'] = translatedRequest;
result = await this.fireBeforeModelEvent(llmRequest);
break;
}
case HookEventName.AfterModel: {
const { llmRequest, llmResponse } =
validateAfterModelInput(enrichedInput);
const translatedRequest =
defaultHookTranslator.toHookLLMRequest(llmRequest);
const translatedResponse =
defaultHookTranslator.toHookLLMResponse(llmResponse);
// Update the enrichedInput with translated versions
enrichedInput['llm_request'] = translatedRequest;
enrichedInput['llm_response'] = translatedResponse;
result = await this.fireAfterModelEvent(llmRequest, llmResponse);
break;
}
case HookEventName.BeforeToolSelection: {
const { llmRequest } = validateModelInput(
enrichedInput,
'BeforeToolSelection',
);
const translatedRequest =
defaultHookTranslator.toHookLLMRequest(llmRequest);
// Update the enrichedInput with translated request
enrichedInput['llm_request'] = translatedRequest;
result = await this.fireBeforeToolSelectionEvent(llmRequest);
break;
}
case HookEventName.Notification: {
const { notificationType, message, details } =
validateNotificationInput(enrichedInput);
result = await this.fireNotificationEvent(
notificationType,
message,
details,
);
break;
}
default:
throw new Error(`Unsupported hook event: ${request.eventName}`);
}
// Publish response through MessageBus
if (this.messageBus) {
this.messageBus.publish({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: request.correlationId,
success: result.success,
output: result.finalOutput as unknown as Record<string, unknown>,
});
}
} catch (error) {
// Publish error response
if (this.messageBus) {
this.messageBus.publish({
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: request.correlationId,
success: false,
error: error instanceof Error ? error : new Error(String(error)),
});
}
}
}
}

View File

@@ -1250,4 +1250,291 @@ describe('PolicyEngine', () => {
expect(result.decision).toBe(PolicyDecision.DENY);
});
});
describe('checkHook', () => {
it('should allow hooks by default', async () => {
engine = new PolicyEngine({}, mockCheckerRunner);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.ALLOW);
});
it('should deny all hooks when allowHooks is false', async () => {
engine = new PolicyEngine({ allowHooks: false }, mockCheckerRunner);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.DENY);
});
it('should deny project hooks in untrusted folders', async () => {
engine = new PolicyEngine({}, mockCheckerRunner);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'project',
trustedFolder: false,
});
expect(decision).toBe(PolicyDecision.DENY);
});
it('should allow project hooks in trusted folders', async () => {
engine = new PolicyEngine({}, mockCheckerRunner);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'project',
trustedFolder: true,
});
expect(decision).toBe(PolicyDecision.ALLOW);
});
it('should allow user hooks in untrusted folders', async () => {
engine = new PolicyEngine({}, mockCheckerRunner);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
trustedFolder: false,
});
expect(decision).toBe(PolicyDecision.ALLOW);
});
it('should run hook checkers and deny on DENY decision', async () => {
const hookCheckers = [
{
eventName: 'BeforeTool',
checker: { type: 'external' as const, name: 'test-hook-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.DENY,
reason: 'Hook checker denied',
});
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.DENY);
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
expect.objectContaining({ name: 'hook:BeforeTool' }),
expect.objectContaining({ name: 'test-hook-checker' }),
);
});
it('should run hook checkers and allow on ALLOW decision', async () => {
const hookCheckers = [
{
eventName: 'BeforeTool',
checker: { type: 'external' as const, name: 'test-hook-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.ALLOW,
});
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.ALLOW);
});
it('should return ASK_USER when checker requests it', async () => {
const hookCheckers = [
{
checker: { type: 'external' as const, name: 'test-hook-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.ASK_USER,
reason: 'Needs confirmation',
});
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.ASK_USER);
});
it('should return DENY for ASK_USER in non-interactive mode', async () => {
const hookCheckers = [
{
checker: { type: 'external' as const, name: 'test-hook-checker' },
},
];
engine = new PolicyEngine(
{ hookCheckers, nonInteractive: true },
mockCheckerRunner,
);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.ASK_USER,
reason: 'Needs confirmation',
});
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.DENY);
});
it('should match hook checkers by eventName', async () => {
const hookCheckers = [
{
eventName: 'AfterTool',
checker: { type: 'external' as const, name: 'after-tool-checker' },
},
{
eventName: 'BeforeTool',
checker: { type: 'external' as const, name: 'before-tool-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.ALLOW,
});
await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ name: 'before-tool-checker' }),
);
expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ name: 'after-tool-checker' }),
);
});
it('should match hook checkers by hookSource', async () => {
const hookCheckers = [
{
hookSource: 'project' as const,
checker: { type: 'external' as const, name: 'project-checker' },
},
{
hookSource: 'user' as const,
checker: { type: 'external' as const, name: 'user-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
decision: SafetyCheckDecision.ALLOW,
});
await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ name: 'user-checker' }),
);
expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ name: 'project-checker' }),
);
});
it('should deny when hook checker throws an error', async () => {
const hookCheckers = [
{
checker: { type: 'external' as const, name: 'failing-checker' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockRejectedValue(
new Error('Checker failed'),
);
const decision = await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
expect(decision).toBe(PolicyDecision.DENY);
});
it('should run hook checkers in priority order', async () => {
const hookCheckers = [
{
priority: 5,
checker: { type: 'external' as const, name: 'low-priority' },
},
{
priority: 20,
checker: { type: 'external' as const, name: 'high-priority' },
},
{
priority: 10,
checker: { type: 'external' as const, name: 'medium-priority' },
},
];
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
vi.mocked(mockCheckerRunner.runChecker).mockImplementation(
async (_call, config) => {
if (config.name === 'high-priority') {
return { decision: SafetyCheckDecision.DENY, reason: 'denied' };
}
return { decision: SafetyCheckDecision.ALLOW };
},
);
await engine.checkHook({
eventName: 'BeforeTool',
hookSource: 'user',
});
// Should only call the high-priority checker (first in sorted order)
expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1);
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
expect.anything(),
expect.objectContaining({ name: 'high-priority' }),
);
});
});
describe('addHookChecker', () => {
it('should add a new hook checker and maintain priority order', () => {
engine = new PolicyEngine({}, mockCheckerRunner);
engine.addHookChecker({
priority: 5,
checker: { type: 'external', name: 'checker1' },
});
engine.addHookChecker({
priority: 10,
checker: { type: 'external', name: 'checker2' },
});
const checkers = engine.getHookCheckers();
expect(checkers).toHaveLength(2);
expect(checkers[0].priority).toBe(10);
expect(checkers[0].checker.name).toBe('checker2');
expect(checkers[1].priority).toBe(5);
expect(checkers[1].checker.name).toBe('checker1');
});
});
});

View File

@@ -10,11 +10,15 @@ import {
type PolicyEngineConfig,
type PolicyRule,
type SafetyCheckerRule,
type HookCheckerRule,
type HookExecutionContext,
getHookSource,
} from './types.js';
import { stableStringify } from './stable-stringify.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { CheckerRunner } from '../safety/checker-runner.js';
import { SafetyCheckDecision } from '../safety/protocol.js';
import type { HookExecutionRequest } from '../confirmation-bus/types.js';
function ruleMatches(
rule: PolicyRule | SafetyCheckerRule,
@@ -61,12 +65,34 @@ function ruleMatches(
return true;
}
/**
* Check if a hook checker rule matches a hook execution context.
*/
function hookCheckerMatches(
rule: HookCheckerRule,
context: HookExecutionContext,
): boolean {
// Check event name if specified
if (rule.eventName && rule.eventName !== context.eventName) {
return false;
}
// Check hook source if specified
if (rule.hookSource && rule.hookSource !== context.hookSource) {
return false;
}
return true;
}
export class PolicyEngine {
private rules: PolicyRule[];
private checkers: SafetyCheckerRule[];
private hookCheckers: HookCheckerRule[];
private readonly defaultDecision: PolicyDecision;
private readonly nonInteractive: boolean;
private readonly checkerRunner?: CheckerRunner;
private readonly allowHooks: boolean;
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
this.rules = (config.rules ?? []).sort(
@@ -75,9 +101,13 @@ export class PolicyEngine {
this.checkers = (config.checkers ?? []).sort(
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
);
this.hookCheckers = (config.hookCheckers ?? []).sort(
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
);
this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER;
this.nonInteractive = config.nonInteractive ?? false;
this.checkerRunner = checkerRunner;
this.allowHooks = config.allowHooks ?? true;
}
/**
@@ -206,6 +236,99 @@ export class PolicyEngine {
return this.checkers;
}
/**
* Add a new hook checker to the policy engine.
*/
addHookChecker(checker: HookCheckerRule): void {
this.hookCheckers.push(checker);
this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
}
/**
* Get all current hook checkers.
*/
getHookCheckers(): readonly HookCheckerRule[] {
return this.hookCheckers;
}
/**
* Check if a hook execution is allowed based on the configured policies.
* Runs hook-specific safety checkers if configured.
*/
async checkHook(
request: HookExecutionRequest | HookExecutionContext,
): Promise<PolicyDecision> {
// If hooks are globally disabled, deny all hook executions
if (!this.allowHooks) {
return PolicyDecision.DENY;
}
const context: HookExecutionContext =
'input' in request
? {
eventName: request.eventName,
hookSource: getHookSource(request.input),
trustedFolder:
typeof request.input['trusted_folder'] === 'boolean'
? request.input['trusted_folder']
: undefined,
}
: request;
// In untrusted folders, deny project-level hooks
if (context.trustedFolder === false && context.hookSource === 'project') {
return PolicyDecision.DENY;
}
// Run hook-specific safety checkers if configured
if (this.checkerRunner && this.hookCheckers.length > 0) {
for (const checkerRule of this.hookCheckers) {
if (hookCheckerMatches(checkerRule, context)) {
debugLogger.debug(
`[PolicyEngine.checkHook] Running hook checker: ${checkerRule.checker.name} for event: ${context.eventName}`,
);
try {
// Create a synthetic function call for the checker runner
// This allows reusing the existing checker infrastructure
const syntheticCall = {
name: `hook:${context.eventName}`,
args: {
hookSource: context.hookSource,
trustedFolder: context.trustedFolder,
},
};
const result = await this.checkerRunner.runChecker(
syntheticCall,
checkerRule.checker,
);
if (result.decision === SafetyCheckDecision.DENY) {
debugLogger.debug(
`[PolicyEngine.checkHook] Hook checker denied: ${result.reason}`,
);
return PolicyDecision.DENY;
} else if (result.decision === SafetyCheckDecision.ASK_USER) {
debugLogger.debug(
`[PolicyEngine.checkHook] Hook checker requested ASK_USER: ${result.reason}`,
);
// For hooks, ASK_USER is treated as DENY in non-interactive mode
return this.applyNonInteractiveMode(PolicyDecision.ASK_USER);
}
} catch (error) {
debugLogger.debug(
`[PolicyEngine.checkHook] Hook checker failed: ${error}`,
);
return PolicyDecision.DENY;
}
}
}
}
// Default: Allow hooks
return PolicyDecision.ALLOW;
}
private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision {
// In non-interactive mode, ASK_USER becomes DENY
if (this.nonInteractive && decision === PolicyDecision.ASK_USER) {

View File

@@ -12,6 +12,36 @@ export enum PolicyDecision {
ASK_USER = 'ask_user',
}
/**
* Valid sources for hook execution
*/
export type HookSource = 'project' | 'user' | 'system' | 'extension';
/**
* Array of valid hook source values for runtime validation
*/
const VALID_HOOK_SOURCES: HookSource[] = [
'project',
'user',
'system',
'extension',
];
/**
* Safely extract and validate hook source from input
* Returns 'project' as default if the value is invalid or missing
*/
export function getHookSource(input: Record<string, unknown>): HookSource {
const source = input['hook_source'];
if (
typeof source === 'string' &&
VALID_HOOK_SOURCES.includes(source as HookSource)
) {
return source as HookSource;
}
return 'project';
}
export enum ApprovalMode {
DEFAULT = 'default',
AUTO_EDIT = 'autoEdit',
@@ -115,6 +145,42 @@ export interface SafetyCheckerRule {
checker: SafetyCheckerConfig;
}
export interface HookExecutionContext {
eventName: string;
hookSource?: HookSource;
trustedFolder?: boolean;
}
/**
* Rule for applying safety checkers to hook executions.
* Similar to SafetyCheckerRule but with hook-specific matching criteria.
*/
export interface HookCheckerRule {
/**
* The name of the hook event this rule applies to.
* If undefined, the rule applies to all hook events.
*/
eventName?: string;
/**
* The source of hooks this rule applies to.
* If undefined, the rule applies to all hook sources.
*/
hookSource?: HookSource;
/**
* Priority of this checker. Higher numbers run first.
* Default is 0.
*/
priority?: number;
/**
* Specifies an external or built-in safety checker to execute for
* additional validation of a hook execution.
*/
checker: SafetyCheckerConfig;
}
export interface PolicyEngineConfig {
/**
* List of policy rules to apply.
@@ -122,10 +188,15 @@ export interface PolicyEngineConfig {
rules?: PolicyRule[];
/**
* List of safety checkers to apply.
* List of safety checkers to apply to tool calls.
*/
checkers?: SafetyCheckerRule[];
/**
* List of safety checkers to apply to hook executions.
*/
hookCheckers?: HookCheckerRule[];
/**
* Default decision when no rules match.
* Defaults to ASK_USER.
@@ -137,6 +208,13 @@ export interface PolicyEngineConfig {
* When true, ASK_USER decisions become DENY.
*/
nonInteractive?: boolean;
/**
* Whether to allow hooks to execute.
* When false, all hooks are denied.
* Defaults to true.
*/
allowHooks?: boolean;
}
export interface PolicySettings {

View File

@@ -0,0 +1,177 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
MessageBusType,
type Message,
type HookExecutionRequest,
type HookExecutionResponse,
} from '../confirmation-bus/types.js';
/**
* Mock MessageBus for testing hook execution through MessageBus
*/
export class MockMessageBus {
private subscriptions = new Map<
MessageBusType,
Set<(message: Message) => void>
>();
publishedMessages: Message[] = [];
hookRequests: HookExecutionRequest[] = [];
hookResponses: HookExecutionResponse[] = [];
/**
* Mock publish method that captures messages and simulates responses
*/
publish = vi.fn((message: Message) => {
this.publishedMessages.push(message);
// Capture hook-specific messages
if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) {
this.hookRequests.push(message as HookExecutionRequest);
// Auto-respond with success for testing
const response: HookExecutionResponse = {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId: (message as HookExecutionRequest).correlationId,
success: true,
output: {
decision: 'allow',
reason: 'Mock hook execution successful',
},
};
this.hookResponses.push(response);
// Emit response to subscribers
this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response);
}
});
/**
* Mock subscribe method that stores listeners
*/
subscribe = vi.fn(
<T extends Message>(type: T['type'], listener: (message: T) => void) => {
if (!this.subscriptions.has(type)) {
this.subscriptions.set(type, new Set());
}
this.subscriptions.get(type)!.add(listener as (message: Message) => void);
},
);
/**
* Mock unsubscribe method
*/
unsubscribe = vi.fn(
<T extends Message>(type: T['type'], listener: (message: T) => void) => {
const listeners = this.subscriptions.get(type);
if (listeners) {
listeners.delete(listener as (message: Message) => void);
}
},
);
/**
* Emit a message to subscribers (for testing)
*/
private emit(type: MessageBusType, message: Message) {
const listeners = this.subscriptions.get(type);
if (listeners) {
listeners.forEach((listener) => listener(message));
}
}
/**
* Manually trigger a hook response (for testing custom scenarios)
*/
triggerHookResponse(
correlationId: string,
success: boolean,
output?: Record<string, unknown>,
error?: Error,
) {
const response: HookExecutionResponse = {
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
correlationId,
success,
output,
error,
};
this.hookResponses.push(response);
this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response);
}
/**
* Get the last hook request published
*/
getLastHookRequest(): HookExecutionRequest | undefined {
return this.hookRequests[this.hookRequests.length - 1];
}
/**
* Get all hook requests for a specific event
*/
getHookRequestsForEvent(eventName: string): HookExecutionRequest[] {
return this.hookRequests.filter((req) => req.eventName === eventName);
}
/**
* Clear all captured messages (for test isolation)
*/
clear() {
this.publishedMessages = [];
this.hookRequests = [];
this.hookResponses = [];
this.subscriptions.clear();
}
/**
* Verify that a hook execution request was published
*/
expectHookRequest(
eventName: string,
input?: Partial<Record<string, unknown>>,
) {
const request = this.hookRequests.find(
(req) => req.eventName === eventName,
);
if (!request) {
throw new Error(
`Expected hook request for event "${eventName}" but none was found`,
);
}
if (input) {
Object.entries(input).forEach(([key, value]) => {
if (request.input[key] !== value) {
throw new Error(
`Expected hook input.${key} to be ${JSON.stringify(value)} but got ${JSON.stringify(request.input[key])}`,
);
}
});
}
return request;
}
}
/**
* Create a mock MessageBus for testing
*/
export function createMockMessageBus(): MessageBus {
return new MockMessageBus() as unknown as MessageBus;
}
/**
* Get the MockMessageBus instance from a mocked MessageBus
*/
export function getMockMessageBusInstance(
messageBus: MessageBus,
): MockMessageBus {
return messageBus as unknown as MockMessageBus;
}