feat(hooks): Hook Agent Lifecycle Integration (#9105)

This commit is contained in:
Edilmo Palencia
2025-11-24 14:31:48 -08:00
committed by GitHub
parent 2034098780
commit 5411f4a667
12 changed files with 631 additions and 3 deletions

View File

@@ -14,7 +14,9 @@ import {
DEFAULT_TRUNCATE_TOOL_OUTPUT_LINES,
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
GeminiClient,
HookSystem,
} from '@google/gemini-cli-core';
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
import type { Config, Storage } from '@google/gemini-cli-core';
import { expect, vi } from 'vitest';
@@ -54,8 +56,13 @@ export function createMockConfig(
getMessageBus: vi.fn(),
getPolicyEngine: vi.fn(),
getEnableExtensionReloading: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false),
...overrides,
} as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
mockConfig.getGeminiClient = vi
.fn()

View File

@@ -32,7 +32,9 @@ import {
ToolConfirmationOutcome,
ApprovalMode,
MockTool,
HookSystem,
} from '@google/gemini-cli-core';
import { createMockMessageBus } from '@google/gemini-cli-core/src/test-utils/mock-message-bus.js';
import { ToolCallStatus } from '../types.js';
// Mocks
@@ -81,7 +83,10 @@ const mockConfig = {
getPolicyEngine: () => null,
isInteractive: () => false,
getExperiments: () => {},
getEnableHooks: () => false,
} as unknown as Config;
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
mockConfig.getHookSystem = vi.fn().mockReturnValue(new HookSystem(mockConfig));
const mockTool = new MockTool({
name: 'mockTool',

View File

@@ -76,6 +76,7 @@ import type { EventEmitter } from 'node:events';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { PolicyEngine } from '../policy/policy-engine.js';
import type { PolicyEngineConfig } from '../policy/types.js';
import { HookSystem } from '../hooks/index.js';
import type { UserTierId } from '../code_assist/types.js';
import { getCodeAssistServer } from '../code_assist/codeAssist.js';
import type { Experiments } from '../code_assist/experiments/experiments.js';
@@ -415,6 +416,7 @@ export class Config {
| undefined;
private experiments: Experiments | undefined;
private experimentsPromise: Promise<void> | undefined;
private hookSystem?: HookSystem;
private previewModelFallbackMode = false;
private previewModelBypassMode = false;
@@ -627,6 +629,12 @@ export class Config {
await this.getExtensionLoader().start(this),
]);
// Initialize hook system if enabled
if (this.enableHooks) {
this.hookSystem = new HookSystem(this);
await this.hookSystem.initialize();
}
await this.geminiClient.initialize();
}
@@ -1475,6 +1483,13 @@ export class Config {
return registry;
}
/**
* Get the hook system instance
*/
getHookSystem(): HookSystem | undefined {
return this.hookSystem;
}
/**
* Get hooks configuration
*/

View File

@@ -43,6 +43,7 @@ import type {
ResolvedModelConfig,
} from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { HookSystem } from '../hooks/hookSystem.js';
vi.mock('../services/chatCompressionService.js');
@@ -120,6 +121,7 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
getLastPromptTokenCount: vi.fn(),
},
}));
vi.mock('../hooks/hookSystem.js');
/**
* Array.fromAsync ponyfill, which will be available in es 2024.
@@ -211,6 +213,8 @@ describe('Gemini Client (client.ts)', () => {
getModelRouterService: vi.fn().mockReturnValue({
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
}),
getMessageBus: vi.fn().mockReturnValue(undefined),
getEnableHooks: vi.fn().mockReturnValue(false),
isInFallbackMode: vi.fn().mockReturnValue(false),
setFallbackMode: vi.fn(),
getChatCompression: vi.fn().mockReturnValue(undefined),
@@ -243,6 +247,9 @@ describe('Gemini Client (client.ts)', () => {
isInteractive: vi.fn().mockReturnValue(false),
getExperiments: () => {},
} as unknown as Config;
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
client = new GeminiClient(mockConfig);
await client.initialize();

View File

@@ -42,6 +42,10 @@ import {
logContentRetryFailure,
logNextSpeakerCheck,
} from '../telemetry/loggers.js';
import {
fireBeforeAgentHook,
fireAfterAgentHook,
} from './clientHookTriggers.js';
import {
ContentRetryFailureEvent,
NextSpeakerCheckEvent,
@@ -438,6 +442,35 @@ export class GeminiClient {
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (hooksEnabled && messageBus) {
const hookOutput = await fireBeforeAgentHook(messageBus, request);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
yield {
type: GeminiEventType.Error,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
},
};
return new Turn(this.getChat(), prompt_id);
}
// Add additional context from hooks to the request
const additionalContext = hookOutput?.getAdditionalContext();
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [...requestArray, { text: additionalContext }];
}
}
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.lastPromptId = prompt_id;
@@ -608,9 +641,9 @@ export class GeminiClient {
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(
// This recursive call's events will be yielded out, and the final
// turn object from the recursive call will be returned.
return yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
@@ -619,6 +652,32 @@ export class GeminiClient {
);
}
}
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const responseText = turn.getResponseText() || '[no response text]';
const hookOutput = await fireAfterAgentHook(
messageBus,
request,
responseText,
);
// For AfterAgent hooks, blocking/stop execution should force continuation
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
const continueReason = hookOutput.getEffectiveReason();
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
signal,
prompt_id,
boundedTurns - 1,
);
}
}
return turn;
}

View File

@@ -0,0 +1,105 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { PartListUnion } from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
MessageBusType,
type HookExecutionRequest,
type HookExecutionResponse,
} from '../confirmation-bus/types.js';
import { createHookOutput, type DefaultHookOutput } from '../hooks/types.js';
import { partToString } from '../utils/partUtils.js';
import { debugLogger } from '../utils/debugLogger.js';
/**
* Fires the BeforeAgent hook and returns the hook output.
* This should be called before processing a user prompt.
*
* The caller can use the returned DefaultHookOutput methods:
* - isBlockingDecision() / shouldStopExecution() to check if blocked
* - getEffectiveReason() to get the blocking reason
* - getAdditionalContext() to get additional context to add
*
* @param messageBus The message bus to use for hook communication
* @param request The user's request (prompt)
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireBeforeAgentHook(
messageBus: MessageBus,
request: PartListUnion,
): Promise<DefaultHookOutput | undefined> {
try {
const promptText = partToString(request);
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'BeforeAgent',
input: {
prompt: promptText,
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('BeforeAgent', response.output)
: undefined;
} catch (error) {
debugLogger.warn(`BeforeAgent hook failed: ${error}`);
return undefined;
}
}
/**
* Fires the AfterAgent hook and returns the hook output.
* This should be called after the agent has generated a response.
*
* The caller can use the returned DefaultHookOutput methods:
* - isBlockingDecision() / shouldStopExecution() to check if continuation is requested
* - getEffectiveReason() to get the continuation reason
*
* @param messageBus The message bus to use for hook communication
* @param request The original user's request (prompt)
* @param responseText The agent's response text
* @returns The hook output, or undefined if no hook was executed or on error
*/
export async function fireAfterAgentHook(
messageBus: MessageBus,
request: PartListUnion,
responseText: string,
): Promise<DefaultHookOutput | undefined> {
try {
const promptText = partToString(request);
const response = await messageBus.request<
HookExecutionRequest,
HookExecutionResponse
>(
{
type: MessageBusType.HOOK_EXECUTION_REQUEST,
eventName: 'AfterAgent',
input: {
prompt: promptText,
prompt_response: responseText,
stop_hook_active: false,
},
},
MessageBusType.HOOK_EXECUTION_RESPONSE,
);
return response.output
? createHookOutput('AfterAgent', response.output)
: undefined;
} catch (error) {
debugLogger.warn(`AfterAgent hook failed: ${error}`);
return undefined;
}
}

