From 8faf2e31eaccf94fdd81493f53ef2a0f30e96440 Mon Sep 17 00:00:00 2001 From: Michael Bleigh Date: Thu, 26 Mar 2026 16:05:21 -0700 Subject: [PATCH] feat(core): implement additionalContext for BeforeModel hooks and aggregate context from multiple hooks --- docs/hooks/reference.md | 4 ++ .../core/src/hooks/hookAggregator.test.ts | 46 ++++++++++++++ packages/core/src/hooks/hookAggregator.ts | 12 ++++ packages/core/src/hooks/hookRunner.test.ts | 5 +- packages/core/src/hooks/hookRunner.ts | 17 ------ packages/core/src/hooks/types.test.ts | 61 +++++++++++++++++++ packages/core/src/hooks/types.ts | 57 ++++++++++++++++- 7 files changed, 180 insertions(+), 22 deletions(-) diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index 5242c3a13d..c534922818 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -197,6 +197,10 @@ request format. outgoing request (e.g., changing models or temperature). - `hookSpecificOutput.llm_response`: A **Synthetic Response** object. If provided, the CLI skips the LLM call entirely and uses this as the response. + - `hookSpecificOutput.additionalContext`: (`string`) Text that is **appended** + to the end of the user message (wrapped in `` tags). If + `llm_request` also modifies contents, it will be appended to the modified + contents. - `decision`: Set to `"deny"` to block the request and abort the turn. - **Exit Code 2 (Block Turn)**: Aborts the turn and skips the LLM call. Uses `stderr` as the error message. diff --git a/packages/core/src/hooks/hookAggregator.test.ts b/packages/core/src/hooks/hookAggregator.test.ts index ee9ade9a87..8f22919656 100644 --- a/packages/core/src/hooks/hookAggregator.test.ts +++ b/packages/core/src/hooks/hookAggregator.test.ts @@ -423,6 +423,52 @@ describe('HookAggregator', () => { const llmRequest = output.hookSpecificOutput?.llm_request; expect(llmRequest?.['model']).toBe('model2'); // Later value wins }); + + it('should aggregate additionalContext for BeforeModel hooks', () => { + const results: HookExecutionResult[] = [ + { + hookConfig: { + type: HookType.Command, + command: 'h1', + timeout: 30000, + }, + eventName: HookEventName.BeforeModel, + success: true, + output: { + hookSpecificOutput: { + hookEventName: 'BeforeModel', + additionalContext: 'Context 1', + }, + }, + duration: 10, + }, + { + hookConfig: { + type: HookType.Command, + command: 'h2', + timeout: 30000, + }, + eventName: HookEventName.BeforeModel, + success: true, + output: { + hookSpecificOutput: { + hookEventName: 'BeforeModel', + additionalContext: 'Context 2', + }, + }, + duration: 10, + }, + ]; + + const aggregated = aggregator.aggregateResults( + results, + HookEventName.BeforeModel, + ); + + expect( + aggregated.finalOutput?.hookSpecificOutput?.['additionalContext'], + ).toBe('Context 1\nContext 2'); + }); }); describe('extractAdditionalContext', () => { diff --git a/packages/core/src/hooks/hookAggregator.ts b/packages/core/src/hooks/hookAggregator.ts index b67266edf5..31ee2b7113 100644 --- a/packages/core/src/hooks/hookAggregator.ts +++ b/packages/core/src/hooks/hookAggregator.ts @@ -221,8 +221,12 @@ export class HookAggregator { */ private mergeWithFieldReplacement(outputs: HookOutput[]): HookOutput { let merged: HookOutput = {}; + const additionalContexts: string[] = []; for (const output of outputs) { + // Collect additional context + this.extractAdditionalContext(output, additionalContexts); + // Later outputs override earlier ones merged = { ...merged, @@ -234,6 +238,14 @@ export class HookAggregator { }; } + // Add merged additional context + if (additionalContexts.length > 0) { + merged.hookSpecificOutput = { + ...(merged.hookSpecificOutput || {}), + additionalContext: additionalContexts.join('\n'), + }; + } + return merged; } diff --git a/packages/core/src/hooks/hookRunner.test.ts b/packages/core/src/hooks/hookRunner.test.ts index eb806aba3d..95c04d39a0 100644 --- a/packages/core/src/hooks/hookRunner.test.ts +++ b/packages/core/src/hooks/hookRunner.test.ts @@ -655,10 +655,9 @@ describe('HookRunner', () => { // Verify that the second hook received modified input const secondHookInput = JSON.parse( - vi.mocked(mockSpawn.stdin.write).mock.calls[1][0], + vi.mocked(mockSpawn.stdin.write).mock.calls[1][0] as string, ); - expect(secondHookInput.prompt).toContain('Original prompt'); - expect(secondHookInput.prompt).toContain('Context from hook 1'); + expect(secondHookInput.prompt).toBe('Original prompt'); }); it('should pass modified LLM request from one hook to the next for BeforeModel', async () => { diff --git a/packages/core/src/hooks/hookRunner.ts b/packages/core/src/hooks/hookRunner.ts index 4f44958787..82e9acdccb 100644 --- a/packages/core/src/hooks/hookRunner.ts +++ b/packages/core/src/hooks/hookRunner.ts @@ -15,7 +15,6 @@ import { type HookInput, type HookOutput, type HookExecutionResult, - type BeforeAgentInput, type BeforeModelInput, type BeforeModelOutput, type BeforeToolInput, @@ -180,22 +179,6 @@ export class HookRunner { // Apply modifications based on hook output and event type if (hookOutput.hookSpecificOutput) { switch (eventName) { - case HookEventName.BeforeAgent: - if ('additionalContext' in hookOutput.hookSpecificOutput) { - // For BeforeAgent, we could modify the prompt with additional context - const additionalContext = - hookOutput.hookSpecificOutput['additionalContext']; - if ( - typeof additionalContext === 'string' && - 'prompt' in modifiedInput - ) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - (modifiedInput as BeforeAgentInput).prompt += - '\n\n' + additionalContext; - } - } - break; - case HookEventName.BeforeModel: if ('llm_request' in hookOutput.hookSpecificOutput) { // For BeforeModel, we update the LLM request diff --git a/packages/core/src/hooks/types.test.ts b/packages/core/src/hooks/types.test.ts index ab809cbec7..5c5d218ee4 100644 --- a/packages/core/src/hooks/types.test.ts +++ b/packages/core/src/hooks/types.test.ts @@ -24,6 +24,7 @@ import { import type { GenerateContentParameters, GenerateContentResponse, + Content, ToolConfig, } from '@google/genai'; @@ -271,6 +272,66 @@ describe('Hook Output Classes', () => { const output = new BeforeModelHookOutput({}); expect(output.applyLLMRequestModifications(target)).toBe(target); }); + + it('applyLLMRequestModifications should append additionalContext to contents', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'Hello' }] }], + }; + const output = new BeforeModelHookOutput({ + hookSpecificOutput: { additionalContext: 'New Context' }, + }); + const result = output.applyLLMRequestModifications(target); + expect(result.contents).toHaveLength(1); + const contents = result.contents as Content[]; + expect(contents[0].parts!).toHaveLength(2); + expect(contents[0].parts![1]).toEqual({ + text: '\n\nNew Context', + }); + }); + + it('applyLLMRequestModifications should create new user message if none exists', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'model', parts: [{ text: 'Hi' }] }], + }; + const output = new BeforeModelHookOutput({ + hookSpecificOutput: { additionalContext: 'New Context' }, + }); + const result = output.applyLLMRequestModifications(target); + expect(result.contents).toHaveLength(2); + const contents = result.contents as Content[]; + expect(contents[1].role).toBe('user'); + expect(contents[1].parts![0]).toEqual({ + text: '\n\nNew Context', + }); + }); + + it('applyLLMRequestModifications should handle both llm_request and additionalContext', () => { + const target: GenerateContentParameters = { + model: 'gemini-pro', + contents: [{ role: 'user', parts: [{ text: 'original' }] }], + }; + const mockRequest = { + // Our mock fromHookLLMRequest just spreads the request, so we can use contents here for convenience in tests + // even though LLMRequest uses messages. + contents: [{ role: 'user', parts: [{ text: 'modified' }] }], + } as unknown as Partial; + const output = new BeforeModelHookOutput({ + hookSpecificOutput: { + llm_request: mockRequest as LLMRequest, + additionalContext: 'New Context', + }, + }); + const result = output.applyLLMRequestModifications(target); + expect(result.contents).toHaveLength(1); + const contents = result.contents as Content[]; + expect(contents[0].parts!).toHaveLength(2); + expect(contents[0].parts![0]).toEqual({ text: 'modified' }); + expect(contents[0].parts![1]).toEqual({ + text: '\n\nNew Context', + }); + }); }); describe('BeforeToolSelectionHookOutput', () => { diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index 11dbe874e5..2f039f56cb 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -7,6 +7,7 @@ import type { GenerateContentResponse, GenerateContentParameters, + Content, ToolConfig as GenAIToolConfig, ToolListUnion, } from '@google/genai'; @@ -374,6 +375,8 @@ export class BeforeModelHookOutput extends DefaultHookOutput { override applyLLMRequestModifications( target: GenerateContentParameters, ): GenerateContentParameters { + let resultTarget = target; + if (this.hookSpecificOutput && 'llm_request' in this.hookSpecificOutput) { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion const hookRequest = this.hookSpecificOutput[ @@ -386,13 +389,62 @@ export class BeforeModelHookOutput extends DefaultHookOutput { hookRequest as LLMRequest, target, ); - return { + resultTarget = { ...target, ...sdkRequest, }; } } - return target; + + const additionalContext = this.getAdditionalContext(); + if (additionalContext) { + const originalContents = resultTarget.contents; + let contents: Content[]; + + if (Array.isArray(originalContents)) { + contents = (originalContents as unknown[]).map((c) => { + if (typeof c === 'string') { + return { role: 'user', parts: [{ text: c }] }; + } + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const content = c as Content; + return { ...content, parts: [...(content.parts || [])] }; + }); + } else if (typeof originalContents === 'string') { + contents = [{ role: 'user', parts: [{ text: originalContents }] }]; + } else if (originalContents && 'role' in originalContents) { + const c = originalContents; + contents = [{ ...c, parts: [...(c.parts || [])] }]; + } else { + contents = []; + } + + const wrappedContext = `\n\n${additionalContext}`; + + let lastUserMessageIndex = -1; + for (let i = contents.length - 1; i >= 0; i--) { + if (contents[i].role === 'user') { + lastUserMessageIndex = i; + break; + } + } + + if (lastUserMessageIndex !== -1) { + if (!contents[lastUserMessageIndex].parts) { + contents[lastUserMessageIndex].parts = []; + } + contents[lastUserMessageIndex].parts!.push({ text: wrappedContext }); + } else { + contents.push({ role: 'user', parts: [{ text: wrappedContext }] }); + } + + resultTarget = { + ...resultTarget, + contents, + }; + } + + return resultTarget; } } @@ -683,6 +735,7 @@ export interface BeforeModelOutput extends HookOutput { hookEventName: 'BeforeModel'; llm_request?: Partial; llm_response?: LLMResponse; + additionalContext?: string; }; }