fix: migrate BeforeModel and AfterModel hooks to HookSystem (#16599)

Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Vedant Mahajan
2026-01-20 23:16:54 +05:30
committed by GitHub
parent 15f26175b8
commit e92f60b4fc
3 changed files with 157 additions and 84 deletions
+14 -26
View File
@@ -22,24 +22,12 @@ import { AuthType } from './contentGenerator.js';
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js'; import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
import { type RetryOptions } from '../utils/retry.js'; import { type RetryOptions } from '../utils/retry.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { HookSystem } from '../hooks/hookSystem.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import * as policyHelpers from '../availability/policyHelpers.js'; import * as policyHelpers from '../availability/policyHelpers.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
import { import type { HookSystem } from '../hooks/hookSystem.js';
fireBeforeModelHook,
fireAfterModelHook,
fireBeforeToolSelectionHook,
} from './geminiChatHookTriggers.js';
// Mock hook triggers
vi.mock('./geminiChatHookTriggers.js', () => ({
fireBeforeModelHook: vi.fn(),
fireAfterModelHook: vi.fn(),
fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}),
}));
// Mock fs module to prevent actual file system operations during tests // Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>(); const mockFileSystem = new Map<string, string>();
@@ -204,9 +192,7 @@ describe('GeminiChat', () => {
setSimulate429(false); setSimulate429(false);
// Reset history for each test by creating a new instance // Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig); chat = new GeminiChat(mockConfig);
mockConfig.getHookSystem = vi mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
.fn()
.mockReturnValue(new HookSystem(mockConfig));
}); });
afterEach(() => { afterEach(() => {
@@ -2283,18 +2269,20 @@ describe('GeminiChat', () => {
}); });
describe('Hook execution control', () => { describe('Hook execution control', () => {
let mockHookSystem: HookSystem;
beforeEach(() => { beforeEach(() => {
vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true); vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true);
// Default to allowing execution
vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false }); mockHookSystem = {
vi.mocked(fireAfterModelHook).mockResolvedValue({ fireBeforeModelEvent: vi.fn().mockResolvedValue({ blocked: false }),
response: {} as GenerateContentResponse, fireAfterModelEvent: vi.fn().mockResolvedValue({ response: {} }),
}); fireBeforeToolSelectionEvent: vi.fn().mockResolvedValue({}),
vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({}); } as unknown as HookSystem;
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
}); });
it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => { it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => {
vi.mocked(fireBeforeModelHook).mockResolvedValue({ vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
blocked: true, blocked: true,
stopped: true, stopped: true,
reason: 'stopped by hook', reason: 'stopped by hook',
@@ -2324,7 +2312,7 @@ describe('GeminiChat', () => {
candidates: [{ content: { parts: [{ text: 'blocked' }] } }], candidates: [{ content: { parts: [{ text: 'blocked' }] } }],
} as GenerateContentResponse; } as GenerateContentResponse;
vi.mocked(fireBeforeModelHook).mockResolvedValue({ vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
blocked: true, blocked: true,
reason: 'blocked by hook', reason: 'blocked by hook',
syntheticResponse, syntheticResponse,
@@ -2363,7 +2351,7 @@ describe('GeminiChat', () => {
})(), })(),
); );
vi.mocked(fireAfterModelHook).mockResolvedValue({ vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
response: {} as GenerateContentResponse, response: {} as GenerateContentResponse,
stopped: true, stopped: true,
reason: 'stopped by after hook', reason: 'stopped by after hook',
@@ -2399,7 +2387,7 @@ describe('GeminiChat', () => {
})(), })(),
); );
vi.mocked(fireAfterModelHook).mockResolvedValue({ vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
response, response,
blocked: true, blocked: true,
reason: 'blocked by after hook', reason: 'blocked by after hook',
+14 -40
View File
@@ -49,11 +49,6 @@ import {
applyModelSelection, applyModelSelection,
createAvailabilityContextProvider, createAvailabilityContextProvider,
} from '../availability/policyHelpers.js'; } from '../availability/policyHelpers.js';
import {
fireAfterModelHook,
fireBeforeModelHook,
fireBeforeToolSelectionHook,
} from './geminiChatHookTriggers.js';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
export enum StreamEventType { export enum StreamEventType {
@@ -507,39 +502,26 @@ export class GeminiChat {
? contentsForPreviewModel ? contentsForPreviewModel
: requestContents; : requestContents;
// Fire BeforeModel and BeforeToolSelection hooks if enabled const hookSystem = this.config.getHookSystem();
const hooksEnabled = this.config.getEnableHooks(); if (hookSystem) {
const messageBus = this.config.getMessageBus(); const beforeModelResult = await hookSystem.fireBeforeModelEvent({
if (hooksEnabled && messageBus) {
// Fire BeforeModel hook
const beforeModelResult = await fireBeforeModelHook(messageBus, {
model: modelToUse, model: modelToUse,
config, config,
contents: contentsToUse, contents: contentsToUse,
}); });
// Check if hook requested to stop execution
if (beforeModelResult.stopped) { if (beforeModelResult.stopped) {
throw new AgentExecutionStoppedError( throw new AgentExecutionStoppedError(
beforeModelResult.reason || 'Agent execution stopped by hook', beforeModelResult.reason || 'Agent execution stopped by hook',
); );
} }
// Check if hook blocked the model call
if (beforeModelResult.blocked) { if (beforeModelResult.blocked) {
// Return a synthetic response generator
const syntheticResponse = beforeModelResult.syntheticResponse; const syntheticResponse = beforeModelResult.syntheticResponse;
if (syntheticResponse) {
// Ensure synthetic response has a finish reason to prevent InvalidStreamError for (const candidate of syntheticResponse?.candidates ?? []) {
if ( if (!candidate.finishReason) {
syntheticResponse.candidates && candidate.finishReason = FinishReason.STOP;
syntheticResponse.candidates.length > 0
) {
for (const candidate of syntheticResponse.candidates) {
if (!candidate.finishReason) {
candidate.finishReason = FinishReason.STOP;
}
}
} }
} }
@@ -549,7 +531,6 @@ export class GeminiChat {
); );
} }
// Apply modifications from BeforeModel hook
if (beforeModelResult.modifiedConfig) { if (beforeModelResult.modifiedConfig) {
Object.assign(config, beforeModelResult.modifiedConfig); Object.assign(config, beforeModelResult.modifiedConfig);
} }
@@ -560,17 +541,13 @@ export class GeminiChat {
contentsToUse = beforeModelResult.modifiedContents as Content[]; contentsToUse = beforeModelResult.modifiedContents as Content[];
} }
// Fire BeforeToolSelection hook const toolSelectionResult =
const toolSelectionResult = await fireBeforeToolSelectionHook( await hookSystem.fireBeforeToolSelectionEvent({
messageBus,
{
model: modelToUse, model: modelToUse,
config, config,
contents: contentsToUse, contents: contentsToUse,
}, });
);
// Apply tool configuration modifications
if (toolSelectionResult.toolConfig) { if (toolSelectionResult.toolConfig) {
config.toolConfig = toolSelectionResult.toolConfig; config.toolConfig = toolSelectionResult.toolConfig;
} }
@@ -825,12 +802,9 @@ export class GeminiChat {
} }
} }
// Fire AfterModel hook through MessageBus (only if hooks are enabled) const hookSystem = this.config.getHookSystem();
const hooksEnabled = this.config.getEnableHooks(); if (originalRequest && chunk && hookSystem) {
const messageBus = this.config.getMessageBus(); const hookResult = await hookSystem.fireAfterModelEvent(
if (hooksEnabled && messageBus && originalRequest && chunk) {
const hookResult = await fireAfterModelHook(
messageBus,
originalRequest, originalRequest,
chunk, chunk,
); );
@@ -850,7 +824,7 @@ export class GeminiChat {
yield hookResult.response; yield hookResult.response;
} else { } else {
yield chunk; // Yield every chunk to the UI immediately. yield chunk;
} }
} }
+129 -18
View File
@@ -19,13 +19,24 @@ import type {
SessionEndReason, SessionEndReason,
PreCompressTrigger, PreCompressTrigger,
DefaultHookOutput, DefaultHookOutput,
BeforeModelHookOutput,
AfterModelHookOutput,
BeforeToolSelectionHookOutput,
} from './types.js'; } from './types.js';
import type { AggregatedHookResult } from './hookAggregator.js'; import type { AggregatedHookResult } from './hookAggregator.js';
import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import type {
AfterModelHookResult,
BeforeModelHookResult,
BeforeToolSelectionHookResult,
} from '../core/geminiChatHookTriggers.js';
/** /**
* Main hook system that coordinates all hook-related functionality * Main hook system that coordinates all hook-related functionality
*/ */
export class HookSystem { export class HookSystem {
private readonly config: Config;
private readonly hookRegistry: HookRegistry; private readonly hookRegistry: HookRegistry;
private readonly hookRunner: HookRunner; private readonly hookRunner: HookRunner;
private readonly hookAggregator: HookAggregator; private readonly hookAggregator: HookAggregator;
@@ -33,7 +44,6 @@ export class HookSystem {
private readonly hookEventHandler: HookEventHandler; private readonly hookEventHandler: HookEventHandler;
constructor(config: Config) { constructor(config: Config) {
this.config = config;
const logger: Logger = logs.getLogger(SERVICE_NAME); const logger: Logger = logs.getLogger(SERVICE_NAME);
const messageBus = config.getMessageBus(); const messageBus = config.getMessageBus();
@@ -90,14 +100,10 @@ export class HookSystem {
/** /**
* Fire hook events directly * Fire hook events directly
* Returns undefined if hooks are disabled
*/ */
async fireSessionStartEvent( async fireSessionStartEvent(
source: SessionStartSource, source: SessionStartSource,
): Promise<DefaultHookOutput | undefined> { ): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireSessionStartEvent(source); const result = await this.hookEventHandler.fireSessionStartEvent(source);
return result.finalOutput; return result.finalOutput;
} }
@@ -105,27 +111,18 @@ export class HookSystem {
async fireSessionEndEvent( async fireSessionEndEvent(
reason: SessionEndReason, reason: SessionEndReason,
): Promise<AggregatedHookResult | undefined> { ): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireSessionEndEvent(reason); return this.hookEventHandler.fireSessionEndEvent(reason);
} }
async firePreCompressEvent( async firePreCompressEvent(
trigger: PreCompressTrigger, trigger: PreCompressTrigger,
): Promise<AggregatedHookResult | undefined> { ): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.firePreCompressEvent(trigger); return this.hookEventHandler.firePreCompressEvent(trigger);
} }
async fireBeforeAgentEvent( async fireBeforeAgentEvent(
prompt: string, prompt: string,
): Promise<DefaultHookOutput | undefined> { ): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireBeforeAgentEvent(prompt); const result = await this.hookEventHandler.fireBeforeAgentEvent(prompt);
return result.finalOutput; return result.finalOutput;
} }
@@ -135,9 +132,6 @@ export class HookSystem {
response: string, response: string,
stopHookActive: boolean = false, stopHookActive: boolean = false,
): Promise<DefaultHookOutput | undefined> { ): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireAfterAgentEvent( const result = await this.hookEventHandler.fireAfterAgentEvent(
prompt, prompt,
response, response,
@@ -145,4 +139,121 @@ export class HookSystem {
); );
return result.finalOutput; return result.finalOutput;
} }
async fireBeforeModelEvent(
llmRequest: GenerateContentParameters,
): Promise<BeforeModelHookResult> {
try {
const result =
await this.hookEventHandler.fireBeforeModelEvent(llmRequest);
const hookOutput = result.finalOutput;
if (hookOutput?.shouldStopExecution()) {
return {
blocked: true,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked) {
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
const syntheticResponse = beforeModelOutput.getSyntheticResponse();
return {
blocked: true,
reason:
hookOutput?.getEffectiveReason() || 'Model call blocked by hook',
syntheticResponse,
};
}
if (hookOutput) {
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
const modifiedRequest =
beforeModelOutput.applyLLMRequestModifications(llmRequest);
return {
blocked: false,
modifiedConfig: modifiedRequest?.config,
modifiedContents: modifiedRequest?.contents,
};
}
return { blocked: false };
} catch (error) {
debugLogger.debug(`BeforeModelHookEvent failed:`, error);
return { blocked: false };
}
}
async fireAfterModelEvent(
originalRequest: GenerateContentParameters,
chunk: GenerateContentResponse,
): Promise<AfterModelHookResult> {
try {
const result = await this.hookEventHandler.fireAfterModelEvent(
originalRequest,
chunk,
);
const hookOutput = result.finalOutput;
if (hookOutput?.shouldStopExecution()) {
return {
response: chunk,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked) {
return {
response: chunk,
blocked: true,
reason: hookOutput?.getEffectiveReason(),
};
}
if (hookOutput) {
const afterModelOutput = hookOutput as AfterModelHookOutput;
const modifiedResponse = afterModelOutput.getModifiedResponse();
if (modifiedResponse) {
return { response: modifiedResponse };
}
}
return { response: chunk };
} catch (error) {
debugLogger.debug(`AfterModelHookEvent failed:`, error);
return { response: chunk };
}
}
async fireBeforeToolSelectionEvent(
llmRequest: GenerateContentParameters,
): Promise<BeforeToolSelectionHookResult> {
try {
const result =
await this.hookEventHandler.fireBeforeToolSelectionEvent(llmRequest);
const hookOutput = result.finalOutput;
if (hookOutput) {
const toolSelectionOutput = hookOutput as BeforeToolSelectionHookOutput;
const modifiedConfig = toolSelectionOutput.applyToolConfigModifications(
{
toolConfig: llmRequest.config?.toolConfig,
tools: llmRequest.config?.tools,
},
);
return {
toolConfig: modifiedConfig.toolConfig,
tools: modifiedConfig.tools,
};
}
return {};
} catch (error) {
debugLogger.debug(`BeforeToolSelectionEvent failed:`, error);
return {};
}
}
} }