Add other hook wrapper methods to hooksystem (#16361)

This commit is contained in:
Vedant Mahajan
2026-01-12 23:08:45 +05:30
committed by GitHub
parent 8656ce8a27
commit 8a2e0fac0d
3 changed files with 118 additions and 77 deletions

View File

@@ -46,9 +46,8 @@ import type {
ResolvedModelConfig,
} from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { HookSystem } from '../hooks/hookSystem.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import * as policyCatalog from '../availability/policyCatalog.js';
import { partToString } from '../utils/partUtils.js';
vi.mock('../services/chatCompressionService.js');
@@ -137,15 +136,22 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
},
}));
vi.mock('../hooks/hookSystem.js');
vi.mock('./clientHookTriggers.js', () => ({
fireBeforeAgentHook: vi.fn(),
fireAfterAgentHook: vi.fn().mockResolvedValue({
decision: 'allow',
continue: false,
suppressOutput: false,
systemMessage: undefined,
const mockHookSystem = {
fireBeforeAgentEvent: vi.fn().mockResolvedValue({
success: true,
finalOutput: undefined,
allOutputs: [],
errors: [],
totalDuration: 0,
}),
}));
fireAfterAgentEvent: vi.fn().mockResolvedValue({
success: true,
finalOutput: undefined,
allOutputs: [],
errors: [],
totalDuration: 0,
}),
};
/**
* Array.fromAsync ponyfill, which will be available in es 2024.
@@ -286,9 +292,7 @@ describe('Gemini Client (client.ts)', () => {
.fn()
.mockReturnValue(createAvailabilityServiceMock()),
} as unknown as Config;
mockConfig.getHookSystem = vi
.fn()
.mockReturnValue(new HookSystem(mockConfig));
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
client = new GeminiClient(mockConfig);
await client.initialize();
@@ -2688,9 +2692,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-1';
const request = { text: 'Hello Hooks' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
@@ -2702,11 +2703,10 @@ ${JSON.stringify(
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
partToString(request),
'Hook Response',
);
@@ -2725,9 +2725,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-recursive';
const request = { text: 'Recursion Test' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
let callCount = 0;
mockTurnRunFn.mockImplementation(async function* (
@@ -2743,15 +2740,14 @@ ${JSON.stringify(
while (!(await stream.next()).done);
// BeforeAgent should fire ONLY once despite multiple internal turns
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
// AfterAgent should fire ONLY when the stack unwinds
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
// Check cumulative response (separated by newline)
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
partToString(request),
'Response 1\nResponse 2',
);
@@ -2769,7 +2765,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-original-req';
const request = { text: 'Do something' };
const signal = new AbortController().signal;
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
@@ -2781,9 +2776,8 @@ ${JSON.stringify(
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request, // Should be 'Do something'
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
partToString(request), // Should be 'Do something'
expect.stringContaining('Ok'),
);
});
@@ -2817,11 +2811,17 @@ ${JSON.stringify(
});
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped by hook',
} as DefaultHookOutput);
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
success: true,
finalOutput: {
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped by hook',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
@@ -2850,12 +2850,18 @@ ${JSON.stringify(
});
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Blocked by hook',
} as DefaultHookOutput);
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
success: true,
finalOutput: {
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Blocked by hook',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
@@ -2883,11 +2889,17 @@ ${JSON.stringify(
});
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireAfterAgentHook).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped after agent',
} as DefaultHookOutput);
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
success: true,
finalOutput: {
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped after agent',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Hello' };
@@ -2909,17 +2921,30 @@ ${JSON.stringify(
});
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireAfterAgentHook)
mockHookSystem.fireAfterAgentEvent
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain',
} as DefaultHookOutput)
success: true,
finalOutput: {
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
})
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => false,
} as DefaultHookOutput);
success: true,
finalOutput: {
shouldStopExecution: () => false,
isBlockingDecision: () => false,
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Response' };

View File

@@ -12,7 +12,6 @@ import type {
GenerateContentResponse,
} from '@google/genai';
import { createUserContent } from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
getDirectoryContextString,
getInitialChatHistory,
@@ -40,10 +39,6 @@ import {
logContentRetryFailure,
logNextSpeakerCheck,
} from '../telemetry/loggers.js';
import {
fireBeforeAgentHook,
fireAfterAgentHook,
} from './clientHookTriggers.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import {
ContentRetryFailureEvent,
@@ -62,6 +57,7 @@ import {
} from '../availability/policyHelpers.js';
import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
const MAX_TURNS = 100;
@@ -113,7 +109,6 @@ export class GeminiClient {
>();
private async fireBeforeAgentHookSafe(
messageBus: MessageBus,
request: PartListUnion,
prompt_id: string,
): Promise<BeforeAgentHookReturn> {
@@ -138,7 +133,10 @@ export class GeminiClient {
return undefined;
}
const hookOutput = await fireBeforeAgentHook(messageBus, request);
const hookResult = await this.config
.getHookSystem()
?.fireBeforeAgentEvent(partToString(request));
const hookOutput = hookResult?.finalOutput;
hookState.hasFiredBeforeAgent = true;
if (hookOutput?.shouldStopExecution()) {
@@ -169,7 +167,6 @@ export class GeminiClient {
}
private async fireAfterAgentHookSafe(
messageBus: MessageBus,
currentRequest: PartListUnion,
prompt_id: string,
turn?: Turn,
@@ -190,11 +187,11 @@ export class GeminiClient {
'[no response text]';
const finalRequest = hookState.originalRequest || currentRequest;
const hookOutput = await fireAfterAgentHook(
messageBus,
finalRequest,
finalResponseText,
);
const hookResult = await this.config
.getHookSystem()
?.fireAfterAgentEvent(partToString(finalRequest), finalResponseText);
const hookOutput = hookResult?.finalOutput;
return hookOutput;
}
@@ -757,11 +754,7 @@ export class GeminiClient {
}
if (hooksEnabled && messageBus) {
const hookResult = await this.fireBeforeAgentHookSafe(
messageBus,
request,
prompt_id,
);
const hookResult = await this.fireBeforeAgentHookSafe(request, prompt_id);
if (hookResult) {
if (
'type' in hookResult &&
@@ -802,7 +795,6 @@ export class GeminiClient {
// Fire AfterAgent hook if we have a turn and no pending tools
if (hooksEnabled && messageBus) {
const hookOutput = await this.fireAfterAgentHookSafe(
messageBus,
request,
prompt_id,
turn,

View File

@@ -117,4 +117,28 @@ export class HookSystem {
}
return this.hookEventHandler.firePreCompressEvent(trigger);
}
async fireBeforeAgentEvent(
prompt: string,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireBeforeAgentEvent(prompt);
}
async fireAfterAgentEvent(
prompt: string,
response: string,
stopHookActive: boolean = false,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireAfterAgentEvent(
prompt,
response,
stopHookActive,
);
}
}