View File

@@ -18,9 +18,11 @@ import {
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
ToolErrorType,
ApprovalMode,
HookSystem,
} from '../index.js';
import type { Part } from '@google/genai';
import { MockTool } from '../test-utils/mock-tool.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
describe('executeToolCall', () => {
let mockToolRegistry: ToolRegistry;
@@ -66,8 +68,15 @@ describe('executeToolCall', () => {
getPolicyEngine: () => null,
isInteractive: () => false,
getExperiments: () => {},
getEnableHooks: () => false,
} as unknown as Config;
// Use proper MessageBus mocking for Phase 3 preparation
const mockMessageBus = createMockMessageBus();
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
abortController = new AbortController();
});

View File

@@ -392,6 +392,17 @@ export class Turn {
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
/**
* Get the concatenated response text from all responses in this turn.
* This extracts and joins all text content from the model's responses.
*/
getResponseText(): string {
return this.debugResponses
.map((response) => getResponseText(response))
.filter((text): text is string => text !== null)
.join(' ');
}
}
function getCitations(resp: GenerateContentResponse): string[] {

View File

@@ -0,0 +1,280 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { HookSystem } from './hookSystem.js';
import { Config } from '../config/config.js';
import { HookType } from './types.js';
import { spawn } from 'node:child_process';
import type { ChildProcessWithoutNullStreams } from 'node:child_process';
import type { Readable, Writable } from 'node:stream';
// Mock type for the child_process spawn
type MockChildProcessWithoutNullStreams = ChildProcessWithoutNullStreams & {
mockStdoutOn: ReturnType<typeof vi.fn>;
mockStderrOn: ReturnType<typeof vi.fn>;
mockProcessOn: ReturnType<typeof vi.fn>;
};
// Mock child_process with importOriginal for partial mocking
vi.mock('node:child_process', async (importOriginal) => {
const actual = (await importOriginal()) as object;
return {
...actual,
spawn: vi.fn(),
};
});
// Mock debugLogger - use vi.hoisted to define mock before it's used in vi.mock
const mockDebugLogger = vi.hoisted(() => ({
debug: vi.fn(),
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
}));
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: mockDebugLogger,
}));
// Mock console methods
const mockConsole = {
log: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
};
vi.stubGlobal('console', mockConsole);
describe('HookSystem Integration', () => {
let hookSystem: HookSystem;
let config: Config;
let mockSpawn: MockChildProcessWithoutNullStreams;
beforeEach(() => {
vi.resetAllMocks();
// Create a real config with simple command hook configurations for testing
config = new Config({
model: 'gemini-1.5-flash',
targetDir: '/tmp/test-hooks',
sessionId: 'test-session',
debugMode: false,
cwd: '/tmp/test-hooks',
hooks: {
BeforeTool: [
{
matcher: 'TestTool',
hooks: [
{
type: HookType.Command,
command: 'echo',
timeout: 5000,
},
],
},
],
},
});
// Provide getMessageBus mock for MessageBus integration tests
(config as unknown as { getMessageBus: () => unknown }).getMessageBus =
() => undefined;
hookSystem = new HookSystem(config);
// Set up spawn mock with accessible mock functions
const mockStdoutOn = vi.fn();
const mockStderrOn = vi.fn();
const mockProcessOn = vi.fn();
mockSpawn = {
stdin: {
write: vi.fn(),
end: vi.fn(),
} as unknown as Writable,
stdout: {
on: mockStdoutOn,
} as unknown as Readable,
stderr: {
on: mockStderrOn,
} as unknown as Readable,
on: mockProcessOn,
kill: vi.fn(),
killed: false,
mockStdoutOn,
mockStderrOn,
mockProcessOn,
} as unknown as MockChildProcessWithoutNullStreams;
vi.mocked(spawn).mockReturnValue(mockSpawn);
});
afterEach(async () => {
// No cleanup needed
});
describe('initialize', () => {
it('should initialize successfully', async () => {
await hookSystem.initialize();
expect(mockDebugLogger.debug).toHaveBeenCalledWith(
'Hook system initialized successfully',
);
// Verify system is initialized
const status = hookSystem.getStatus();
expect(status.initialized).toBe(true);
// Note: totalHooks might be 0 if hook validation rejects the test hooks
});
it('should not initialize twice', async () => {
await hookSystem.initialize();
await hookSystem.initialize(); // Second call should be no-op
// The system logs both registry initialization and system initialization
expect(mockDebugLogger.debug).toHaveBeenCalledWith(
'Hook system initialized successfully',
);
});
it('should handle initialization errors gracefully', async () => {
// Create a config with invalid hooks to trigger initialization errors
const invalidConfig = new Config({
model: 'gemini-1.5-flash',
targetDir: '/tmp/test-hooks-invalid',
sessionId: 'test-session-invalid',
debugMode: false,
cwd: '/tmp/test-hooks-invalid',
hooks: {
BeforeTool: [
{
hooks: [
{
type: 'invalid-type' as HookType, // Invalid hook type for testing
command: './test.sh',
},
],
},
],
},
});
const invalidHookSystem = new HookSystem(invalidConfig);
// Should not throw, but should log warnings via debugLogger
await invalidHookSystem.initialize();
expect(mockDebugLogger.warn).toHaveBeenCalled();
});
});
describe('getEventHandler', () => {
it('should return event bus when initialized', async () => {
await hookSystem.initialize();
// Set up spawn mock behavior for successful execution
mockSpawn.mockStdoutOn.mockImplementation(
(event: string, callback: (data: Buffer) => void) => {
if (event === 'data') {
setTimeout(() => callback(Buffer.from('')), 5); // echo outputs empty
}
},
);
mockSpawn.mockProcessOn.mockImplementation(
(event: string, callback: (code: number) => void) => {
if (event === 'close') {
setTimeout(() => callback(0), 10);
}
},
);
const eventBus = hookSystem.getEventHandler();
expect(eventBus).toBeDefined();
// Test that the event bus can actually fire events
const result = await eventBus.fireBeforeToolEvent('TestTool', {
test: 'data',
});
expect(result.success).toBe(true);
});
it('should throw error when not initialized', () => {
expect(() => hookSystem.getEventHandler()).toThrow(
'Hook system not initialized',
);
});
});
describe('hook execution', () => {
it('should execute hooks and return results', async () => {
await hookSystem.initialize();
// Set up spawn mock behavior for successful execution
mockSpawn.mockStdoutOn.mockImplementation(
(event: string, callback: (data: Buffer) => void) => {
if (event === 'data') {
setTimeout(() => callback(Buffer.from('')), 5); // echo outputs empty
}
},
);
mockSpawn.mockProcessOn.mockImplementation(
(event: string, callback: (code: number) => void) => {
if (event === 'close') {
setTimeout(() => callback(0), 10);
}
},
);
const eventBus = hookSystem.getEventHandler();
// Test BeforeTool event with command hook
const result = await eventBus.fireBeforeToolEvent('TestTool', {
test: 'data',
});
expect(result.success).toBe(true);
// Command hooks with echo should succeed but may not have specific decisions
expect(result.errors).toHaveLength(0);
});
it('should handle no matching hooks', async () => {
await hookSystem.initialize();
const eventBus = hookSystem.getEventHandler();
// Test with a tool that doesn't match any hooks
const result = await eventBus.fireBeforeToolEvent('UnmatchedTool', {
test: 'data',
});
expect(result.success).toBe(true);
expect(result.allOutputs).toHaveLength(0);
expect(result.finalOutput).toBeUndefined();
});
});
describe('system management', () => {
it('should return correct status when initialized', async () => {
await hookSystem.initialize();
const status = hookSystem.getStatus();
expect(status.initialized).toBe(true);
// Note: totalHooks might be 0 if hook validation rejects the test hooks
expect(typeof status.totalHooks).toBe('number');
});
it('should return uninitialized status', () => {
const status = hookSystem.getStatus();
expect(status.initialized).toBe(false);
});
});
});

