From e92f60b4fc61a1b0deb7a761b6e401821a05bebe Mon Sep 17 00:00:00 2001 From: Vedant Mahajan Date: Tue, 20 Jan 2026 23:16:54 +0530 Subject: [PATCH] fix: migrate BeforeModel and AfterModel hooks to HookSystem (#16599) Co-authored-by: Tommaso Sciortino --- packages/core/src/core/geminiChat.test.ts | 40 +++--- packages/core/src/core/geminiChat.ts | 54 +++----- packages/core/src/hooks/hookSystem.ts | 147 +++++++++++++++++++--- 3 files changed, 157 insertions(+), 84 deletions(-) diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 1f60565f0d..2bf22d509c 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -22,24 +22,12 @@ import { AuthType } from './contentGenerator.js'; import { TerminalQuotaError } from '../utils/googleQuotaErrors.js'; import { type RetryOptions } from '../utils/retry.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; -import { HookSystem } from '../hooks/hookSystem.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import * as policyHelpers from '../availability/policyHelpers.js'; import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; -import { - fireBeforeModelHook, - fireAfterModelHook, - fireBeforeToolSelectionHook, -} from './geminiChatHookTriggers.js'; - -// Mock hook triggers -vi.mock('./geminiChatHookTriggers.js', () => ({ - fireBeforeModelHook: vi.fn(), - fireAfterModelHook: vi.fn(), - fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}), -})); +import type { HookSystem } from '../hooks/hookSystem.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -204,9 +192,7 @@ describe('GeminiChat', () => { setSimulate429(false); // Reset history for each test by creating a new instance chat = new GeminiChat(mockConfig); - mockConfig.getHookSystem = vi - .fn() - .mockReturnValue(new HookSystem(mockConfig)); + mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined); }); afterEach(() => { @@ -2283,18 +2269,20 @@ describe('GeminiChat', () => { }); describe('Hook execution control', () => { + let mockHookSystem: HookSystem; beforeEach(() => { vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true); - // Default to allowing execution - vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false }); - vi.mocked(fireAfterModelHook).mockResolvedValue({ - response: {} as GenerateContentResponse, - }); - vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({}); + + mockHookSystem = { + fireBeforeModelEvent: vi.fn().mockResolvedValue({ blocked: false }), + fireAfterModelEvent: vi.fn().mockResolvedValue({ response: {} }), + fireBeforeToolSelectionEvent: vi.fn().mockResolvedValue({}), + } as unknown as HookSystem; + mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem); }); it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => { - vi.mocked(fireBeforeModelHook).mockResolvedValue({ + vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({ blocked: true, stopped: true, reason: 'stopped by hook', @@ -2324,7 +2312,7 @@ describe('GeminiChat', () => { candidates: [{ content: { parts: [{ text: 'blocked' }] } }], } as GenerateContentResponse; - vi.mocked(fireBeforeModelHook).mockResolvedValue({ + vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({ blocked: true, reason: 'blocked by hook', syntheticResponse, @@ -2363,7 +2351,7 @@ describe('GeminiChat', () => { })(), ); - vi.mocked(fireAfterModelHook).mockResolvedValue({ + vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({ response: {} as GenerateContentResponse, stopped: true, reason: 'stopped by after hook', @@ -2399,7 +2387,7 @@ describe('GeminiChat', () => { })(), ); - vi.mocked(fireAfterModelHook).mockResolvedValue({ + vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({ response, blocked: true, reason: 'blocked by after hook', diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 25f32e09a7..8b75eedb1b 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -49,11 +49,6 @@ import { applyModelSelection, createAvailabilityContextProvider, } from '../availability/policyHelpers.js'; -import { - fireAfterModelHook, - fireBeforeModelHook, - fireBeforeToolSelectionHook, -} from './geminiChatHookTriggers.js'; import { coreEvents } from '../utils/events.js'; export enum StreamEventType { @@ -507,39 +502,26 @@ export class GeminiChat { ? contentsForPreviewModel : requestContents; - // Fire BeforeModel and BeforeToolSelection hooks if enabled - const hooksEnabled = this.config.getEnableHooks(); - const messageBus = this.config.getMessageBus(); - if (hooksEnabled && messageBus) { - // Fire BeforeModel hook - const beforeModelResult = await fireBeforeModelHook(messageBus, { + const hookSystem = this.config.getHookSystem(); + if (hookSystem) { + const beforeModelResult = await hookSystem.fireBeforeModelEvent({ model: modelToUse, config, contents: contentsToUse, }); - // Check if hook requested to stop execution if (beforeModelResult.stopped) { throw new AgentExecutionStoppedError( beforeModelResult.reason || 'Agent execution stopped by hook', ); } - // Check if hook blocked the model call if (beforeModelResult.blocked) { - // Return a synthetic response generator const syntheticResponse = beforeModelResult.syntheticResponse; - if (syntheticResponse) { - // Ensure synthetic response has a finish reason to prevent InvalidStreamError - if ( - syntheticResponse.candidates && - syntheticResponse.candidates.length > 0 - ) { - for (const candidate of syntheticResponse.candidates) { - if (!candidate.finishReason) { - candidate.finishReason = FinishReason.STOP; - } - } + + for (const candidate of syntheticResponse?.candidates ?? []) { + if (!candidate.finishReason) { + candidate.finishReason = FinishReason.STOP; } } @@ -549,7 +531,6 @@ export class GeminiChat { ); } - // Apply modifications from BeforeModel hook if (beforeModelResult.modifiedConfig) { Object.assign(config, beforeModelResult.modifiedConfig); } @@ -560,17 +541,13 @@ export class GeminiChat { contentsToUse = beforeModelResult.modifiedContents as Content[]; } - // Fire BeforeToolSelection hook - const toolSelectionResult = await fireBeforeToolSelectionHook( - messageBus, - { + const toolSelectionResult = + await hookSystem.fireBeforeToolSelectionEvent({ model: modelToUse, config, contents: contentsToUse, - }, - ); + }); - // Apply tool configuration modifications if (toolSelectionResult.toolConfig) { config.toolConfig = toolSelectionResult.toolConfig; } @@ -825,12 +802,9 @@ export class GeminiChat { } } - // Fire AfterModel hook through MessageBus (only if hooks are enabled) - const hooksEnabled = this.config.getEnableHooks(); - const messageBus = this.config.getMessageBus(); - if (hooksEnabled && messageBus && originalRequest && chunk) { - const hookResult = await fireAfterModelHook( - messageBus, + const hookSystem = this.config.getHookSystem(); + if (originalRequest && chunk && hookSystem) { + const hookResult = await hookSystem.fireAfterModelEvent( originalRequest, chunk, ); @@ -850,7 +824,7 @@ export class GeminiChat { yield hookResult.response; } else { - yield chunk; // Yield every chunk to the UI immediately. + yield chunk; } } diff --git a/packages/core/src/hooks/hookSystem.ts b/packages/core/src/hooks/hookSystem.ts index 74176be464..c10380be8e 100644 --- a/packages/core/src/hooks/hookSystem.ts +++ b/packages/core/src/hooks/hookSystem.ts @@ -19,13 +19,24 @@ import type { SessionEndReason, PreCompressTrigger, DefaultHookOutput, + BeforeModelHookOutput, + AfterModelHookOutput, + BeforeToolSelectionHookOutput, } from './types.js'; import type { AggregatedHookResult } from './hookAggregator.js'; +import type { + GenerateContentParameters, + GenerateContentResponse, +} from '@google/genai'; +import type { + AfterModelHookResult, + BeforeModelHookResult, + BeforeToolSelectionHookResult, +} from '../core/geminiChatHookTriggers.js'; /** * Main hook system that coordinates all hook-related functionality */ export class HookSystem { - private readonly config: Config; private readonly hookRegistry: HookRegistry; private readonly hookRunner: HookRunner; private readonly hookAggregator: HookAggregator; @@ -33,7 +44,6 @@ export class HookSystem { private readonly hookEventHandler: HookEventHandler; constructor(config: Config) { - this.config = config; const logger: Logger = logs.getLogger(SERVICE_NAME); const messageBus = config.getMessageBus(); @@ -90,14 +100,10 @@ export class HookSystem { /** * Fire hook events directly - * Returns undefined if hooks are disabled */ async fireSessionStartEvent( source: SessionStartSource, ): Promise { - if (!this.config.getEnableHooks()) { - return undefined; - } const result = await this.hookEventHandler.fireSessionStartEvent(source); return result.finalOutput; } @@ -105,27 +111,18 @@ export class HookSystem { async fireSessionEndEvent( reason: SessionEndReason, ): Promise { - if (!this.config.getEnableHooks()) { - return undefined; - } return this.hookEventHandler.fireSessionEndEvent(reason); } async firePreCompressEvent( trigger: PreCompressTrigger, ): Promise { - if (!this.config.getEnableHooks()) { - return undefined; - } return this.hookEventHandler.firePreCompressEvent(trigger); } async fireBeforeAgentEvent( prompt: string, ): Promise { - if (!this.config.getEnableHooks()) { - return undefined; - } const result = await this.hookEventHandler.fireBeforeAgentEvent(prompt); return result.finalOutput; } @@ -135,9 +132,6 @@ export class HookSystem { response: string, stopHookActive: boolean = false, ): Promise { - if (!this.config.getEnableHooks()) { - return undefined; - } const result = await this.hookEventHandler.fireAfterAgentEvent( prompt, response, @@ -145,4 +139,121 @@ export class HookSystem { ); return result.finalOutput; } + + async fireBeforeModelEvent( + llmRequest: GenerateContentParameters, + ): Promise { + try { + const result = + await this.hookEventHandler.fireBeforeModelEvent(llmRequest); + const hookOutput = result.finalOutput; + + if (hookOutput?.shouldStopExecution()) { + return { + blocked: true, + stopped: true, + reason: hookOutput.getEffectiveReason(), + }; + } + + const blockingError = hookOutput?.getBlockingError(); + if (blockingError?.blocked) { + const beforeModelOutput = hookOutput as BeforeModelHookOutput; + const syntheticResponse = beforeModelOutput.getSyntheticResponse(); + return { + blocked: true, + reason: + hookOutput?.getEffectiveReason() || 'Model call blocked by hook', + syntheticResponse, + }; + } + + if (hookOutput) { + const beforeModelOutput = hookOutput as BeforeModelHookOutput; + const modifiedRequest = + beforeModelOutput.applyLLMRequestModifications(llmRequest); + return { + blocked: false, + modifiedConfig: modifiedRequest?.config, + modifiedContents: modifiedRequest?.contents, + }; + } + + return { blocked: false }; + } catch (error) { + debugLogger.debug(`BeforeModelHookEvent failed:`, error); + return { blocked: false }; + } + } + + async fireAfterModelEvent( + originalRequest: GenerateContentParameters, + chunk: GenerateContentResponse, + ): Promise { + try { + const result = await this.hookEventHandler.fireAfterModelEvent( + originalRequest, + chunk, + ); + const hookOutput = result.finalOutput; + + if (hookOutput?.shouldStopExecution()) { + return { + response: chunk, + stopped: true, + reason: hookOutput.getEffectiveReason(), + }; + } + + const blockingError = hookOutput?.getBlockingError(); + if (blockingError?.blocked) { + return { + response: chunk, + blocked: true, + reason: hookOutput?.getEffectiveReason(), + }; + } + + if (hookOutput) { + const afterModelOutput = hookOutput as AfterModelHookOutput; + const modifiedResponse = afterModelOutput.getModifiedResponse(); + if (modifiedResponse) { + return { response: modifiedResponse }; + } + } + + return { response: chunk }; + } catch (error) { + debugLogger.debug(`AfterModelHookEvent failed:`, error); + return { response: chunk }; + } + } + + async fireBeforeToolSelectionEvent( + llmRequest: GenerateContentParameters, + ): Promise { + try { + const result = + await this.hookEventHandler.fireBeforeToolSelectionEvent(llmRequest); + const hookOutput = result.finalOutput; + + if (hookOutput) { + const toolSelectionOutput = hookOutput as BeforeToolSelectionHookOutput; + const modifiedConfig = toolSelectionOutput.applyToolConfigModifications( + { + toolConfig: llmRequest.config?.toolConfig, + tools: llmRequest.config?.tools, + }, + ); + return { + toolConfig: modifiedConfig.toolConfig, + tools: modifiedConfig.tools, + }; + } + return {}; + } catch (error) { + debugLogger.debug(`BeforeToolSelectionEvent failed:`, error); + return {}; + } + } }