mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-28 06:50:35 -07:00
feat(hooks): Hook Agent Lifecycle Integration (#9105)
This commit is contained in:
@@ -14,7 +14,9 @@ import {
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
GeminiClient,
|
||||
HookSystem,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
|
||||
import type { Config, Storage } from '@google/gemini-cli-core';
|
||||
import { expect, vi } from 'vitest';
|
||||
|
||||
@@ -54,8 +56,13 @@ export function createMockConfig(
|
||||
getMessageBus: vi.fn(),
|
||||
getPolicyEngine: vi.fn(),
|
||||
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
...overrides,
|
||||
} as unknown as Config;
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
|
||||
mockConfig.getGeminiClient = vi
|
||||
.fn()
|
||||
|
||||
@@ -32,7 +32,9 @@ import {
|
||||
ToolConfirmationOutcome,
|
||||
ApprovalMode,
|
||||
MockTool,
|
||||
HookSystem,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
|
||||
import { ToolCallStatus } from '../types.js';
|
||||
|
||||
// Mocks
|
||||
@@ -81,7 +83,10 @@ const mockConfig = {
|
||||
getPolicyEngine: () => null,
|
||||
isInteractive: () => false,
|
||||
getExperiments: () => {},
|
||||
getEnableHooks: () => false,
|
||||
} as unknown as Config;
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
|
||||
mockConfig.getHookSystem = vi.fn().mockReturnValue(new HookSystem(mockConfig));
|
||||
|
||||
const mockTool = new MockTool({
|
||||
name: 'mockTool',
|
||||
|
||||
@@ -76,6 +76,7 @@ import type { EventEmitter } from 'node:events';
|
||||
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import { PolicyEngine } from '../policy/policy-engine.js';
|
||||
import type { PolicyEngineConfig } from '../policy/types.js';
|
||||
import { HookSystem } from '../hooks/index.js';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
import { getCodeAssistServer } from '../code_assist/codeAssist.js';
|
||||
import type { Experiments } from '../code_assist/experiments/experiments.js';
|
||||
@@ -415,6 +416,7 @@ export class Config {
|
||||
| undefined;
|
||||
private experiments: Experiments | undefined;
|
||||
private experimentsPromise: Promise<void> | undefined;
|
||||
private hookSystem?: HookSystem;
|
||||
|
||||
private previewModelFallbackMode = false;
|
||||
private previewModelBypassMode = false;
|
||||
@@ -627,6 +629,12 @@ export class Config {
|
||||
await this.getExtensionLoader().start(this),
|
||||
]);
|
||||
|
||||
// Initialize hook system if enabled
|
||||
if (this.enableHooks) {
|
||||
this.hookSystem = new HookSystem(this);
|
||||
await this.hookSystem.initialize();
|
||||
}
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
}
|
||||
|
||||
@@ -1475,6 +1483,13 @@ export class Config {
|
||||
return registry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the hook system instance
|
||||
*/
|
||||
getHookSystem(): HookSystem | undefined {
|
||||
return this.hookSystem;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hooks configuration
|
||||
*/
|
||||
|
||||
@@ -43,6 +43,7 @@ import type {
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
|
||||
@@ -120,6 +121,7 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||
getLastPromptTokenCount: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock('../hooks/hookSystem.js');
|
||||
|
||||
/**
|
||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||
@@ -211,6 +213,8 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getModelRouterService: vi.fn().mockReturnValue({
|
||||
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
|
||||
}),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
setFallbackMode: vi.fn(),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
@@ -243,6 +247,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getExperiments: () => {},
|
||||
} as unknown as Config;
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
|
||||
client = new GeminiClient(mockConfig);
|
||||
await client.initialize();
|
||||
|
||||
@@ -42,6 +42,10 @@ import {
|
||||
logContentRetryFailure,
|
||||
logNextSpeakerCheck,
|
||||
} from '../telemetry/loggers.js';
|
||||
import {
|
||||
fireBeforeAgentHook,
|
||||
fireAfterAgentHook,
|
||||
} from './clientHookTriggers.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
@@ -438,6 +442,35 @@ export class GeminiClient {
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Add additional context from hooks to the request
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
@@ -608,9 +641,9 @@ export class GeminiClient {
|
||||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, but the final
|
||||
// turn object will be from the top-level call.
|
||||
yield* this.sendMessageStream(
|
||||
// This recursive call's events will be yielded out, and the final
|
||||
// turn object from the recursive call will be returned.
|
||||
return yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
@@ -619,6 +652,32 @@ export class GeminiClient {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
const responseText = turn.getResponseText() || '[no response text]';
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
request,
|
||||
responseText,
|
||||
);
|
||||
|
||||
// For AfterAgent hooks, blocking/stop execution should force continuation
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return turn;
|
||||
}
|
||||
|
||||
|
||||
105
packages/core/src/core/clientHookTriggers.ts
Normal file
105
packages/core/src/core/clientHookTriggers.ts
Normal file
@@ -0,0 +1,105 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { PartListUnion } from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionRequest,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import { createHookOutput, type DefaultHookOutput } from '../hooks/types.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Fires the BeforeAgent hook and returns the hook output.
|
||||
* This should be called before processing a user prompt.
|
||||
*
|
||||
* The caller can use the returned DefaultHookOutput methods:
|
||||
* - isBlockingDecision() / shouldStopExecution() to check if blocked
|
||||
* - getEffectiveReason() to get the blocking reason
|
||||
* - getAdditionalContext() to get additional context to add
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param request The user's request (prompt)
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireBeforeAgentHook(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const promptText = partToString(request);
|
||||
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'BeforeAgent',
|
||||
input: {
|
||||
prompt: promptText,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
return response.output
|
||||
? createHookOutput('BeforeAgent', response.output)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`BeforeAgent hook failed: ${error}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the AfterAgent hook and returns the hook output.
|
||||
* This should be called after the agent has generated a response.
|
||||
*
|
||||
* The caller can use the returned DefaultHookOutput methods:
|
||||
* - isBlockingDecision() / shouldStopExecution() to check if continuation is requested
|
||||
* - getEffectiveReason() to get the continuation reason
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param request The original user's request (prompt)
|
||||
* @param responseText The agent's response text
|
||||
* @returns The hook output, or undefined if no hook was executed or on error
|
||||
*/
|
||||
export async function fireAfterAgentHook(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
responseText: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
try {
|
||||
const promptText = partToString(request);
|
||||
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'AfterAgent',
|
||||
input: {
|
||||
prompt: promptText,
|
||||
prompt_response: responseText,
|
||||
stop_hook_active: false,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
return response.output
|
||||
? createHookOutput('AfterAgent', response.output)
|
||||
: undefined;
|
||||
} catch (error) {
|
||||
debugLogger.warn(`AfterAgent hook failed: ${error}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,11 @@ import {
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
ToolErrorType,
|
||||
ApprovalMode,
|
||||
HookSystem,
|
||||
} from '../index.js';
|
||||
import type { Part } from '@google/genai';
|
||||
import { MockTool } from '../test-utils/mock-tool.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
describe('executeToolCall', () => {
|
||||
let mockToolRegistry: ToolRegistry;
|
||||
@@ -66,8 +68,15 @@ describe('executeToolCall', () => {
|
||||
getPolicyEngine: () => null,
|
||||
isInteractive: () => false,
|
||||
getExperiments: () => {},
|
||||
getEnableHooks: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
// Use proper MessageBus mocking for Phase 3 preparation
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
abortController = new AbortController();
|
||||
});
|
||||
|
||||
|
||||
@@ -392,6 +392,17 @@ export class Turn {
|
||||
getDebugResponses(): GenerateContentResponse[] {
|
||||
return this.debugResponses;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the concatenated response text from all responses in this turn.
|
||||
* This extracts and joins all text content from the model's responses.
|
||||
*/
|
||||
getResponseText(): string {
|
||||
return this.debugResponses
|
||||
.map((response) => getResponseText(response))
|
||||
.filter((text): text is string => text !== null)
|
||||
.join(' ');
|
||||
}
|
||||
}
|
||||
|
||||
function getCitations(resp: GenerateContentResponse): string[] {
|
||||
|
||||
280
packages/core/src/hooks/hookSystem.test.ts
Normal file
280
packages/core/src/hooks/hookSystem.test.ts
Normal file
@@ -0,0 +1,280 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { HookSystem } from './hookSystem.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { HookType } from './types.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import type { ChildProcessWithoutNullStreams } from 'node:child_process';
|
||||
import type { Readable, Writable } from 'node:stream';
|
||||
|
||||
// Mock type for the child_process spawn
|
||||
type MockChildProcessWithoutNullStreams = ChildProcessWithoutNullStreams & {
|
||||
mockStdoutOn: ReturnType<typeof vi.fn>;
|
||||
mockStderrOn: ReturnType<typeof vi.fn>;
|
||||
mockProcessOn: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
|
||||
// Mock child_process with importOriginal for partial mocking
|
||||
vi.mock('node:child_process', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as object;
|
||||
return {
|
||||
...actual,
|
||||
spawn: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock debugLogger - use vi.hoisted to define mock before it's used in vi.mock
|
||||
const mockDebugLogger = vi.hoisted(() => ({
|
||||
debug: vi.fn(),
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: mockDebugLogger,
|
||||
}));
|
||||
|
||||
// Mock console methods
|
||||
const mockConsole = {
|
||||
log: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
};
|
||||
|
||||
vi.stubGlobal('console', mockConsole);
|
||||
|
||||
describe('HookSystem Integration', () => {
|
||||
let hookSystem: HookSystem;
|
||||
let config: Config;
|
||||
let mockSpawn: MockChildProcessWithoutNullStreams;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
// Create a real config with simple command hook configurations for testing
|
||||
config = new Config({
|
||||
model: 'gemini-1.5-flash',
|
||||
targetDir: '/tmp/test-hooks',
|
||||
sessionId: 'test-session',
|
||||
debugMode: false,
|
||||
cwd: '/tmp/test-hooks',
|
||||
hooks: {
|
||||
BeforeTool: [
|
||||
{
|
||||
matcher: 'TestTool',
|
||||
hooks: [
|
||||
{
|
||||
type: HookType.Command,
|
||||
command: 'echo',
|
||||
timeout: 5000,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
// Provide getMessageBus mock for MessageBus integration tests
|
||||
(config as unknown as { getMessageBus: () => unknown }).getMessageBus =
|
||||
() => undefined;
|
||||
|
||||
hookSystem = new HookSystem(config);
|
||||
|
||||
// Set up spawn mock with accessible mock functions
|
||||
const mockStdoutOn = vi.fn();
|
||||
const mockStderrOn = vi.fn();
|
||||
const mockProcessOn = vi.fn();
|
||||
|
||||
mockSpawn = {
|
||||
stdin: {
|
||||
write: vi.fn(),
|
||||
end: vi.fn(),
|
||||
} as unknown as Writable,
|
||||
stdout: {
|
||||
on: mockStdoutOn,
|
||||
} as unknown as Readable,
|
||||
stderr: {
|
||||
on: mockStderrOn,
|
||||
} as unknown as Readable,
|
||||
on: mockProcessOn,
|
||||
kill: vi.fn(),
|
||||
killed: false,
|
||||
mockStdoutOn,
|
||||
mockStderrOn,
|
||||
mockProcessOn,
|
||||
} as unknown as MockChildProcessWithoutNullStreams;
|
||||
|
||||
vi.mocked(spawn).mockReturnValue(mockSpawn);
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
// No cleanup needed
|
||||
});
|
||||
|
||||
describe('initialize', () => {
|
||||
it('should initialize successfully', async () => {
|
||||
await hookSystem.initialize();
|
||||
|
||||
expect(mockDebugLogger.debug).toHaveBeenCalledWith(
|
||||
'Hook system initialized successfully',
|
||||
);
|
||||
|
||||
// Verify system is initialized
|
||||
const status = hookSystem.getStatus();
|
||||
expect(status.initialized).toBe(true);
|
||||
// Note: totalHooks might be 0 if hook validation rejects the test hooks
|
||||
});
|
||||
|
||||
it('should not initialize twice', async () => {
|
||||
await hookSystem.initialize();
|
||||
await hookSystem.initialize(); // Second call should be no-op
|
||||
|
||||
// The system logs both registry initialization and system initialization
|
||||
expect(mockDebugLogger.debug).toHaveBeenCalledWith(
|
||||
'Hook system initialized successfully',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle initialization errors gracefully', async () => {
|
||||
// Create a config with invalid hooks to trigger initialization errors
|
||||
const invalidConfig = new Config({
|
||||
model: 'gemini-1.5-flash',
|
||||
targetDir: '/tmp/test-hooks-invalid',
|
||||
sessionId: 'test-session-invalid',
|
||||
debugMode: false,
|
||||
cwd: '/tmp/test-hooks-invalid',
|
||||
hooks: {
|
||||
BeforeTool: [
|
||||
{
|
||||
hooks: [
|
||||
{
|
||||
type: 'invalid-type' as HookType, // Invalid hook type for testing
|
||||
command: './test.sh',
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
const invalidHookSystem = new HookSystem(invalidConfig);
|
||||
|
||||
// Should not throw, but should log warnings via debugLogger
|
||||
await invalidHookSystem.initialize();
|
||||
|
||||
expect(mockDebugLogger.warn).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getEventHandler', () => {
|
||||
it('should return event bus when initialized', async () => {
|
||||
await hookSystem.initialize();
|
||||
|
||||
// Set up spawn mock behavior for successful execution
|
||||
mockSpawn.mockStdoutOn.mockImplementation(
|
||||
(event: string, callback: (data: Buffer) => void) => {
|
||||
if (event === 'data') {
|
||||
setTimeout(() => callback(Buffer.from('')), 5); // echo outputs empty
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
mockSpawn.mockProcessOn.mockImplementation(
|
||||
(event: string, callback: (code: number) => void) => {
|
||||
if (event === 'close') {
|
||||
setTimeout(() => callback(0), 10);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
const eventBus = hookSystem.getEventHandler();
|
||||
expect(eventBus).toBeDefined();
|
||||
|
||||
// Test that the event bus can actually fire events
|
||||
const result = await eventBus.fireBeforeToolEvent('TestTool', {
|
||||
test: 'data',
|
||||
});
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should throw error when not initialized', () => {
|
||||
expect(() => hookSystem.getEventHandler()).toThrow(
|
||||
'Hook system not initialized',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('hook execution', () => {
|
||||
it('should execute hooks and return results', async () => {
|
||||
await hookSystem.initialize();
|
||||
|
||||
// Set up spawn mock behavior for successful execution
|
||||
mockSpawn.mockStdoutOn.mockImplementation(
|
||||
(event: string, callback: (data: Buffer) => void) => {
|
||||
if (event === 'data') {
|
||||
setTimeout(() => callback(Buffer.from('')), 5); // echo outputs empty
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
mockSpawn.mockProcessOn.mockImplementation(
|
||||
(event: string, callback: (code: number) => void) => {
|
||||
if (event === 'close') {
|
||||
setTimeout(() => callback(0), 10);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
const eventBus = hookSystem.getEventHandler();
|
||||
|
||||
// Test BeforeTool event with command hook
|
||||
const result = await eventBus.fireBeforeToolEvent('TestTool', {
|
||||
test: 'data',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
// Command hooks with echo should succeed but may not have specific decisions
|
||||
expect(result.errors).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('should handle no matching hooks', async () => {
|
||||
await hookSystem.initialize();
|
||||
|
||||
const eventBus = hookSystem.getEventHandler();
|
||||
|
||||
// Test with a tool that doesn't match any hooks
|
||||
const result = await eventBus.fireBeforeToolEvent('UnmatchedTool', {
|
||||
test: 'data',
|
||||
});
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.allOutputs).toHaveLength(0);
|
||||
expect(result.finalOutput).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('system management', () => {
|
||||
it('should return correct status when initialized', async () => {
|
||||
await hookSystem.initialize();
|
||||
|
||||
const status = hookSystem.getStatus();
|
||||
|
||||
expect(status.initialized).toBe(true);
|
||||
// Note: totalHooks might be 0 if hook validation rejects the test hooks
|
||||
expect(typeof status.totalHooks).toBe('number');
|
||||
});
|
||||
|
||||
it('should return uninitialized status', () => {
|
||||
const status = hookSystem.getStatus();
|
||||
|
||||
expect(status.initialized).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
106
packages/core/src/hooks/hookSystem.ts
Normal file
106
packages/core/src/hooks/hookSystem.ts
Normal file
@@ -0,0 +1,106 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { HookRegistry } from './hookRegistry.js';
|
||||
import { HookRunner } from './hookRunner.js';
|
||||
import { HookAggregator } from './hookAggregator.js';
|
||||
import { HookPlanner } from './hookPlanner.js';
|
||||
import { HookEventHandler } from './hookEventHandler.js';
|
||||
import type { HookRegistryEntry } from './hookRegistry.js';
|
||||
import { logs, type Logger } from '@opentelemetry/api-logs';
|
||||
import { SERVICE_NAME } from '../telemetry/constants.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Main hook system that coordinates all hook-related functionality
|
||||
*/
|
||||
export class HookSystem {
|
||||
private readonly hookRegistry: HookRegistry;
|
||||
private readonly hookRunner: HookRunner;
|
||||
private readonly hookAggregator: HookAggregator;
|
||||
private readonly hookPlanner: HookPlanner;
|
||||
private readonly hookEventHandler: HookEventHandler;
|
||||
private initialized = false;
|
||||
|
||||
constructor(config: Config) {
|
||||
const logger: Logger = logs.getLogger(SERVICE_NAME);
|
||||
const messageBus = config.getMessageBus();
|
||||
|
||||
// Initialize components
|
||||
this.hookRegistry = new HookRegistry(config);
|
||||
this.hookRunner = new HookRunner();
|
||||
this.hookAggregator = new HookAggregator();
|
||||
this.hookPlanner = new HookPlanner(this.hookRegistry);
|
||||
this.hookEventHandler = new HookEventHandler(
|
||||
config,
|
||||
logger,
|
||||
this.hookPlanner,
|
||||
this.hookRunner,
|
||||
this.hookAggregator,
|
||||
messageBus, // Pass MessageBus to enable mediated hook execution
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the hook system
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
if (this.initialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
await this.hookRegistry.initialize();
|
||||
this.initialized = true;
|
||||
debugLogger.debug('Hook system initialized successfully');
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the hook event bus for firing events
|
||||
*/
|
||||
getEventHandler(): HookEventHandler {
|
||||
if (!this.initialized) {
|
||||
throw new Error('Hook system not initialized');
|
||||
}
|
||||
return this.hookEventHandler;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook registry for management operations
|
||||
*/
|
||||
getRegistry(): HookRegistry {
|
||||
return this.hookRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable or disable a hook
|
||||
*/
|
||||
setHookEnabled(hookName: string, enabled: boolean): void {
|
||||
this.hookRegistry.setHookEnabled(hookName, enabled);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered hooks for display/management
|
||||
*/
|
||||
getAllHooks(): HookRegistryEntry[] {
|
||||
return this.hookRegistry.getAllHooks();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hook system status for debugging
|
||||
*/
|
||||
getStatus(): {
|
||||
initialized: boolean;
|
||||
totalHooks: number;
|
||||
} {
|
||||
const allHooks = this.initialized ? this.hookRegistry.getAllHooks() : [];
|
||||
|
||||
return {
|
||||
initialized: this.initialized,
|
||||
totalHooks: allHooks.length,
|
||||
};
|
||||
}
|
||||
}
|
||||
21
packages/core/src/hooks/index.ts
Normal file
21
packages/core/src/hooks/index.ts
Normal file
@@ -0,0 +1,21 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
// Export types
|
||||
export * from './types.js';
|
||||
|
||||
// Export core components
|
||||
export { HookSystem } from './hookSystem.js';
|
||||
export { HookRegistry } from './hookRegistry.js';
|
||||
export { HookRunner } from './hookRunner.js';
|
||||
export { HookAggregator } from './hookAggregator.js';
|
||||
export { HookPlanner } from './hookPlanner.js';
|
||||
export { HookEventHandler } from './hookEventHandler.js';
|
||||
|
||||
// Export interfaces
|
||||
export type { HookRegistryEntry, ConfigSource } from './hookRegistry.js';
|
||||
export type { AggregatedHookResult } from './hookAggregator.js';
|
||||
export type { HookEventContext } from './hookPlanner.js';
|
||||
@@ -139,6 +139,9 @@ export { sessionId } from './utils/session.js';
|
||||
export * from './utils/browser.js';
|
||||
export { Storage } from './config/storage.js';
|
||||
|
||||
// Export hooks system
|
||||
export * from './hooks/index.js';
|
||||
|
||||
// Export test utils
|
||||
export * from './test-utils/index.js';
|
||||
|
||||
|
||||
Reference in New Issue
Block a user