From 8a2e0fac0d8c3499e628ed52d517987163f32541 Mon Sep 17 00:00:00 2001 From: Vedant Mahajan Date: Mon, 12 Jan 2026 23:08:45 +0530 Subject: [PATCH] Add other hook wrapper methods to hooksystem (#16361) --- packages/core/src/core/client.test.ts | 141 +++++++++++++++----------- packages/core/src/core/client.ts | 30 ++---- packages/core/src/hooks/hookSystem.ts | 24 +++++ 3 files changed, 118 insertions(+), 77 deletions(-) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 16f78d40d8..84418c9855 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -46,9 +46,8 @@ import type { ResolvedModelConfig, } from '../services/modelConfigService.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; -import { HookSystem } from '../hooks/hookSystem.js'; -import type { DefaultHookOutput } from '../hooks/types.js'; import * as policyCatalog from '../availability/policyCatalog.js'; +import { partToString } from '../utils/partUtils.js'; vi.mock('../services/chatCompressionService.js'); @@ -137,15 +136,22 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({ }, })); vi.mock('../hooks/hookSystem.js'); -vi.mock('./clientHookTriggers.js', () => ({ - fireBeforeAgentHook: vi.fn(), - fireAfterAgentHook: vi.fn().mockResolvedValue({ - decision: 'allow', - continue: false, - suppressOutput: false, - systemMessage: undefined, +const mockHookSystem = { + fireBeforeAgentEvent: vi.fn().mockResolvedValue({ + success: true, + finalOutput: undefined, + allOutputs: [], + errors: [], + totalDuration: 0, }), -})); + fireAfterAgentEvent: vi.fn().mockResolvedValue({ + success: true, + finalOutput: undefined, + allOutputs: [], + errors: [], + totalDuration: 0, + }), +}; /** * Array.fromAsync ponyfill, which will be available in es 2024. @@ -286,9 +292,7 @@ describe('Gemini Client (client.ts)', () => { .fn() .mockReturnValue(createAvailabilityServiceMock()), } as unknown as Config; - mockConfig.getHookSystem = vi - .fn() - .mockReturnValue(new HookSystem(mockConfig)); + mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem); client = new GeminiClient(mockConfig); await client.initialize(); @@ -2688,9 +2692,6 @@ ${JSON.stringify( const promptId = 'test-prompt-hook-1'; const request = { text: 'Hello Hooks' }; const signal = new AbortController().signal; - const { fireBeforeAgentHook, fireAfterAgentHook } = await import( - './clientHookTriggers.js' - ); mockTurnRunFn.mockImplementation(async function* ( this: MockTurnContext, @@ -2702,11 +2703,10 @@ ${JSON.stringify( const stream = client.sendMessageStream(request, signal, promptId); while (!(await stream.next()).done); - expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1); - expect(fireAfterAgentHook).toHaveBeenCalledTimes(1); - expect(fireAfterAgentHook).toHaveBeenCalledWith( - expect.anything(), - request, + expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1); + expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1); + expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith( + partToString(request), 'Hook Response', ); @@ -2725,9 +2725,6 @@ ${JSON.stringify( const promptId = 'test-prompt-hook-recursive'; const request = { text: 'Recursion Test' }; const signal = new AbortController().signal; - const { fireBeforeAgentHook, fireAfterAgentHook } = await import( - './clientHookTriggers.js' - ); let callCount = 0; mockTurnRunFn.mockImplementation(async function* ( @@ -2743,15 +2740,14 @@ ${JSON.stringify( while (!(await stream.next()).done); // BeforeAgent should fire ONLY once despite multiple internal turns - expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1); + expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1); // AfterAgent should fire ONLY when the stack unwinds - expect(fireAfterAgentHook).toHaveBeenCalledTimes(1); + expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1); // Check cumulative response (separated by newline) - expect(fireAfterAgentHook).toHaveBeenCalledWith( - expect.anything(), - request, + expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith( + partToString(request), 'Response 1\nResponse 2', ); @@ -2769,7 +2765,6 @@ ${JSON.stringify( const promptId = 'test-prompt-hook-original-req'; const request = { text: 'Do something' }; const signal = new AbortController().signal; - const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); mockTurnRunFn.mockImplementation(async function* ( this: MockTurnContext, @@ -2781,9 +2776,8 @@ ${JSON.stringify( const stream = client.sendMessageStream(request, signal, promptId); while (!(await stream.next()).done); - expect(fireAfterAgentHook).toHaveBeenCalledWith( - expect.anything(), - request, // Should be 'Do something' + expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith( + partToString(request), // Should be 'Do something' expect.stringContaining('Ok'), ); }); @@ -2817,11 +2811,17 @@ ${JSON.stringify( }); it('should stop execution in BeforeAgent when hook returns continue: false', async () => { - const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); - vi.mocked(fireBeforeAgentHook).mockResolvedValue({ - shouldStopExecution: () => true, - getEffectiveReason: () => 'Stopped by hook', - } as DefaultHookOutput); + mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({ + success: true, + finalOutput: { + shouldStopExecution: () => true, + getEffectiveReason: () => 'Stopped by hook', + systemMessage: undefined, + }, + allOutputs: [], + errors: [], + totalDuration: 0, + }); const mockChat: Partial = { addHistory: vi.fn(), @@ -2850,12 +2850,18 @@ ${JSON.stringify( }); it('should block execution in BeforeAgent when hook returns decision: block', async () => { - const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); - vi.mocked(fireBeforeAgentHook).mockResolvedValue({ - shouldStopExecution: () => false, - isBlockingDecision: () => true, - getEffectiveReason: () => 'Blocked by hook', - } as DefaultHookOutput); + mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({ + success: true, + finalOutput: { + shouldStopExecution: () => false, + isBlockingDecision: () => true, + getEffectiveReason: () => 'Blocked by hook', + systemMessage: undefined, + }, + allOutputs: [], + errors: [], + totalDuration: 0, + }); const mockChat: Partial = { addHistory: vi.fn(), @@ -2883,11 +2889,17 @@ ${JSON.stringify( }); it('should stop execution in AfterAgent when hook returns continue: false', async () => { - const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); - vi.mocked(fireAfterAgentHook).mockResolvedValue({ - shouldStopExecution: () => true, - getEffectiveReason: () => 'Stopped after agent', - } as DefaultHookOutput); + mockHookSystem.fireAfterAgentEvent.mockResolvedValue({ + success: true, + finalOutput: { + shouldStopExecution: () => true, + getEffectiveReason: () => 'Stopped after agent', + systemMessage: undefined, + }, + allOutputs: [], + errors: [], + totalDuration: 0, + }); mockTurnRunFn.mockImplementation(async function* () { yield { type: GeminiEventType.Content, value: 'Hello' }; @@ -2909,17 +2921,30 @@ ${JSON.stringify( }); it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => { - const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); - vi.mocked(fireAfterAgentHook) + mockHookSystem.fireAfterAgentEvent .mockResolvedValueOnce({ - shouldStopExecution: () => false, - isBlockingDecision: () => true, - getEffectiveReason: () => 'Please explain', - } as DefaultHookOutput) + success: true, + finalOutput: { + shouldStopExecution: () => false, + isBlockingDecision: () => true, + getEffectiveReason: () => 'Please explain', + systemMessage: undefined, + }, + allOutputs: [], + errors: [], + totalDuration: 0, + }) .mockResolvedValueOnce({ - shouldStopExecution: () => false, - isBlockingDecision: () => false, - } as DefaultHookOutput); + success: true, + finalOutput: { + shouldStopExecution: () => false, + isBlockingDecision: () => false, + systemMessage: undefined, + }, + allOutputs: [], + errors: [], + totalDuration: 0, + }); mockTurnRunFn.mockImplementation(async function* () { yield { type: GeminiEventType.Content, value: 'Response' }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 67dae3f927..535bf15ce7 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -12,7 +12,6 @@ import type { GenerateContentResponse, } from '@google/genai'; import { createUserContent } from '@google/genai'; -import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { getDirectoryContextString, getInitialChatHistory, @@ -40,10 +39,6 @@ import { logContentRetryFailure, logNextSpeakerCheck, } from '../telemetry/loggers.js'; -import { - fireBeforeAgentHook, - fireAfterAgentHook, -} from './clientHookTriggers.js'; import type { DefaultHookOutput } from '../hooks/types.js'; import { ContentRetryFailureEvent, @@ -62,6 +57,7 @@ import { } from '../availability/policyHelpers.js'; import { resolveModel } from '../config/models.js'; import type { RetryAvailabilityContext } from '../utils/retry.js'; +import { partToString } from '../utils/partUtils.js'; const MAX_TURNS = 100; @@ -113,7 +109,6 @@ export class GeminiClient { >(); private async fireBeforeAgentHookSafe( - messageBus: MessageBus, request: PartListUnion, prompt_id: string, ): Promise { @@ -138,7 +133,10 @@ export class GeminiClient { return undefined; } - const hookOutput = await fireBeforeAgentHook(messageBus, request); + const hookResult = await this.config + .getHookSystem() + ?.fireBeforeAgentEvent(partToString(request)); + const hookOutput = hookResult?.finalOutput; hookState.hasFiredBeforeAgent = true; if (hookOutput?.shouldStopExecution()) { @@ -169,7 +167,6 @@ export class GeminiClient { } private async fireAfterAgentHookSafe( - messageBus: MessageBus, currentRequest: PartListUnion, prompt_id: string, turn?: Turn, @@ -190,11 +187,11 @@ export class GeminiClient { '[no response text]'; const finalRequest = hookState.originalRequest || currentRequest; - const hookOutput = await fireAfterAgentHook( - messageBus, - finalRequest, - finalResponseText, - ); + const hookResult = await this.config + .getHookSystem() + ?.fireAfterAgentEvent(partToString(finalRequest), finalResponseText); + const hookOutput = hookResult?.finalOutput; + return hookOutput; } @@ -757,11 +754,7 @@ export class GeminiClient { } if (hooksEnabled && messageBus) { - const hookResult = await this.fireBeforeAgentHookSafe( - messageBus, - request, - prompt_id, - ); + const hookResult = await this.fireBeforeAgentHookSafe(request, prompt_id); if (hookResult) { if ( 'type' in hookResult && @@ -802,7 +795,6 @@ export class GeminiClient { // Fire AfterAgent hook if we have a turn and no pending tools if (hooksEnabled && messageBus) { const hookOutput = await this.fireAfterAgentHookSafe( - messageBus, request, prompt_id, turn, diff --git a/packages/core/src/hooks/hookSystem.ts b/packages/core/src/hooks/hookSystem.ts index 98f7baf817..547b44a923 100644 --- a/packages/core/src/hooks/hookSystem.ts +++ b/packages/core/src/hooks/hookSystem.ts @@ -117,4 +117,28 @@ export class HookSystem { } return this.hookEventHandler.firePreCompressEvent(trigger); } + + async fireBeforeAgentEvent( + prompt: string, + ): Promise { + if (!this.config.getEnableHooks()) { + return undefined; + } + return this.hookEventHandler.fireBeforeAgentEvent(prompt); + } + + async fireAfterAgentEvent( + prompt: string, + response: string, + stopHookActive: boolean = false, + ): Promise { + if (!this.config.getEnableHooks()) { + return undefined; + } + return this.hookEventHandler.fireAfterAgentEvent( + prompt, + response, + stopHookActive, + ); + } }