View File

@@ -0,0 +1,106 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import { HookRegistry } from './hookRegistry.js';
import { HookRunner } from './hookRunner.js';
import { HookAggregator } from './hookAggregator.js';
import { HookPlanner } from './hookPlanner.js';
import { HookEventHandler } from './hookEventHandler.js';
import type { HookRegistryEntry } from './hookRegistry.js';
import { logs, type Logger } from '@opentelemetry/api-logs';
import { SERVICE_NAME } from '../telemetry/constants.js';
import { debugLogger } from '../utils/debugLogger.js';
/**
* Main hook system that coordinates all hook-related functionality
*/
export class HookSystem {
private readonly hookRegistry: HookRegistry;
private readonly hookRunner: HookRunner;
private readonly hookAggregator: HookAggregator;
private readonly hookPlanner: HookPlanner;
private readonly hookEventHandler: HookEventHandler;
private initialized = false;
constructor(config: Config) {
const logger: Logger = logs.getLogger(SERVICE_NAME);
const messageBus = config.getMessageBus();
// Initialize components
this.hookRegistry = new HookRegistry(config);
this.hookRunner = new HookRunner();
this.hookAggregator = new HookAggregator();
this.hookPlanner = new HookPlanner(this.hookRegistry);
this.hookEventHandler = new HookEventHandler(
config,
logger,
this.hookPlanner,
this.hookRunner,
this.hookAggregator,
messageBus, // Pass MessageBus to enable mediated hook execution
);
}
/**
* Initialize the hook system
*/
async initialize(): Promise<void> {
if (this.initialized) {
return;
}
await this.hookRegistry.initialize();
this.initialized = true;
debugLogger.debug('Hook system initialized successfully');
}
/**
* Get the hook event bus for firing events
*/
getEventHandler(): HookEventHandler {
if (!this.initialized) {
throw new Error('Hook system not initialized');
}
return this.hookEventHandler;
}
/**
* Get hook registry for management operations
*/
getRegistry(): HookRegistry {
return this.hookRegistry;
}
/**
* Enable or disable a hook
*/
setHookEnabled(hookName: string, enabled: boolean): void {
this.hookRegistry.setHookEnabled(hookName, enabled);
}
/**
* Get all registered hooks for display/management
*/
getAllHooks(): HookRegistryEntry[] {
return this.hookRegistry.getAllHooks();
}
/**
* Get hook system status for debugging
*/
getStatus(): {
initialized: boolean;
totalHooks: number;
} {
const allHooks = this.initialized ? this.hookRegistry.getAllHooks() : [];
return {
initialized: this.initialized,
totalHooks: allHooks.length,
};
}
}

View File

@@ -0,0 +1,21 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
// Export types
export * from './types.js';
// Export core components
export { HookSystem } from './hookSystem.js';
export { HookRegistry } from './hookRegistry.js';
export { HookRunner } from './hookRunner.js';
export { HookAggregator } from './hookAggregator.js';
export { HookPlanner } from './hookPlanner.js';
export { HookEventHandler } from './hookEventHandler.js';
// Export interfaces
export type { HookRegistryEntry, ConfigSource } from './hookRegistry.js';
export type { AggregatedHookResult } from './hookAggregator.js';
export type { HookEventContext } from './hookPlanner.js';

View File

@@ -139,6 +139,9 @@ export { sessionId } from './utils/session.js';
export * from './utils/browser.js';
export { Storage } from './config/storage.js';
// Export hooks system
export * from './hooks/index.js';
// Export test utils
export * from './test-utils/index.js';