mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
feat(hooks): Hook Event Handling (#9097)
This commit is contained in:
@@ -4,10 +4,16 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { EventEmitter } from 'node:events';
|
||||
import type { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import { PolicyDecision } from '../policy/types.js';
|
||||
import { MessageBusType, type Message } from './types.js';
|
||||
import { PolicyDecision, getHookSource } from '../policy/types.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type Message,
|
||||
type HookExecutionRequest,
|
||||
type HookPolicyDecision,
|
||||
} from './types.js';
|
||||
import { safeJsonStringify } from '../utils/safeJsonStringify.js';
|
||||
|
||||
export class MessageBus extends EventEmitter {
|
||||
@@ -83,6 +89,39 @@ export class MessageBus extends EventEmitter {
|
||||
default:
|
||||
throw new Error(`Unknown policy decision: ${decision}`);
|
||||
}
|
||||
} else if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) {
|
||||
// Handle hook execution requests through policy evaluation
|
||||
const hookRequest = message as HookExecutionRequest;
|
||||
const decision = await this.policyEngine.checkHook(hookRequest);
|
||||
|
||||
// Map decision to allow/deny for observability (ASK_USER treated as deny for hooks)
|
||||
const effectiveDecision =
|
||||
decision === PolicyDecision.ALLOW ? 'allow' : 'deny';
|
||||
|
||||
// Emit policy decision for observability
|
||||
this.emitMessage({
|
||||
type: MessageBusType.HOOK_POLICY_DECISION,
|
||||
eventName: hookRequest.eventName,
|
||||
hookSource: getHookSource(hookRequest.input),
|
||||
decision: effectiveDecision,
|
||||
reason:
|
||||
decision !== PolicyDecision.ALLOW
|
||||
? 'Hook execution denied by policy'
|
||||
: undefined,
|
||||
} as HookPolicyDecision);
|
||||
|
||||
// If allowed, emit the request for hook system to handle
|
||||
if (decision === PolicyDecision.ALLOW) {
|
||||
this.emitMessage(message);
|
||||
} else {
|
||||
// If denied or ASK_USER, emit error response (hooks don't support interactive confirmation)
|
||||
this.emitMessage({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: hookRequest.correlationId,
|
||||
success: false,
|
||||
error: new Error('Hook execution denied by policy'),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// For all other message types, just emit them
|
||||
this.emitMessage(message);
|
||||
@@ -105,4 +144,46 @@ export class MessageBus extends EventEmitter {
|
||||
): void {
|
||||
this.off(type, listener);
|
||||
}
|
||||
|
||||
/**
|
||||
* Request-response pattern: Publish a message and wait for a correlated response
|
||||
* This enables synchronous-style communication over the async MessageBus
|
||||
* The correlation ID is generated internally and added to the request
|
||||
*/
|
||||
async request<TRequest extends Message, TResponse extends Message>(
|
||||
request: Omit<TRequest, 'correlationId'>,
|
||||
responseType: TResponse['type'],
|
||||
timeoutMs: number = 60000,
|
||||
): Promise<TResponse> {
|
||||
const correlationId = randomUUID();
|
||||
|
||||
return new Promise<TResponse>((resolve, reject) => {
|
||||
const timeoutId = setTimeout(() => {
|
||||
cleanup();
|
||||
reject(new Error(`Request timed out waiting for ${responseType}`));
|
||||
}, timeoutMs);
|
||||
|
||||
const cleanup = () => {
|
||||
clearTimeout(timeoutId);
|
||||
this.unsubscribe(responseType, responseHandler);
|
||||
};
|
||||
|
||||
const responseHandler = (response: TResponse) => {
|
||||
// Check if this response matches our request
|
||||
if (
|
||||
'correlationId' in response &&
|
||||
response.correlationId === correlationId
|
||||
) {
|
||||
cleanup();
|
||||
resolve(response);
|
||||
}
|
||||
};
|
||||
|
||||
// Subscribe to responses
|
||||
this.subscribe<TResponse>(responseType, responseHandler);
|
||||
|
||||
// Publish the request with correlation ID
|
||||
this.publish({ ...request, correlationId } as TRequest);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,9 @@ export enum MessageBusType {
|
||||
TOOL_EXECUTION_SUCCESS = 'tool-execution-success',
|
||||
TOOL_EXECUTION_FAILURE = 'tool-execution-failure',
|
||||
UPDATE_POLICY = 'update-policy',
|
||||
HOOK_EXECUTION_REQUEST = 'hook-execution-request',
|
||||
HOOK_EXECUTION_RESPONSE = 'hook-execution-response',
|
||||
HOOK_POLICY_DECISION = 'hook-policy-decision',
|
||||
}
|
||||
|
||||
export interface ToolConfirmationRequest {
|
||||
@@ -55,10 +58,36 @@ export interface ToolExecutionFailure<E = Error> {
|
||||
error: E;
|
||||
}
|
||||
|
||||
export interface HookExecutionRequest {
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST;
|
||||
eventName: string;
|
||||
input: Record<string, unknown>;
|
||||
correlationId: string;
|
||||
}
|
||||
|
||||
export interface HookExecutionResponse {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE;
|
||||
correlationId: string;
|
||||
success: boolean;
|
||||
output?: Record<string, unknown>;
|
||||
error?: Error;
|
||||
}
|
||||
|
||||
export interface HookPolicyDecision {
|
||||
type: MessageBusType.HOOK_POLICY_DECISION;
|
||||
eventName: string;
|
||||
hookSource: 'project' | 'user' | 'system' | 'extension';
|
||||
decision: 'allow' | 'deny';
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
export type Message =
|
||||
| ToolConfirmationRequest
|
||||
| ToolConfirmationResponse
|
||||
| ToolPolicyRejection
|
||||
| ToolExecutionSuccess
|
||||
| ToolExecutionFailure
|
||||
| UpdatePolicy;
|
||||
| UpdatePolicy
|
||||
| HookExecutionRequest
|
||||
| HookExecutionResponse
|
||||
| HookPolicyDecision;
|
||||
|
||||
524
packages/core/src/hooks/hookEventHandler.test.ts
Normal file
524
packages/core/src/hooks/hookEventHandler.test.ts
Normal file
@@ -0,0 +1,524 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { HookEventHandler } from './hookEventHandler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { HookConfig } from './types.js';
|
||||
import type { Logger } from '@opentelemetry/api-logs';
|
||||
import type { HookPlanner } from './hookPlanner.js';
|
||||
import type { HookRunner } from './hookRunner.js';
|
||||
import type { HookAggregator } from './hookAggregator.js';
|
||||
import { HookEventName, HookType } from './types.js';
|
||||
import {
|
||||
NotificationType,
|
||||
SessionStartSource,
|
||||
type HookExecutionResult,
|
||||
} from './types.js';
|
||||
|
||||
// Mock debugLogger
|
||||
const mockDebugLogger = vi.hoisted(() => ({
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: mockDebugLogger,
|
||||
}));
|
||||
|
||||
describe('HookEventHandler', () => {
|
||||
let hookEventHandler: HookEventHandler;
|
||||
let mockConfig: Config;
|
||||
let mockLogger: Logger;
|
||||
let mockHookPlanner: HookPlanner;
|
||||
let mockHookRunner: HookRunner;
|
||||
let mockHookAggregator: HookAggregator;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
mockConfig = {
|
||||
getSessionId: vi.fn().mockReturnValue('test-session'),
|
||||
getWorkingDir: vi.fn().mockReturnValue('/test/project'),
|
||||
} as unknown as Config;
|
||||
|
||||
mockLogger = {} as Logger;
|
||||
|
||||
mockHookPlanner = {
|
||||
createExecutionPlan: vi.fn(),
|
||||
} as unknown as HookPlanner;
|
||||
|
||||
mockHookRunner = {
|
||||
executeHooksParallel: vi.fn(),
|
||||
executeHooksSequential: vi.fn(),
|
||||
} as unknown as HookRunner;
|
||||
|
||||
mockHookAggregator = {
|
||||
aggregateResults: vi.fn(),
|
||||
} as unknown as HookAggregator;
|
||||
|
||||
hookEventHandler = new HookEventHandler(
|
||||
mockConfig,
|
||||
mockLogger,
|
||||
mockHookPlanner,
|
||||
mockHookRunner,
|
||||
mockHookAggregator,
|
||||
);
|
||||
});
|
||||
|
||||
describe('fireBeforeToolEvent', () => {
|
||||
it('should fire BeforeTool event with correct input', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {
|
||||
file: 'test.txt',
|
||||
});
|
||||
|
||||
expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith(
|
||||
HookEventName.BeforeTool,
|
||||
{ toolName: 'EditTool' },
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.BeforeTool,
|
||||
expect.objectContaining({
|
||||
session_id: 'test-session',
|
||||
cwd: '/test/project',
|
||||
hook_event_name: 'BeforeTool',
|
||||
tool_name: 'EditTool',
|
||||
tool_input: { file: 'test.txt' },
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
|
||||
it('should return empty result when no hooks to execute', async () => {
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue(null);
|
||||
|
||||
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.allOutputs).toHaveLength(0);
|
||||
expect(result.errors).toHaveLength(0);
|
||||
expect(result.totalDuration).toBe(0);
|
||||
});
|
||||
|
||||
it('should handle execution errors gracefully', async () => {
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockImplementation(() => {
|
||||
throw new Error('Planning failed');
|
||||
});
|
||||
|
||||
const result = await hookEventHandler.fireBeforeToolEvent('EditTool', {});
|
||||
|
||||
expect(result.success).toBe(false);
|
||||
expect(result.errors).toHaveLength(1);
|
||||
expect(result.errors[0].message).toBe('Planning failed');
|
||||
expect(mockDebugLogger.error).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireAfterToolEvent', () => {
|
||||
it('should fire AfterTool event with tool response', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './after.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.AfterTool,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const toolInput = { file: 'test.txt' };
|
||||
const toolResponse = { success: true, content: 'File edited' };
|
||||
|
||||
const result = await hookEventHandler.fireAfterToolEvent(
|
||||
'EditTool',
|
||||
toolInput,
|
||||
toolResponse,
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.AfterTool,
|
||||
expect.objectContaining({
|
||||
tool_name: 'EditTool',
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireBeforeAgentEvent', () => {
|
||||
it('should fire BeforeAgent event with prompt', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './before_agent.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.BeforeAgent,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 100,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 100,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const prompt = 'Please help me with this task';
|
||||
|
||||
const result = await hookEventHandler.fireBeforeAgentEvent(prompt);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.BeforeAgent,
|
||||
expect.objectContaining({
|
||||
prompt,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireNotificationEvent', () => {
|
||||
it('should fire Notification event', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './notification-hook.sh',
|
||||
} as HookConfig,
|
||||
eventName: HookEventName.Notification,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 50,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './notification-hook.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.Notification,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 50,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.Notification,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const message = 'Tool execution requires permission';
|
||||
|
||||
const result = await hookEventHandler.fireNotificationEvent(
|
||||
NotificationType.ToolPermission,
|
||||
message,
|
||||
{ type: 'ToolPermission', title: 'Test Permission' },
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.Notification,
|
||||
expect.objectContaining({
|
||||
notification_type: 'ToolPermission',
|
||||
details: { type: 'ToolPermission', title: 'Test Permission' },
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireSessionStartEvent', () => {
|
||||
it('should fire SessionStart event with source', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './session_start.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.SessionStart,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 200,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './session_start.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.SessionStart,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 200,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.SessionStart,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const result = await hookEventHandler.fireSessionStartEvent(
|
||||
SessionStartSource.Startup,
|
||||
);
|
||||
|
||||
expect(mockHookPlanner.createExecutionPlan).toHaveBeenCalledWith(
|
||||
HookEventName.SessionStart,
|
||||
{ trigger: 'startup' },
|
||||
);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.SessionStart,
|
||||
expect.objectContaining({
|
||||
source: 'startup',
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fireBeforeModelEvent', () => {
|
||||
it('should fire BeforeModel event with LLM request', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './model-hook.sh',
|
||||
} as HookConfig,
|
||||
eventName: HookEventName.BeforeModel,
|
||||
},
|
||||
];
|
||||
const mockResults: HookExecutionResult[] = [
|
||||
{
|
||||
success: true,
|
||||
duration: 150,
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './model-hook.sh',
|
||||
timeout: 30000,
|
||||
},
|
||||
eventName: HookEventName.BeforeModel,
|
||||
},
|
||||
];
|
||||
const mockAggregated = {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 150,
|
||||
};
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeModel,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue(
|
||||
mockResults,
|
||||
);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue(
|
||||
mockAggregated,
|
||||
);
|
||||
|
||||
const llmRequest = {
|
||||
model: 'gemini-pro',
|
||||
config: { temperature: 0.7 },
|
||||
contents: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
};
|
||||
|
||||
const result = await hookEventHandler.fireBeforeModelEvent(llmRequest);
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
[mockPlan[0].hookConfig],
|
||||
HookEventName.BeforeModel,
|
||||
expect.objectContaining({
|
||||
llm_request: expect.objectContaining({
|
||||
model: 'gemini-pro',
|
||||
messages: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
role: 'user',
|
||||
content: 'Hello',
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result).toBe(mockAggregated);
|
||||
});
|
||||
});
|
||||
|
||||
describe('createBaseInput', () => {
|
||||
it('should create base input with correct fields', async () => {
|
||||
const mockPlan = [
|
||||
{
|
||||
hookConfig: {
|
||||
type: HookType.Command,
|
||||
command: './test.sh',
|
||||
} as unknown as HookConfig,
|
||||
eventName: HookEventName.BeforeTool,
|
||||
},
|
||||
];
|
||||
|
||||
vi.mocked(mockHookPlanner.createExecutionPlan).mockReturnValue({
|
||||
eventName: HookEventName.BeforeTool,
|
||||
hookConfigs: mockPlan.map((p) => p.hookConfig),
|
||||
sequential: false,
|
||||
});
|
||||
vi.mocked(mockHookRunner.executeHooksParallel).mockResolvedValue([]);
|
||||
vi.mocked(mockHookAggregator.aggregateResults).mockReturnValue({
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
});
|
||||
|
||||
await hookEventHandler.fireBeforeToolEvent('TestTool', {});
|
||||
|
||||
expect(mockHookRunner.executeHooksParallel).toHaveBeenCalledWith(
|
||||
expect.any(Array),
|
||||
HookEventName.BeforeTool,
|
||||
expect.objectContaining({
|
||||
session_id: 'test-session',
|
||||
transcript_path: '',
|
||||
cwd: '/test/project',
|
||||
hook_event_name: 'BeforeTool',
|
||||
timestamp: expect.any(String),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
732
packages/core/src/hooks/hookEventHandler.ts
Normal file
732
packages/core/src/hooks/hookEventHandler.ts
Normal file
@@ -0,0 +1,732 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Logger } from '@opentelemetry/api-logs';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { HookPlanner, HookEventContext } from './hookPlanner.js';
|
||||
import type { HookRunner } from './hookRunner.js';
|
||||
import type { HookAggregator, AggregatedHookResult } from './hookAggregator.js';
|
||||
import { HookEventName } from './types.js';
|
||||
import type {
|
||||
HookInput,
|
||||
BeforeToolInput,
|
||||
AfterToolInput,
|
||||
BeforeAgentInput,
|
||||
NotificationInput,
|
||||
AfterAgentInput,
|
||||
SessionStartInput,
|
||||
SessionEndInput,
|
||||
PreCompressInput,
|
||||
BeforeModelInput,
|
||||
AfterModelInput,
|
||||
BeforeToolSelectionInput,
|
||||
NotificationType,
|
||||
SessionStartSource,
|
||||
SessionEndReason,
|
||||
PreCompressTrigger,
|
||||
HookExecutionResult,
|
||||
} from './types.js';
|
||||
import { defaultHookTranslator } from './hookTranslator.js';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { logHookCall } from '../telemetry/loggers.js';
|
||||
import { HookCallEvent } from '../telemetry/types.js';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionRequest,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Validates that a value is a non-null object
|
||||
*/
|
||||
function isObject(value: unknown): value is Record<string, unknown> {
|
||||
return typeof value === 'object' && value !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates BeforeTool input fields
|
||||
*/
|
||||
function validateBeforeToolInput(input: Record<string, unknown>): {
|
||||
toolName: string;
|
||||
toolInput: Record<string, unknown>;
|
||||
} {
|
||||
const toolName = input['tool_name'];
|
||||
const toolInput = input['tool_input'];
|
||||
if (typeof toolName !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for BeforeTool hook event: tool_name must be a string',
|
||||
);
|
||||
}
|
||||
if (!isObject(toolInput)) {
|
||||
throw new Error(
|
||||
'Invalid input for BeforeTool hook event: tool_input must be an object',
|
||||
);
|
||||
}
|
||||
return { toolName, toolInput };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates AfterTool input fields
|
||||
*/
|
||||
function validateAfterToolInput(input: Record<string, unknown>): {
|
||||
toolName: string;
|
||||
toolInput: Record<string, unknown>;
|
||||
toolResponse: Record<string, unknown>;
|
||||
} {
|
||||
const toolName = input['tool_name'];
|
||||
const toolInput = input['tool_input'];
|
||||
const toolResponse = input['tool_response'];
|
||||
if (typeof toolName !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for AfterTool hook event: tool_name must be a string',
|
||||
);
|
||||
}
|
||||
if (!isObject(toolInput)) {
|
||||
throw new Error(
|
||||
'Invalid input for AfterTool hook event: tool_input must be an object',
|
||||
);
|
||||
}
|
||||
if (!isObject(toolResponse)) {
|
||||
throw new Error(
|
||||
'Invalid input for AfterTool hook event: tool_response must be an object',
|
||||
);
|
||||
}
|
||||
return { toolName, toolInput, toolResponse };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates BeforeAgent input fields
|
||||
*/
|
||||
function validateBeforeAgentInput(input: Record<string, unknown>): {
|
||||
prompt: string;
|
||||
} {
|
||||
const prompt = input['prompt'];
|
||||
if (typeof prompt !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for BeforeAgent hook event: prompt must be a string',
|
||||
);
|
||||
}
|
||||
return { prompt };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates AfterAgent input fields
|
||||
*/
|
||||
function validateAfterAgentInput(input: Record<string, unknown>): {
|
||||
prompt: string;
|
||||
promptResponse: string;
|
||||
stopHookActive: boolean;
|
||||
} {
|
||||
const prompt = input['prompt'];
|
||||
const promptResponse = input['prompt_response'];
|
||||
const stopHookActive = input['stop_hook_active'];
|
||||
if (typeof prompt !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for AfterAgent hook event: prompt must be a string',
|
||||
);
|
||||
}
|
||||
if (typeof promptResponse !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for AfterAgent hook event: prompt_response must be a string',
|
||||
);
|
||||
}
|
||||
// stopHookActive defaults to false if not a boolean
|
||||
return {
|
||||
prompt,
|
||||
promptResponse,
|
||||
stopHookActive:
|
||||
typeof stopHookActive === 'boolean' ? stopHookActive : false,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates model-related input fields (llm_request)
|
||||
*/
|
||||
function validateModelInput(
|
||||
input: Record<string, unknown>,
|
||||
eventName: string,
|
||||
): { llmRequest: GenerateContentParameters } {
|
||||
const llmRequest = input['llm_request'];
|
||||
if (!isObject(llmRequest)) {
|
||||
throw new Error(
|
||||
`Invalid input for ${eventName} hook event: llm_request must be an object`,
|
||||
);
|
||||
}
|
||||
return { llmRequest: llmRequest as unknown as GenerateContentParameters };
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates AfterModel input fields
|
||||
*/
|
||||
function validateAfterModelInput(input: Record<string, unknown>): {
|
||||
llmRequest: GenerateContentParameters;
|
||||
llmResponse: GenerateContentResponse;
|
||||
} {
|
||||
const llmRequest = input['llm_request'];
|
||||
const llmResponse = input['llm_response'];
|
||||
if (!isObject(llmRequest)) {
|
||||
throw new Error(
|
||||
'Invalid input for AfterModel hook event: llm_request must be an object',
|
||||
);
|
||||
}
|
||||
if (!isObject(llmResponse)) {
|
||||
throw new Error(
|
||||
'Invalid input for AfterModel hook event: llm_response must be an object',
|
||||
);
|
||||
}
|
||||
return {
|
||||
llmRequest: llmRequest as unknown as GenerateContentParameters,
|
||||
llmResponse: llmResponse as unknown as GenerateContentResponse,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates Notification input fields
|
||||
*/
|
||||
function validateNotificationInput(input: Record<string, unknown>): {
|
||||
notificationType: NotificationType;
|
||||
message: string;
|
||||
details: Record<string, unknown>;
|
||||
} {
|
||||
const notificationType = input['notification_type'];
|
||||
const message = input['message'];
|
||||
const details = input['details'];
|
||||
if (typeof notificationType !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for Notification hook event: notification_type must be a string',
|
||||
);
|
||||
}
|
||||
if (typeof message !== 'string') {
|
||||
throw new Error(
|
||||
'Invalid input for Notification hook event: message must be a string',
|
||||
);
|
||||
}
|
||||
if (!isObject(details)) {
|
||||
throw new Error(
|
||||
'Invalid input for Notification hook event: details must be an object',
|
||||
);
|
||||
}
|
||||
return {
|
||||
notificationType: notificationType as NotificationType,
|
||||
message,
|
||||
details,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook event bus that coordinates hook execution across the system
|
||||
*/
|
||||
export class HookEventHandler {
|
||||
private readonly config: Config;
|
||||
private readonly hookPlanner: HookPlanner;
|
||||
private readonly hookRunner: HookRunner;
|
||||
private readonly hookAggregator: HookAggregator;
|
||||
private readonly messageBus?: MessageBus;
|
||||
|
||||
constructor(
|
||||
config: Config,
|
||||
logger: Logger,
|
||||
hookPlanner: HookPlanner,
|
||||
hookRunner: HookRunner,
|
||||
hookAggregator: HookAggregator,
|
||||
messageBus?: MessageBus,
|
||||
) {
|
||||
this.config = config;
|
||||
this.hookPlanner = hookPlanner;
|
||||
this.hookRunner = hookRunner;
|
||||
this.hookAggregator = hookAggregator;
|
||||
this.messageBus = messageBus;
|
||||
|
||||
// Subscribe to hook execution requests from MessageBus
|
||||
if (this.messageBus) {
|
||||
this.messageBus.subscribe<HookExecutionRequest>(
|
||||
MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
(request) => this.handleHookExecutionRequest(request),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a BeforeTool event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireBeforeToolEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: BeforeToolInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeTool),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
return await this.executeHooks(HookEventName.BeforeTool, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire an AfterTool event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireAfterToolEvent(
|
||||
toolName: string,
|
||||
toolInput: Record<string, unknown>,
|
||||
toolResponse: Record<string, unknown>,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: AfterToolInput = {
|
||||
...this.createBaseInput(HookEventName.AfterTool),
|
||||
tool_name: toolName,
|
||||
tool_input: toolInput,
|
||||
tool_response: toolResponse,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { toolName };
|
||||
return await this.executeHooks(HookEventName.AfterTool, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a BeforeAgent event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireBeforeAgentEvent(prompt: string): Promise<AggregatedHookResult> {
|
||||
const input: BeforeAgentInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeAgent),
|
||||
prompt,
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.BeforeAgent, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a Notification event
|
||||
*/
|
||||
async fireNotificationEvent(
|
||||
type: NotificationType,
|
||||
message: string,
|
||||
details: Record<string, unknown>,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: NotificationInput = {
|
||||
...this.createBaseInput(HookEventName.Notification),
|
||||
notification_type: type,
|
||||
message,
|
||||
details,
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.Notification, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire an AfterAgent event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireAfterAgentEvent(
|
||||
prompt: string,
|
||||
promptResponse: string,
|
||||
stopHookActive: boolean = false,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: AfterAgentInput = {
|
||||
...this.createBaseInput(HookEventName.AfterAgent),
|
||||
prompt,
|
||||
prompt_response: promptResponse,
|
||||
stop_hook_active: stopHookActive,
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.AfterAgent, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a SessionStart event
|
||||
*/
|
||||
async fireSessionStartEvent(
|
||||
source: SessionStartSource,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: SessionStartInput = {
|
||||
...this.createBaseInput(HookEventName.SessionStart),
|
||||
source,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger: source };
|
||||
return await this.executeHooks(HookEventName.SessionStart, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a SessionEnd event
|
||||
*/
|
||||
async fireSessionEndEvent(
|
||||
reason: SessionEndReason,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: SessionEndInput = {
|
||||
...this.createBaseInput(HookEventName.SessionEnd),
|
||||
reason,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger: reason };
|
||||
return await this.executeHooks(HookEventName.SessionEnd, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a PreCompress event
|
||||
*/
|
||||
async firePreCompressEvent(
|
||||
trigger: PreCompressTrigger,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: PreCompressInput = {
|
||||
...this.createBaseInput(HookEventName.PreCompress),
|
||||
trigger,
|
||||
};
|
||||
|
||||
const context: HookEventContext = { trigger };
|
||||
return await this.executeHooks(HookEventName.PreCompress, input, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a BeforeModel event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireBeforeModelEvent(
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: BeforeModelInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeModel),
|
||||
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.BeforeModel, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire an AfterModel event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireAfterModelEvent(
|
||||
llmRequest: GenerateContentParameters,
|
||||
llmResponse: GenerateContentResponse,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: AfterModelInput = {
|
||||
...this.createBaseInput(HookEventName.AfterModel),
|
||||
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
|
||||
llm_response: defaultHookTranslator.toHookLLMResponse(llmResponse),
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.AfterModel, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Fire a BeforeToolSelection event
|
||||
* Called by handleHookExecutionRequest - executes hooks directly
|
||||
*/
|
||||
async fireBeforeToolSelectionEvent(
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<AggregatedHookResult> {
|
||||
const input: BeforeToolSelectionInput = {
|
||||
...this.createBaseInput(HookEventName.BeforeToolSelection),
|
||||
llm_request: defaultHookTranslator.toHookLLMRequest(llmRequest),
|
||||
};
|
||||
|
||||
return await this.executeHooks(HookEventName.BeforeToolSelection, input);
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute hooks for a specific event (direct execution without MessageBus)
|
||||
* Used as fallback when MessageBus is not available
|
||||
*/
|
||||
private async executeHooks(
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
context?: HookEventContext,
|
||||
): Promise<AggregatedHookResult> {
|
||||
try {
|
||||
// Create execution plan
|
||||
const plan = this.hookPlanner.createExecutionPlan(eventName, context);
|
||||
|
||||
if (!plan || plan.hookConfigs.length === 0) {
|
||||
return {
|
||||
success: true,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
};
|
||||
}
|
||||
|
||||
// Execute hooks according to the plan's strategy
|
||||
const results = plan.sequential
|
||||
? await this.hookRunner.executeHooksSequential(
|
||||
plan.hookConfigs,
|
||||
eventName,
|
||||
input,
|
||||
)
|
||||
: await this.hookRunner.executeHooksParallel(
|
||||
plan.hookConfigs,
|
||||
eventName,
|
||||
input,
|
||||
);
|
||||
|
||||
// Aggregate results
|
||||
const aggregated = this.hookAggregator.aggregateResults(
|
||||
results,
|
||||
eventName,
|
||||
);
|
||||
|
||||
// Process common hook output fields centrally
|
||||
this.processCommonHookOutputFields(aggregated);
|
||||
|
||||
// Log hook execution
|
||||
this.logHookExecution(eventName, input, results, aggregated);
|
||||
|
||||
return aggregated;
|
||||
} catch (error) {
|
||||
debugLogger.error(`Hook event bus error for ${eventName}: ${error}`);
|
||||
|
||||
return {
|
||||
success: false,
|
||||
allOutputs: [],
|
||||
errors: [error instanceof Error ? error : new Error(String(error))],
|
||||
totalDuration: 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create base hook input with common fields
|
||||
*/
|
||||
private createBaseInput(eventName: HookEventName): HookInput {
|
||||
return {
|
||||
session_id: this.config.getSessionId(),
|
||||
transcript_path: '', // TODO: Implement transcript path when supported
|
||||
cwd: this.config.getWorkingDir(),
|
||||
hook_event_name: eventName,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Log hook execution for observability
|
||||
*/
|
||||
private logHookExecution(
|
||||
eventName: HookEventName,
|
||||
input: HookInput,
|
||||
results: HookExecutionResult[],
|
||||
aggregated: AggregatedHookResult,
|
||||
): void {
|
||||
const successCount = results.filter((r) => r.success).length;
|
||||
const errorCount = results.length - successCount;
|
||||
|
||||
if (errorCount > 0) {
|
||||
debugLogger.warn(
|
||||
`Hook execution for ${eventName}: ${successCount} succeeded, ${errorCount} failed, ` +
|
||||
`total duration: ${aggregated.totalDuration}ms`,
|
||||
);
|
||||
} else {
|
||||
debugLogger.debug(
|
||||
`Hook execution for ${eventName}: ${successCount} hooks executed successfully, ` +
|
||||
`total duration: ${aggregated.totalDuration}ms`,
|
||||
);
|
||||
}
|
||||
|
||||
// Log individual hook calls to telemetry
|
||||
for (const result of results) {
|
||||
// Determine hook name and type for telemetry
|
||||
const hookName = this.getHookNameFromResult(result);
|
||||
const hookType = this.getHookTypeFromResult(result);
|
||||
|
||||
const hookCallEvent = new HookCallEvent(
|
||||
eventName,
|
||||
hookType,
|
||||
hookName,
|
||||
{ ...input },
|
||||
result.duration,
|
||||
result.success,
|
||||
result.output ? { ...result.output } : undefined,
|
||||
result.exitCode,
|
||||
result.stdout,
|
||||
result.stderr,
|
||||
result.error?.message,
|
||||
);
|
||||
|
||||
logHookCall(this.config, hookCallEvent);
|
||||
}
|
||||
|
||||
// Log individual errors
|
||||
for (const error of aggregated.errors) {
|
||||
debugLogger.error(`Hook execution error: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process common hook output fields centrally
|
||||
*/
|
||||
private processCommonHookOutputFields(
|
||||
aggregated: AggregatedHookResult,
|
||||
): void {
|
||||
if (!aggregated.finalOutput) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle systemMessage - show to user in transcript mode (not to agent)
|
||||
const systemMessage = aggregated.finalOutput.systemMessage;
|
||||
if (systemMessage && !aggregated.finalOutput.suppressOutput) {
|
||||
debugLogger.warn(`Hook system message: ${systemMessage}`);
|
||||
}
|
||||
|
||||
// Handle suppressOutput - already handled by not logging above when true
|
||||
|
||||
// Handle continue=false - this should stop the entire agent execution
|
||||
if (aggregated.finalOutput.shouldStopExecution()) {
|
||||
const stopReason = aggregated.finalOutput.getEffectiveReason();
|
||||
debugLogger.log(`Hook requested to stop execution: ${stopReason}`);
|
||||
|
||||
// Note: The actual stopping of execution must be handled by integration points
|
||||
// as they need to interpret this signal in the context of their specific workflow
|
||||
// This is just logging the request centrally
|
||||
}
|
||||
|
||||
// Other common fields like decision/reason are handled by specific hook output classes
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook name from execution result for telemetry
|
||||
*/
|
||||
private getHookNameFromResult(result: HookExecutionResult): string {
|
||||
return result.hookConfig.command || 'unknown-command';
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook type from execution result for telemetry
|
||||
*/
|
||||
private getHookTypeFromResult(result: HookExecutionResult): 'command' {
|
||||
return result.hookConfig.type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle hook execution requests from MessageBus
|
||||
* This method routes the request to the appropriate fire*Event method
|
||||
* and publishes the response back through MessageBus
|
||||
*
|
||||
* The request input only contains event-specific fields. This method adds
|
||||
* the common base fields (session_id, cwd, etc.) before routing.
|
||||
*/
|
||||
private async handleHookExecutionRequest(
|
||||
request: HookExecutionRequest,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Add base fields to the input
|
||||
const enrichedInput = {
|
||||
...this.createBaseInput(request.eventName as HookEventName),
|
||||
...request.input,
|
||||
} as Record<string, unknown>;
|
||||
|
||||
let result: AggregatedHookResult;
|
||||
|
||||
// Route to appropriate event handler based on eventName
|
||||
switch (request.eventName) {
|
||||
case HookEventName.BeforeTool: {
|
||||
const { toolName, toolInput } =
|
||||
validateBeforeToolInput(enrichedInput);
|
||||
result = await this.fireBeforeToolEvent(toolName, toolInput);
|
||||
break;
|
||||
}
|
||||
case HookEventName.AfterTool: {
|
||||
const { toolName, toolInput, toolResponse } =
|
||||
validateAfterToolInput(enrichedInput);
|
||||
result = await this.fireAfterToolEvent(
|
||||
toolName,
|
||||
toolInput,
|
||||
toolResponse,
|
||||
);
|
||||
break;
|
||||
}
|
||||
case HookEventName.BeforeAgent: {
|
||||
const { prompt } = validateBeforeAgentInput(enrichedInput);
|
||||
result = await this.fireBeforeAgentEvent(prompt);
|
||||
break;
|
||||
}
|
||||
case HookEventName.AfterAgent: {
|
||||
const { prompt, promptResponse, stopHookActive } =
|
||||
validateAfterAgentInput(enrichedInput);
|
||||
result = await this.fireAfterAgentEvent(
|
||||
prompt,
|
||||
promptResponse,
|
||||
stopHookActive,
|
||||
);
|
||||
break;
|
||||
}
|
||||
case HookEventName.BeforeModel: {
|
||||
const { llmRequest } = validateModelInput(
|
||||
enrichedInput,
|
||||
'BeforeModel',
|
||||
);
|
||||
const translatedRequest =
|
||||
defaultHookTranslator.toHookLLMRequest(llmRequest);
|
||||
// Update the enrichedInput with translated request
|
||||
enrichedInput['llm_request'] = translatedRequest;
|
||||
result = await this.fireBeforeModelEvent(llmRequest);
|
||||
break;
|
||||
}
|
||||
case HookEventName.AfterModel: {
|
||||
const { llmRequest, llmResponse } =
|
||||
validateAfterModelInput(enrichedInput);
|
||||
const translatedRequest =
|
||||
defaultHookTranslator.toHookLLMRequest(llmRequest);
|
||||
const translatedResponse =
|
||||
defaultHookTranslator.toHookLLMResponse(llmResponse);
|
||||
// Update the enrichedInput with translated versions
|
||||
enrichedInput['llm_request'] = translatedRequest;
|
||||
enrichedInput['llm_response'] = translatedResponse;
|
||||
result = await this.fireAfterModelEvent(llmRequest, llmResponse);
|
||||
break;
|
||||
}
|
||||
case HookEventName.BeforeToolSelection: {
|
||||
const { llmRequest } = validateModelInput(
|
||||
enrichedInput,
|
||||
'BeforeToolSelection',
|
||||
);
|
||||
const translatedRequest =
|
||||
defaultHookTranslator.toHookLLMRequest(llmRequest);
|
||||
// Update the enrichedInput with translated request
|
||||
enrichedInput['llm_request'] = translatedRequest;
|
||||
result = await this.fireBeforeToolSelectionEvent(llmRequest);
|
||||
break;
|
||||
}
|
||||
case HookEventName.Notification: {
|
||||
const { notificationType, message, details } =
|
||||
validateNotificationInput(enrichedInput);
|
||||
result = await this.fireNotificationEvent(
|
||||
notificationType,
|
||||
message,
|
||||
details,
|
||||
);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unsupported hook event: ${request.eventName}`);
|
||||
}
|
||||
|
||||
// Publish response through MessageBus
|
||||
if (this.messageBus) {
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
success: result.success,
|
||||
output: result.finalOutput as unknown as Record<string, unknown>,
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
// Publish error response
|
||||
if (this.messageBus) {
|
||||
this.messageBus.publish({
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: request.correlationId,
|
||||
success: false,
|
||||
error: error instanceof Error ? error : new Error(String(error)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1250,4 +1250,291 @@ describe('PolicyEngine', () => {
|
||||
expect(result.decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
});
|
||||
|
||||
describe('checkHook', () => {
|
||||
it('should allow hooks by default', async () => {
|
||||
engine = new PolicyEngine({}, mockCheckerRunner);
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
expect(decision).toBe(PolicyDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should deny all hooks when allowHooks is false', async () => {
|
||||
engine = new PolicyEngine({ allowHooks: false }, mockCheckerRunner);
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
expect(decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
|
||||
it('should deny project hooks in untrusted folders', async () => {
|
||||
engine = new PolicyEngine({}, mockCheckerRunner);
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'project',
|
||||
trustedFolder: false,
|
||||
});
|
||||
expect(decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
|
||||
it('should allow project hooks in trusted folders', async () => {
|
||||
engine = new PolicyEngine({}, mockCheckerRunner);
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'project',
|
||||
trustedFolder: true,
|
||||
});
|
||||
expect(decision).toBe(PolicyDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should allow user hooks in untrusted folders', async () => {
|
||||
engine = new PolicyEngine({}, mockCheckerRunner);
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
trustedFolder: false,
|
||||
});
|
||||
expect(decision).toBe(PolicyDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should run hook checkers and deny on DENY decision', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
eventName: 'BeforeTool',
|
||||
checker: { type: 'external' as const, name: 'test-hook-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.DENY,
|
||||
reason: 'Hook checker denied',
|
||||
});
|
||||
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(decision).toBe(PolicyDecision.DENY);
|
||||
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ name: 'hook:BeforeTool' }),
|
||||
expect.objectContaining({ name: 'test-hook-checker' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should run hook checkers and allow on ALLOW decision', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
eventName: 'BeforeTool',
|
||||
checker: { type: 'external' as const, name: 'test-hook-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.ALLOW,
|
||||
});
|
||||
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(decision).toBe(PolicyDecision.ALLOW);
|
||||
});
|
||||
|
||||
it('should return ASK_USER when checker requests it', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
checker: { type: 'external' as const, name: 'test-hook-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.ASK_USER,
|
||||
reason: 'Needs confirmation',
|
||||
});
|
||||
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(decision).toBe(PolicyDecision.ASK_USER);
|
||||
});
|
||||
|
||||
it('should return DENY for ASK_USER in non-interactive mode', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
checker: { type: 'external' as const, name: 'test-hook-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine(
|
||||
{ hookCheckers, nonInteractive: true },
|
||||
mockCheckerRunner,
|
||||
);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.ASK_USER,
|
||||
reason: 'Needs confirmation',
|
||||
});
|
||||
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
|
||||
it('should match hook checkers by eventName', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
eventName: 'AfterTool',
|
||||
checker: { type: 'external' as const, name: 'after-tool-checker' },
|
||||
},
|
||||
{
|
||||
eventName: 'BeforeTool',
|
||||
checker: { type: 'external' as const, name: 'before-tool-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.ALLOW,
|
||||
});
|
||||
|
||||
await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ name: 'before-tool-checker' }),
|
||||
);
|
||||
expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ name: 'after-tool-checker' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should match hook checkers by hookSource', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
hookSource: 'project' as const,
|
||||
checker: { type: 'external' as const, name: 'project-checker' },
|
||||
},
|
||||
{
|
||||
hookSource: 'user' as const,
|
||||
checker: { type: 'external' as const, name: 'user-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockResolvedValue({
|
||||
decision: SafetyCheckDecision.ALLOW,
|
||||
});
|
||||
|
||||
await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ name: 'user-checker' }),
|
||||
);
|
||||
expect(mockCheckerRunner.runChecker).not.toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ name: 'project-checker' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should deny when hook checker throws an error', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
checker: { type: 'external' as const, name: 'failing-checker' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockRejectedValue(
|
||||
new Error('Checker failed'),
|
||||
);
|
||||
|
||||
const decision = await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
expect(decision).toBe(PolicyDecision.DENY);
|
||||
});
|
||||
|
||||
it('should run hook checkers in priority order', async () => {
|
||||
const hookCheckers = [
|
||||
{
|
||||
priority: 5,
|
||||
checker: { type: 'external' as const, name: 'low-priority' },
|
||||
},
|
||||
{
|
||||
priority: 20,
|
||||
checker: { type: 'external' as const, name: 'high-priority' },
|
||||
},
|
||||
{
|
||||
priority: 10,
|
||||
checker: { type: 'external' as const, name: 'medium-priority' },
|
||||
},
|
||||
];
|
||||
engine = new PolicyEngine({ hookCheckers }, mockCheckerRunner);
|
||||
|
||||
vi.mocked(mockCheckerRunner.runChecker).mockImplementation(
|
||||
async (_call, config) => {
|
||||
if (config.name === 'high-priority') {
|
||||
return { decision: SafetyCheckDecision.DENY, reason: 'denied' };
|
||||
}
|
||||
return { decision: SafetyCheckDecision.ALLOW };
|
||||
},
|
||||
);
|
||||
|
||||
await engine.checkHook({
|
||||
eventName: 'BeforeTool',
|
||||
hookSource: 'user',
|
||||
});
|
||||
|
||||
// Should only call the high-priority checker (first in sorted order)
|
||||
expect(mockCheckerRunner.runChecker).toHaveBeenCalledTimes(1);
|
||||
expect(mockCheckerRunner.runChecker).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ name: 'high-priority' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHookChecker', () => {
|
||||
it('should add a new hook checker and maintain priority order', () => {
|
||||
engine = new PolicyEngine({}, mockCheckerRunner);
|
||||
|
||||
engine.addHookChecker({
|
||||
priority: 5,
|
||||
checker: { type: 'external', name: 'checker1' },
|
||||
});
|
||||
engine.addHookChecker({
|
||||
priority: 10,
|
||||
checker: { type: 'external', name: 'checker2' },
|
||||
});
|
||||
|
||||
const checkers = engine.getHookCheckers();
|
||||
expect(checkers).toHaveLength(2);
|
||||
expect(checkers[0].priority).toBe(10);
|
||||
expect(checkers[0].checker.name).toBe('checker2');
|
||||
expect(checkers[1].priority).toBe(5);
|
||||
expect(checkers[1].checker.name).toBe('checker1');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,11 +10,15 @@ import {
|
||||
type PolicyEngineConfig,
|
||||
type PolicyRule,
|
||||
type SafetyCheckerRule,
|
||||
type HookCheckerRule,
|
||||
type HookExecutionContext,
|
||||
getHookSource,
|
||||
} from './types.js';
|
||||
import { stableStringify } from './stable-stringify.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { CheckerRunner } from '../safety/checker-runner.js';
|
||||
import { SafetyCheckDecision } from '../safety/protocol.js';
|
||||
import type { HookExecutionRequest } from '../confirmation-bus/types.js';
|
||||
|
||||
function ruleMatches(
|
||||
rule: PolicyRule | SafetyCheckerRule,
|
||||
@@ -61,12 +65,34 @@ function ruleMatches(
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a hook checker rule matches a hook execution context.
|
||||
*/
|
||||
function hookCheckerMatches(
|
||||
rule: HookCheckerRule,
|
||||
context: HookExecutionContext,
|
||||
): boolean {
|
||||
// Check event name if specified
|
||||
if (rule.eventName && rule.eventName !== context.eventName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check hook source if specified
|
||||
if (rule.hookSource && rule.hookSource !== context.hookSource) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
export class PolicyEngine {
|
||||
private rules: PolicyRule[];
|
||||
private checkers: SafetyCheckerRule[];
|
||||
private hookCheckers: HookCheckerRule[];
|
||||
private readonly defaultDecision: PolicyDecision;
|
||||
private readonly nonInteractive: boolean;
|
||||
private readonly checkerRunner?: CheckerRunner;
|
||||
private readonly allowHooks: boolean;
|
||||
|
||||
constructor(config: PolicyEngineConfig = {}, checkerRunner?: CheckerRunner) {
|
||||
this.rules = (config.rules ?? []).sort(
|
||||
@@ -75,9 +101,13 @@ export class PolicyEngine {
|
||||
this.checkers = (config.checkers ?? []).sort(
|
||||
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
|
||||
);
|
||||
this.hookCheckers = (config.hookCheckers ?? []).sort(
|
||||
(a, b) => (b.priority ?? 0) - (a.priority ?? 0),
|
||||
);
|
||||
this.defaultDecision = config.defaultDecision ?? PolicyDecision.ASK_USER;
|
||||
this.nonInteractive = config.nonInteractive ?? false;
|
||||
this.checkerRunner = checkerRunner;
|
||||
this.allowHooks = config.allowHooks ?? true;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -206,6 +236,99 @@ export class PolicyEngine {
|
||||
return this.checkers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a new hook checker to the policy engine.
|
||||
*/
|
||||
addHookChecker(checker: HookCheckerRule): void {
|
||||
this.hookCheckers.push(checker);
|
||||
this.hookCheckers.sort((a, b) => (b.priority ?? 0) - (a.priority ?? 0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all current hook checkers.
|
||||
*/
|
||||
getHookCheckers(): readonly HookCheckerRule[] {
|
||||
return this.hookCheckers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a hook execution is allowed based on the configured policies.
|
||||
* Runs hook-specific safety checkers if configured.
|
||||
*/
|
||||
async checkHook(
|
||||
request: HookExecutionRequest | HookExecutionContext,
|
||||
): Promise<PolicyDecision> {
|
||||
// If hooks are globally disabled, deny all hook executions
|
||||
if (!this.allowHooks) {
|
||||
return PolicyDecision.DENY;
|
||||
}
|
||||
|
||||
const context: HookExecutionContext =
|
||||
'input' in request
|
||||
? {
|
||||
eventName: request.eventName,
|
||||
hookSource: getHookSource(request.input),
|
||||
trustedFolder:
|
||||
typeof request.input['trusted_folder'] === 'boolean'
|
||||
? request.input['trusted_folder']
|
||||
: undefined,
|
||||
}
|
||||
: request;
|
||||
|
||||
// In untrusted folders, deny project-level hooks
|
||||
if (context.trustedFolder === false && context.hookSource === 'project') {
|
||||
return PolicyDecision.DENY;
|
||||
}
|
||||
|
||||
// Run hook-specific safety checkers if configured
|
||||
if (this.checkerRunner && this.hookCheckers.length > 0) {
|
||||
for (const checkerRule of this.hookCheckers) {
|
||||
if (hookCheckerMatches(checkerRule, context)) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkHook] Running hook checker: ${checkerRule.checker.name} for event: ${context.eventName}`,
|
||||
);
|
||||
try {
|
||||
// Create a synthetic function call for the checker runner
|
||||
// This allows reusing the existing checker infrastructure
|
||||
const syntheticCall = {
|
||||
name: `hook:${context.eventName}`,
|
||||
args: {
|
||||
hookSource: context.hookSource,
|
||||
trustedFolder: context.trustedFolder,
|
||||
},
|
||||
};
|
||||
|
||||
const result = await this.checkerRunner.runChecker(
|
||||
syntheticCall,
|
||||
checkerRule.checker,
|
||||
);
|
||||
|
||||
if (result.decision === SafetyCheckDecision.DENY) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkHook] Hook checker denied: ${result.reason}`,
|
||||
);
|
||||
return PolicyDecision.DENY;
|
||||
} else if (result.decision === SafetyCheckDecision.ASK_USER) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkHook] Hook checker requested ASK_USER: ${result.reason}`,
|
||||
);
|
||||
// For hooks, ASK_USER is treated as DENY in non-interactive mode
|
||||
return this.applyNonInteractiveMode(PolicyDecision.ASK_USER);
|
||||
}
|
||||
} catch (error) {
|
||||
debugLogger.debug(
|
||||
`[PolicyEngine.checkHook] Hook checker failed: ${error}`,
|
||||
);
|
||||
return PolicyDecision.DENY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default: Allow hooks
|
||||
return PolicyDecision.ALLOW;
|
||||
}
|
||||
|
||||
private applyNonInteractiveMode(decision: PolicyDecision): PolicyDecision {
|
||||
// In non-interactive mode, ASK_USER becomes DENY
|
||||
if (this.nonInteractive && decision === PolicyDecision.ASK_USER) {
|
||||
|
||||
@@ -12,6 +12,36 @@ export enum PolicyDecision {
|
||||
ASK_USER = 'ask_user',
|
||||
}
|
||||
|
||||
/**
|
||||
* Valid sources for hook execution
|
||||
*/
|
||||
export type HookSource = 'project' | 'user' | 'system' | 'extension';
|
||||
|
||||
/**
|
||||
* Array of valid hook source values for runtime validation
|
||||
*/
|
||||
const VALID_HOOK_SOURCES: HookSource[] = [
|
||||
'project',
|
||||
'user',
|
||||
'system',
|
||||
'extension',
|
||||
];
|
||||
|
||||
/**
|
||||
* Safely extract and validate hook source from input
|
||||
* Returns 'project' as default if the value is invalid or missing
|
||||
*/
|
||||
export function getHookSource(input: Record<string, unknown>): HookSource {
|
||||
const source = input['hook_source'];
|
||||
if (
|
||||
typeof source === 'string' &&
|
||||
VALID_HOOK_SOURCES.includes(source as HookSource)
|
||||
) {
|
||||
return source as HookSource;
|
||||
}
|
||||
return 'project';
|
||||
}
|
||||
|
||||
export enum ApprovalMode {
|
||||
DEFAULT = 'default',
|
||||
AUTO_EDIT = 'autoEdit',
|
||||
@@ -115,6 +145,42 @@ export interface SafetyCheckerRule {
|
||||
checker: SafetyCheckerConfig;
|
||||
}
|
||||
|
||||
export interface HookExecutionContext {
|
||||
eventName: string;
|
||||
hookSource?: HookSource;
|
||||
trustedFolder?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Rule for applying safety checkers to hook executions.
|
||||
* Similar to SafetyCheckerRule but with hook-specific matching criteria.
|
||||
*/
|
||||
export interface HookCheckerRule {
|
||||
/**
|
||||
* The name of the hook event this rule applies to.
|
||||
* If undefined, the rule applies to all hook events.
|
||||
*/
|
||||
eventName?: string;
|
||||
|
||||
/**
|
||||
* The source of hooks this rule applies to.
|
||||
* If undefined, the rule applies to all hook sources.
|
||||
*/
|
||||
hookSource?: HookSource;
|
||||
|
||||
/**
|
||||
* Priority of this checker. Higher numbers run first.
|
||||
* Default is 0.
|
||||
*/
|
||||
priority?: number;
|
||||
|
||||
/**
|
||||
* Specifies an external or built-in safety checker to execute for
|
||||
* additional validation of a hook execution.
|
||||
*/
|
||||
checker: SafetyCheckerConfig;
|
||||
}
|
||||
|
||||
export interface PolicyEngineConfig {
|
||||
/**
|
||||
* List of policy rules to apply.
|
||||
@@ -122,10 +188,15 @@ export interface PolicyEngineConfig {
|
||||
rules?: PolicyRule[];
|
||||
|
||||
/**
|
||||
* List of safety checkers to apply.
|
||||
* List of safety checkers to apply to tool calls.
|
||||
*/
|
||||
checkers?: SafetyCheckerRule[];
|
||||
|
||||
/**
|
||||
* List of safety checkers to apply to hook executions.
|
||||
*/
|
||||
hookCheckers?: HookCheckerRule[];
|
||||
|
||||
/**
|
||||
* Default decision when no rules match.
|
||||
* Defaults to ASK_USER.
|
||||
@@ -137,6 +208,13 @@ export interface PolicyEngineConfig {
|
||||
* When true, ASK_USER decisions become DENY.
|
||||
*/
|
||||
nonInteractive?: boolean;
|
||||
|
||||
/**
|
||||
* Whether to allow hooks to execute.
|
||||
* When false, all hooks are denied.
|
||||
* Defaults to true.
|
||||
*/
|
||||
allowHooks?: boolean;
|
||||
}
|
||||
|
||||
export interface PolicySettings {
|
||||
|
||||
177
packages/core/src/test-utils/mock-message-bus.ts
Normal file
177
packages/core/src/test-utils/mock-message-bus.ts
Normal file
@@ -0,0 +1,177 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi } from 'vitest';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type Message,
|
||||
type HookExecutionRequest,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
|
||||
/**
|
||||
* Mock MessageBus for testing hook execution through MessageBus
|
||||
*/
|
||||
export class MockMessageBus {
|
||||
private subscriptions = new Map<
|
||||
MessageBusType,
|
||||
Set<(message: Message) => void>
|
||||
>();
|
||||
publishedMessages: Message[] = [];
|
||||
hookRequests: HookExecutionRequest[] = [];
|
||||
hookResponses: HookExecutionResponse[] = [];
|
||||
|
||||
/**
|
||||
* Mock publish method that captures messages and simulates responses
|
||||
*/
|
||||
publish = vi.fn((message: Message) => {
|
||||
this.publishedMessages.push(message);
|
||||
|
||||
// Capture hook-specific messages
|
||||
if (message.type === MessageBusType.HOOK_EXECUTION_REQUEST) {
|
||||
this.hookRequests.push(message as HookExecutionRequest);
|
||||
|
||||
// Auto-respond with success for testing
|
||||
const response: HookExecutionResponse = {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId: (message as HookExecutionRequest).correlationId,
|
||||
success: true,
|
||||
output: {
|
||||
decision: 'allow',
|
||||
reason: 'Mock hook execution successful',
|
||||
},
|
||||
};
|
||||
this.hookResponses.push(response);
|
||||
|
||||
// Emit response to subscribers
|
||||
this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response);
|
||||
}
|
||||
});
|
||||
|
||||
/**
|
||||
* Mock subscribe method that stores listeners
|
||||
*/
|
||||
subscribe = vi.fn(
|
||||
<T extends Message>(type: T['type'], listener: (message: T) => void) => {
|
||||
if (!this.subscriptions.has(type)) {
|
||||
this.subscriptions.set(type, new Set());
|
||||
}
|
||||
this.subscriptions.get(type)!.add(listener as (message: Message) => void);
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* Mock unsubscribe method
|
||||
*/
|
||||
unsubscribe = vi.fn(
|
||||
<T extends Message>(type: T['type'], listener: (message: T) => void) => {
|
||||
const listeners = this.subscriptions.get(type);
|
||||
if (listeners) {
|
||||
listeners.delete(listener as (message: Message) => void);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
/**
|
||||
* Emit a message to subscribers (for testing)
|
||||
*/
|
||||
private emit(type: MessageBusType, message: Message) {
|
||||
const listeners = this.subscriptions.get(type);
|
||||
if (listeners) {
|
||||
listeners.forEach((listener) => listener(message));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Manually trigger a hook response (for testing custom scenarios)
|
||||
*/
|
||||
triggerHookResponse(
|
||||
correlationId: string,
|
||||
success: boolean,
|
||||
output?: Record<string, unknown>,
|
||||
error?: Error,
|
||||
) {
|
||||
const response: HookExecutionResponse = {
|
||||
type: MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
correlationId,
|
||||
success,
|
||||
output,
|
||||
error,
|
||||
};
|
||||
this.hookResponses.push(response);
|
||||
this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the last hook request published
|
||||
*/
|
||||
getLastHookRequest(): HookExecutionRequest | undefined {
|
||||
return this.hookRequests[this.hookRequests.length - 1];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all hook requests for a specific event
|
||||
*/
|
||||
getHookRequestsForEvent(eventName: string): HookExecutionRequest[] {
|
||||
return this.hookRequests.filter((req) => req.eventName === eventName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all captured messages (for test isolation)
|
||||
*/
|
||||
clear() {
|
||||
this.publishedMessages = [];
|
||||
this.hookRequests = [];
|
||||
this.hookResponses = [];
|
||||
this.subscriptions.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify that a hook execution request was published
|
||||
*/
|
||||
expectHookRequest(
|
||||
eventName: string,
|
||||
input?: Partial<Record<string, unknown>>,
|
||||
) {
|
||||
const request = this.hookRequests.find(
|
||||
(req) => req.eventName === eventName,
|
||||
);
|
||||
if (!request) {
|
||||
throw new Error(
|
||||
`Expected hook request for event "${eventName}" but none was found`,
|
||||
);
|
||||
}
|
||||
|
||||
if (input) {
|
||||
Object.entries(input).forEach(([key, value]) => {
|
||||
if (request.input[key] !== value) {
|
||||
throw new Error(
|
||||
`Expected hook input.${key} to be ${JSON.stringify(value)} but got ${JSON.stringify(request.input[key])}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a mock MessageBus for testing
|
||||
*/
|
||||
export function createMockMessageBus(): MessageBus {
|
||||
return new MockMessageBus() as unknown as MessageBus;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the MockMessageBus instance from a mocked MessageBus
|
||||
*/
|
||||
export function getMockMessageBusInstance(
|
||||
messageBus: MessageBus,
|
||||
): MockMessageBus {
|
||||
return messageBus as unknown as MockMessageBus;
|
||||
}
|
||||
Reference in New Issue
Block a user