Add other hook wrapper methods to hooksystem (#16361)

This commit is contained in:
Vedant Mahajan
2026-01-12 23:08:45 +05:30
committed by GitHub
parent 8656ce8a27
commit 8a2e0fac0d
3 changed files with 118 additions and 77 deletions
+83 -58
View File
@@ -46,9 +46,8 @@ import type {
ResolvedModelConfig, ResolvedModelConfig,
} from '../services/modelConfigService.js'; } from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { HookSystem } from '../hooks/hookSystem.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import * as policyCatalog from '../availability/policyCatalog.js'; import * as policyCatalog from '../availability/policyCatalog.js';
import { partToString } from '../utils/partUtils.js';
vi.mock('../services/chatCompressionService.js'); vi.mock('../services/chatCompressionService.js');
@@ -137,15 +136,22 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
}, },
})); }));
vi.mock('../hooks/hookSystem.js'); vi.mock('../hooks/hookSystem.js');
vi.mock('./clientHookTriggers.js', () => ({ const mockHookSystem = {
fireBeforeAgentHook: vi.fn(), fireBeforeAgentEvent: vi.fn().mockResolvedValue({
fireAfterAgentHook: vi.fn().mockResolvedValue({ success: true,
decision: 'allow', finalOutput: undefined,
continue: false, allOutputs: [],
suppressOutput: false, errors: [],
systemMessage: undefined, totalDuration: 0,
}), }),
})); fireAfterAgentEvent: vi.fn().mockResolvedValue({
success: true,
finalOutput: undefined,
allOutputs: [],
errors: [],
totalDuration: 0,
}),
};
/** /**
* Array.fromAsync ponyfill, which will be available in es 2024. * Array.fromAsync ponyfill, which will be available in es 2024.
@@ -286,9 +292,7 @@ describe('Gemini Client (client.ts)', () => {
.fn() .fn()
.mockReturnValue(createAvailabilityServiceMock()), .mockReturnValue(createAvailabilityServiceMock()),
} as unknown as Config; } as unknown as Config;
mockConfig.getHookSystem = vi mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
.fn()
.mockReturnValue(new HookSystem(mockConfig));
client = new GeminiClient(mockConfig); client = new GeminiClient(mockConfig);
await client.initialize(); await client.initialize();
@@ -2688,9 +2692,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-1'; const promptId = 'test-prompt-hook-1';
const request = { text: 'Hello Hooks' }; const request = { text: 'Hello Hooks' };
const signal = new AbortController().signal; const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
mockTurnRunFn.mockImplementation(async function* ( mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext, this: MockTurnContext,
@@ -2702,11 +2703,10 @@ ${JSON.stringify(
const stream = client.sendMessageStream(request, signal, promptId); const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done); while (!(await stream.next()).done);
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1); expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1); expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledWith( expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
expect.anything(), partToString(request),
request,
'Hook Response', 'Hook Response',
); );
@@ -2725,9 +2725,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-recursive'; const promptId = 'test-prompt-hook-recursive';
const request = { text: 'Recursion Test' }; const request = { text: 'Recursion Test' };
const signal = new AbortController().signal; const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
let callCount = 0; let callCount = 0;
mockTurnRunFn.mockImplementation(async function* ( mockTurnRunFn.mockImplementation(async function* (
@@ -2743,15 +2740,14 @@ ${JSON.stringify(
while (!(await stream.next()).done); while (!(await stream.next()).done);
// BeforeAgent should fire ONLY once despite multiple internal turns // BeforeAgent should fire ONLY once despite multiple internal turns
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1); expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
// AfterAgent should fire ONLY when the stack unwinds // AfterAgent should fire ONLY when the stack unwinds
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1); expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
// Check cumulative response (separated by newline) // Check cumulative response (separated by newline)
expect(fireAfterAgentHook).toHaveBeenCalledWith( expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
expect.anything(), partToString(request),
request,
'Response 1\nResponse 2', 'Response 1\nResponse 2',
); );
@@ -2769,7 +2765,6 @@ ${JSON.stringify(
const promptId = 'test-prompt-hook-original-req'; const promptId = 'test-prompt-hook-original-req';
const request = { text: 'Do something' }; const request = { text: 'Do something' };
const signal = new AbortController().signal; const signal = new AbortController().signal;
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
mockTurnRunFn.mockImplementation(async function* ( mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext, this: MockTurnContext,
@@ -2781,9 +2776,8 @@ ${JSON.stringify(
const stream = client.sendMessageStream(request, signal, promptId); const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done); while (!(await stream.next()).done);
expect(fireAfterAgentHook).toHaveBeenCalledWith( expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
expect.anything(), partToString(request), // Should be 'Do something'
request, // Should be 'Do something'
expect.stringContaining('Ok'), expect.stringContaining('Ok'),
); );
}); });
@@ -2817,11 +2811,17 @@ ${JSON.stringify(
}); });
it('should stop execution in BeforeAgent when hook returns continue: false', async () => { it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
vi.mocked(fireBeforeAgentHook).mockResolvedValue({ success: true,
shouldStopExecution: () => true, finalOutput: {
getEffectiveReason: () => 'Stopped by hook', shouldStopExecution: () => true,
} as DefaultHookOutput); getEffectiveReason: () => 'Stopped by hook',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
const mockChat: Partial<GeminiChat> = { const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(), addHistory: vi.fn(),
@@ -2850,12 +2850,18 @@ ${JSON.stringify(
}); });
it('should block execution in BeforeAgent when hook returns decision: block', async () => { it('should block execution in BeforeAgent when hook returns decision: block', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
vi.mocked(fireBeforeAgentHook).mockResolvedValue({ success: true,
shouldStopExecution: () => false, finalOutput: {
isBlockingDecision: () => true, shouldStopExecution: () => false,
getEffectiveReason: () => 'Blocked by hook', isBlockingDecision: () => true,
} as DefaultHookOutput); getEffectiveReason: () => 'Blocked by hook',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
const mockChat: Partial<GeminiChat> = { const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(), addHistory: vi.fn(),
@@ -2883,11 +2889,17 @@ ${JSON.stringify(
}); });
it('should stop execution in AfterAgent when hook returns continue: false', async () => { it('should stop execution in AfterAgent when hook returns continue: false', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
vi.mocked(fireAfterAgentHook).mockResolvedValue({ success: true,
shouldStopExecution: () => true, finalOutput: {
getEffectiveReason: () => 'Stopped after agent', shouldStopExecution: () => true,
} as DefaultHookOutput); getEffectiveReason: () => 'Stopped after agent',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
mockTurnRunFn.mockImplementation(async function* () { mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Hello' }; yield { type: GeminiEventType.Content, value: 'Hello' };
@@ -2909,17 +2921,30 @@ ${JSON.stringify(
}); });
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => { it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); mockHookSystem.fireAfterAgentEvent
vi.mocked(fireAfterAgentHook)
.mockResolvedValueOnce({ .mockResolvedValueOnce({
shouldStopExecution: () => false, success: true,
isBlockingDecision: () => true, finalOutput: {
getEffectiveReason: () => 'Please explain', shouldStopExecution: () => false,
} as DefaultHookOutput) isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain',
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
})
.mockResolvedValueOnce({ .mockResolvedValueOnce({
shouldStopExecution: () => false, success: true,
isBlockingDecision: () => false, finalOutput: {
} as DefaultHookOutput); shouldStopExecution: () => false,
isBlockingDecision: () => false,
systemMessage: undefined,
},
allOutputs: [],
errors: [],
totalDuration: 0,
});
mockTurnRunFn.mockImplementation(async function* () { mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Response' }; yield { type: GeminiEventType.Content, value: 'Response' };
+11 -19
View File
@@ -12,7 +12,6 @@ import type {
GenerateContentResponse, GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import { createUserContent } from '@google/genai'; import { createUserContent } from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { import {
getDirectoryContextString, getDirectoryContextString,
getInitialChatHistory, getInitialChatHistory,
@@ -40,10 +39,6 @@ import {
logContentRetryFailure, logContentRetryFailure,
logNextSpeakerCheck, logNextSpeakerCheck,
} from '../telemetry/loggers.js'; } from '../telemetry/loggers.js';
import {
fireBeforeAgentHook,
fireAfterAgentHook,
} from './clientHookTriggers.js';
import type { DefaultHookOutput } from '../hooks/types.js'; import type { DefaultHookOutput } from '../hooks/types.js';
import { import {
ContentRetryFailureEvent, ContentRetryFailureEvent,
@@ -62,6 +57,7 @@ import {
} from '../availability/policyHelpers.js'; } from '../availability/policyHelpers.js';
import { resolveModel } from '../config/models.js'; import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js'; import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
const MAX_TURNS = 100; const MAX_TURNS = 100;
@@ -113,7 +109,6 @@ export class GeminiClient {
>(); >();
private async fireBeforeAgentHookSafe( private async fireBeforeAgentHookSafe(
messageBus: MessageBus,
request: PartListUnion, request: PartListUnion,
prompt_id: string, prompt_id: string,
): Promise<BeforeAgentHookReturn> { ): Promise<BeforeAgentHookReturn> {
@@ -138,7 +133,10 @@ export class GeminiClient {
return undefined; return undefined;
} }
const hookOutput = await fireBeforeAgentHook(messageBus, request); const hookResult = await this.config
.getHookSystem()
?.fireBeforeAgentEvent(partToString(request));
const hookOutput = hookResult?.finalOutput;
hookState.hasFiredBeforeAgent = true; hookState.hasFiredBeforeAgent = true;
if (hookOutput?.shouldStopExecution()) { if (hookOutput?.shouldStopExecution()) {
@@ -169,7 +167,6 @@ export class GeminiClient {
} }
private async fireAfterAgentHookSafe( private async fireAfterAgentHookSafe(
messageBus: MessageBus,
currentRequest: PartListUnion, currentRequest: PartListUnion,
prompt_id: string, prompt_id: string,
turn?: Turn, turn?: Turn,
@@ -190,11 +187,11 @@ export class GeminiClient {
'[no response text]'; '[no response text]';
const finalRequest = hookState.originalRequest || currentRequest; const finalRequest = hookState.originalRequest || currentRequest;
const hookOutput = await fireAfterAgentHook( const hookResult = await this.config
messageBus, .getHookSystem()
finalRequest, ?.fireAfterAgentEvent(partToString(finalRequest), finalResponseText);
finalResponseText, const hookOutput = hookResult?.finalOutput;
);
return hookOutput; return hookOutput;
} }
@@ -757,11 +754,7 @@ export class GeminiClient {
} }
if (hooksEnabled && messageBus) { if (hooksEnabled && messageBus) {
const hookResult = await this.fireBeforeAgentHookSafe( const hookResult = await this.fireBeforeAgentHookSafe(request, prompt_id);
messageBus,
request,
prompt_id,
);
if (hookResult) { if (hookResult) {
if ( if (
'type' in hookResult && 'type' in hookResult &&
@@ -802,7 +795,6 @@ export class GeminiClient {
// Fire AfterAgent hook if we have a turn and no pending tools // Fire AfterAgent hook if we have a turn and no pending tools
if (hooksEnabled && messageBus) { if (hooksEnabled && messageBus) {
const hookOutput = await this.fireAfterAgentHookSafe( const hookOutput = await this.fireAfterAgentHookSafe(
messageBus,
request, request,
prompt_id, prompt_id,
turn, turn,
+24
View File
@@ -117,4 +117,28 @@ export class HookSystem {
} }
return this.hookEventHandler.firePreCompressEvent(trigger); return this.hookEventHandler.firePreCompressEvent(trigger);
} }
async fireBeforeAgentEvent(
prompt: string,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireBeforeAgentEvent(prompt);
}
async fireAfterAgentEvent(
prompt: string,
response: string,
stopHookActive: boolean = false,
): Promise<AggregatedHookResult | undefined> {
if (!this.config.getEnableHooks()) {
return undefined;
}
return this.hookEventHandler.fireAfterAgentEvent(
prompt,
response,
stopHookActive,
);
}
} }