fix: migrate BeforeModel and AfterModel hooks to HookSystem (#16599)

Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
Vedant Mahajan
2026-01-20 23:16:54 +05:30
committed by GitHub
parent 15f26175b8
commit e92f60b4fc
3 changed files with 157 additions and 84 deletions

View File

@@ -22,24 +22,12 @@ import { AuthType } from './contentGenerator.js';
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
import { type RetryOptions } from '../utils/retry.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { HookSystem } from '../hooks/hookSystem.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import * as policyHelpers from '../availability/policyHelpers.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
import {
fireBeforeModelHook,
fireAfterModelHook,
fireBeforeToolSelectionHook,
} from './geminiChatHookTriggers.js';
// Mock hook triggers
vi.mock('./geminiChatHookTriggers.js', () => ({
fireBeforeModelHook: vi.fn(),
fireAfterModelHook: vi.fn(),
fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}),
}));
import type { HookSystem } from '../hooks/hookSystem.js';
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -204,9 +192,7 @@ describe('GeminiChat', () => {
setSimulate429(false);
// Reset history for each test by creating a new instance
chat = new GeminiChat(mockConfig);
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
});
afterEach(() => {
@@ -2283,18 +2269,20 @@ describe('GeminiChat', () => {
});
describe('Hook execution control', () => {
let mockHookSystem: HookSystem;
beforeEach(() => {
vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true);
// Default to allowing execution
vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false });
vi.mocked(fireAfterModelHook).mockResolvedValue({
response: {} as GenerateContentResponse,
});
vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({});
mockHookSystem = {
fireBeforeModelEvent: vi.fn().mockResolvedValue({ blocked: false }),
fireAfterModelEvent: vi.fn().mockResolvedValue({ response: {} }),
fireBeforeToolSelectionEvent: vi.fn().mockResolvedValue({}),
} as unknown as HookSystem;
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
});
it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => {
vi.mocked(fireBeforeModelHook).mockResolvedValue({
vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
blocked: true,
stopped: true,
reason: 'stopped by hook',
@@ -2324,7 +2312,7 @@ describe('GeminiChat', () => {
candidates: [{ content: { parts: [{ text: 'blocked' }] } }],
} as GenerateContentResponse;
vi.mocked(fireBeforeModelHook).mockResolvedValue({
vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
blocked: true,
reason: 'blocked by hook',
syntheticResponse,
@@ -2363,7 +2351,7 @@ describe('GeminiChat', () => {
})(),
);
vi.mocked(fireAfterModelHook).mockResolvedValue({
vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
response: {} as GenerateContentResponse,
stopped: true,
reason: 'stopped by after hook',
@@ -2399,7 +2387,7 @@ describe('GeminiChat', () => {
})(),
);
vi.mocked(fireAfterModelHook).mockResolvedValue({
vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
response,
blocked: true,
reason: 'blocked by after hook',

View File

@@ -49,11 +49,6 @@ import {
applyModelSelection,
createAvailabilityContextProvider,
} from '../availability/policyHelpers.js';
import {
fireAfterModelHook,
fireBeforeModelHook,
fireBeforeToolSelectionHook,
} from './geminiChatHookTriggers.js';
import { coreEvents } from '../utils/events.js';
export enum StreamEventType {
@@ -507,39 +502,26 @@ export class GeminiChat {
? contentsForPreviewModel
: requestContents;
// Fire BeforeModel and BeforeToolSelection hooks if enabled
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (hooksEnabled && messageBus) {
// Fire BeforeModel hook
const beforeModelResult = await fireBeforeModelHook(messageBus, {
const hookSystem = this.config.getHookSystem();
if (hookSystem) {
const beforeModelResult = await hookSystem.fireBeforeModelEvent({
model: modelToUse,
config,
contents: contentsToUse,
});
// Check if hook requested to stop execution
if (beforeModelResult.stopped) {
throw new AgentExecutionStoppedError(
beforeModelResult.reason || 'Agent execution stopped by hook',
);
}
// Check if hook blocked the model call
if (beforeModelResult.blocked) {
// Return a synthetic response generator
const syntheticResponse = beforeModelResult.syntheticResponse;
if (syntheticResponse) {
// Ensure synthetic response has a finish reason to prevent InvalidStreamError
if (
syntheticResponse.candidates &&
syntheticResponse.candidates.length > 0
) {
for (const candidate of syntheticResponse.candidates) {
if (!candidate.finishReason) {
candidate.finishReason = FinishReason.STOP;
}
}
for (const candidate of syntheticResponse?.candidates ?? []) {
if (!candidate.finishReason) {
candidate.finishReason = FinishReason.STOP;
}
}
@@ -549,7 +531,6 @@ export class GeminiChat {
);
}
// Apply modifications from BeforeModel hook
if (beforeModelResult.modifiedConfig) {
Object.assign(config, beforeModelResult.modifiedConfig);
}
@@ -560,17 +541,13 @@ export class GeminiChat {
contentsToUse = beforeModelResult.modifiedContents as Content[];
}
// Fire BeforeToolSelection hook
const toolSelectionResult = await fireBeforeToolSelectionHook(
messageBus,
{
const toolSelectionResult =
await hookSystem.fireBeforeToolSelectionEvent({
model: modelToUse,
config,
contents: contentsToUse,
},
);
});
// Apply tool configuration modifications
if (toolSelectionResult.toolConfig) {
config.toolConfig = toolSelectionResult.toolConfig;
}
@@ -825,12 +802,9 @@ export class GeminiChat {
}
}
// Fire AfterModel hook through MessageBus (only if hooks are enabled)
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (hooksEnabled && messageBus && originalRequest && chunk) {
const hookResult = await fireAfterModelHook(
messageBus,
const hookSystem = this.config.getHookSystem();
if (originalRequest && chunk && hookSystem) {
const hookResult = await hookSystem.fireAfterModelEvent(
originalRequest,
chunk,
);
@@ -850,7 +824,7 @@ export class GeminiChat {
yield hookResult.response;
} else {
yield chunk; // Yield every chunk to the UI immediately.
yield chunk;
}
}

View File

@@ -19,13 +19,24 @@ import type {
SessionEndReason,
PreCompressTrigger,
DefaultHookOutput,
BeforeModelHookOutput,
AfterModelHookOutput,
BeforeToolSelectionHookOutput,
} from './types.js';
import type { AggregatedHookResult } from './hookAggregator.js';
import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
import type {
AfterModelHookResult,
BeforeModelHookResult,
BeforeToolSelectionHookResult,
} from '../core/geminiChatHookTriggers.js';
/**
* Main hook system that coordinates all hook-related functionality
*/
export class HookSystem {
private readonly config: Config;
private readonly hookRegistry: HookRegistry;
private readonly hookRunner: HookRunner;
private readonly hookAggregator: HookAggregator;
@@ -33,7 +44,6 @@ export class HookSystem {
private readonly hookEventHandler: HookEventHandler;
constructor(config: Config) {
this.config = config;
const logger: Logger = logs.getLogger(SERVICE_NAME);
const messageBus = config.getMessageBus();
@@ -90,14 +100,10 @@ export class HookSystem {
/**
* Fire hook events directly
* Returns undefined if hooks are disabled
*/
async fireSessionStartEvent(
source: SessionStartSource,
): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireSessionStartEvent(source);
return result.finalOutput;
}
@@ -105,27 +111,18 @@ export class HookSystem {
async fireSessionEndEvent(
reason: SessionEndReason,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireSessionEndEvent(reason);
}
async firePreCompressEvent(
trigger: PreCompressTrigger,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.firePreCompressEvent(trigger);
}
async fireBeforeAgentEvent(
prompt: string,
): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireBeforeAgentEvent(prompt);
return result.finalOutput;
}
@@ -135,9 +132,6 @@ export class HookSystem {
response: string,
stopHookActive: boolean = false,
): Promise<DefaultHookOutput | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
const result = await this.hookEventHandler.fireAfterAgentEvent(
prompt,
response,
@@ -145,4 +139,121 @@ export class HookSystem {
);
return result.finalOutput;
}
async fireBeforeModelEvent(
llmRequest: GenerateContentParameters,
): Promise<BeforeModelHookResult> {
try {
const result =
await this.hookEventHandler.fireBeforeModelEvent(llmRequest);
const hookOutput = result.finalOutput;
if (hookOutput?.shouldStopExecution()) {
return {
blocked: true,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked) {
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
const syntheticResponse = beforeModelOutput.getSyntheticResponse();
return {
blocked: true,
reason:
hookOutput?.getEffectiveReason() || 'Model call blocked by hook',
syntheticResponse,
};
}
if (hookOutput) {
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
const modifiedRequest =
beforeModelOutput.applyLLMRequestModifications(llmRequest);
return {
blocked: false,
modifiedConfig: modifiedRequest?.config,
modifiedContents: modifiedRequest?.contents,
};
}
return { blocked: false };
} catch (error) {
debugLogger.debug(`BeforeModelHookEvent failed:`, error);
return { blocked: false };
}
}
async fireAfterModelEvent(
originalRequest: GenerateContentParameters,
chunk: GenerateContentResponse,
): Promise<AfterModelHookResult> {
try {
const result = await this.hookEventHandler.fireAfterModelEvent(
originalRequest,
chunk,
);
const hookOutput = result.finalOutput;
if (hookOutput?.shouldStopExecution()) {
return {
response: chunk,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked) {
return {
response: chunk,
blocked: true,
reason: hookOutput?.getEffectiveReason(),
};
}
if (hookOutput) {
const afterModelOutput = hookOutput as AfterModelHookOutput;
const modifiedResponse = afterModelOutput.getModifiedResponse();
if (modifiedResponse) {
return { response: modifiedResponse };
}
}
return { response: chunk };
} catch (error) {
debugLogger.debug(`AfterModelHookEvent failed:`, error);
return { response: chunk };
}
}
async fireBeforeToolSelectionEvent(
llmRequest: GenerateContentParameters,
): Promise<BeforeToolSelectionHookResult> {
try {
const result =
await this.hookEventHandler.fireBeforeToolSelectionEvent(llmRequest);
const hookOutput = result.finalOutput;
if (hookOutput) {
const toolSelectionOutput = hookOutput as BeforeToolSelectionHookOutput;
const modifiedConfig = toolSelectionOutput.applyToolConfigModifications(
{
toolConfig: llmRequest.config?.toolConfig,
tools: llmRequest.config?.tools,
},
);
return {
toolConfig: modifiedConfig.toolConfig,
tools: modifiedConfig.tools,
};
}
return {};
} catch (error) {
debugLogger.debug(`BeforeToolSelectionEvent failed:`, error);
return {};
}
}
}