diff --git a/docs/hooks/reference.md b/docs/hooks/reference.md index 6f7a82ad09..2feeedf940 100644 --- a/docs/hooks/reference.md +++ b/docs/hooks/reference.md @@ -142,6 +142,8 @@ case is response validation and automatic retries. - `reason`: Required if denied. This text is sent **to the agent as a new prompt** to request a correction. - `continue`: Set to `false` to **stop the session** without retrying. + - `clearContext`: If `true`, clears conversation history (LLM memory) while + preserving UI display. - **Exit Code 2 (Retry)**: Rejects the response and triggers an automatic retry turn using `stderr` as the feedback prompt. diff --git a/integration-tests/hooks-agent-flow.test.ts b/integration-tests/hooks-agent-flow.test.ts index 462ec155b0..13eb0bcecc 100644 --- a/integration-tests/hooks-agent-flow.test.ts +++ b/integration-tests/hooks-agent-flow.test.ts @@ -155,6 +155,84 @@ describe('Hooks Agent Flow', () => { // The fake response contains "Hello World" expect(afterAgentLog?.hookCall.stdout).toContain('Hello World'); }); + + it('should process clearContext in AfterAgent hook output', async () => { + await rig.setup('should process clearContext in AfterAgent hook output', { + fakeResponsesPath: join( + import.meta.dirname, + 'hooks-system.after-agent.responses', + ), + }); + + // BeforeModel hook to track message counts across LLM calls + const messageCountFile = join(rig.testDir!, 'message-counts.json'); + const beforeModelScript = ` + const fs = require('fs'); + const input = JSON.parse(fs.readFileSync(0, 'utf-8')); + const messageCount = input.llm_request?.contents?.length || 0; + let counts = []; + try { counts = JSON.parse(fs.readFileSync('${messageCountFile}', 'utf-8')); } catch (e) {} + counts.push(messageCount); + fs.writeFileSync('${messageCountFile}', JSON.stringify(counts)); + console.log(JSON.stringify({ decision: 'allow' })); + `; + const beforeModelScriptPath = join( + rig.testDir!, + 'before_model_counter.cjs', + ); + writeFileSync(beforeModelScriptPath, beforeModelScript); + + await rig.setup('should process clearContext in AfterAgent hook output', { + settings: { + hooks: { + enabled: true, + BeforeModel: [ + { + hooks: [ + { + type: 'command', + command: `node "${beforeModelScriptPath}"`, + timeout: 5000, + }, + ], + }, + ], + AfterAgent: [ + { + hooks: [ + { + type: 'command', + command: `node -e "console.log(JSON.stringify({decision: 'block', reason: 'Security policy triggered', hookSpecificOutput: {hookEventName: 'AfterAgent', clearContext: true}}))"`, + timeout: 5000, + }, + ], + }, + ], + }, + }, + }); + + const result = await rig.run({ args: 'Hello test' }); + + const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call'); + expect(hookTelemetryFound).toBeTruthy(); + + const hookLogs = rig.readHookLogs(); + const afterAgentLog = hookLogs.find( + (log) => log.hookCall.hook_event_name === 'AfterAgent', + ); + + expect(afterAgentLog).toBeDefined(); + expect(afterAgentLog?.hookCall.stdout).toContain('clearContext'); + expect(afterAgentLog?.hookCall.stdout).toContain('true'); + expect(result).toContain('Security policy triggered'); + + // Verify context was cleared: second call should not have more messages than first + const countsRaw = rig.readFile('message-counts.json'); + const counts = JSON.parse(countsRaw) as number[]; + expect(counts.length).toBeGreaterThanOrEqual(2); + expect(counts[1]).toBeLessThanOrEqual(counts[0]); + }); }); describe('Multi-step Loops', () => { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index a55d6b7fd7..9c44b0ee11 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -803,7 +803,12 @@ export const useGeminiStream = ( ); const handleAgentExecutionStoppedEvent = useCallback( - (reason: string, userMessageTimestamp: number, systemMessage?: string) => { + ( + reason: string, + userMessageTimestamp: number, + systemMessage?: string, + contextCleared?: boolean, + ) => { if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, userMessageTimestamp); setPendingHistoryItem(null); @@ -815,13 +820,27 @@ export const useGeminiStream = ( }, userMessageTimestamp, ); + if (contextCleared) { + addItem( + { + type: MessageType.INFO, + text: 'Conversation context has been cleared.', + }, + userMessageTimestamp, + ); + } setIsResponding(false); }, [addItem, pendingHistoryItemRef, setPendingHistoryItem, setIsResponding], ); const handleAgentExecutionBlockedEvent = useCallback( - (reason: string, userMessageTimestamp: number, systemMessage?: string) => { + ( + reason: string, + userMessageTimestamp: number, + systemMessage?: string, + contextCleared?: boolean, + ) => { if (pendingHistoryItemRef.current) { addItem(pendingHistoryItemRef.current, userMessageTimestamp); setPendingHistoryItem(null); @@ -833,6 +852,15 @@ export const useGeminiStream = ( }, userMessageTimestamp, ); + if (contextCleared) { + addItem( + { + type: MessageType.INFO, + text: 'Conversation context has been cleared.', + }, + userMessageTimestamp, + ); + } }, [addItem, pendingHistoryItemRef, setPendingHistoryItem], ); @@ -873,6 +901,7 @@ export const useGeminiStream = ( event.value.reason, userMessageTimestamp, event.value.systemMessage, + event.value.contextCleared, ); break; case ServerGeminiEventType.AgentExecutionBlocked: @@ -880,6 +909,7 @@ export const useGeminiStream = ( event.value.reason, userMessageTimestamp, event.value.systemMessage, + event.value.contextCleared, ); break; case ServerGeminiEventType.ChatCompressed: diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index cfe8bdf34b..49d71ce1a9 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -3118,6 +3118,7 @@ ${JSON.stringify( mockHookSystem.fireAfterAgentEvent.mockResolvedValue({ shouldStopExecution: () => true, getEffectiveReason: () => 'Stopped after agent', + shouldClearContext: () => false, systemMessage: undefined, }); @@ -3132,10 +3133,12 @@ ${JSON.stringify( ); const events = await fromAsync(stream); - expect(events).toContainEqual({ - type: GeminiEventType.AgentExecutionStopped, - value: { reason: 'Stopped after agent' }, - }); + expect(events).toContainEqual( + expect.objectContaining({ + type: GeminiEventType.AgentExecutionStopped, + value: expect.objectContaining({ reason: 'Stopped after agent' }), + }), + ); // sendMessageStream should not recurse expect(mockTurnRunFn).toHaveBeenCalledTimes(1); }); @@ -3146,11 +3149,60 @@ ${JSON.stringify( shouldStopExecution: () => false, isBlockingDecision: () => true, getEffectiveReason: () => 'Please explain', + shouldClearContext: () => false, systemMessage: undefined, }) .mockResolvedValueOnce({ shouldStopExecution: () => false, isBlockingDecision: () => false, + shouldClearContext: () => false, + systemMessage: undefined, + }); + + mockTurnRunFn.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Response' }; + }); + + const stream = client.sendMessageStream( + { text: 'Hi' }, + new AbortController().signal, + 'test-prompt', + ); + const events = await fromAsync(stream); + + expect(events).toContainEqual( + expect.objectContaining({ + type: GeminiEventType.AgentExecutionBlocked, + value: expect.objectContaining({ reason: 'Please explain' }), + }), + ); + // Should have called turn run twice (original + re-prompt) + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); + expect(mockTurnRunFn).toHaveBeenNthCalledWith( + 2, + expect.anything(), + [{ text: 'Please explain' }], + expect.anything(), + ); + }); + + it('should call resetChat when AfterAgent hook returns shouldClearContext: true', async () => { + const resetChatSpy = vi + .spyOn(client, 'resetChat') + .mockResolvedValue(undefined); + + mockHookSystem.fireAfterAgentEvent + .mockResolvedValueOnce({ + shouldStopExecution: () => false, + isBlockingDecision: () => true, + getEffectiveReason: () => 'Blocked and clearing context', + shouldClearContext: () => true, + systemMessage: undefined, + }) + .mockResolvedValueOnce({ + shouldStopExecution: () => false, + isBlockingDecision: () => false, + shouldClearContext: () => false, systemMessage: undefined, }); @@ -3167,16 +3219,15 @@ ${JSON.stringify( expect(events).toContainEqual({ type: GeminiEventType.AgentExecutionBlocked, - value: { reason: 'Please explain' }, + value: { + reason: 'Blocked and clearing context', + systemMessage: undefined, + contextCleared: true, + }, }); - // Should have called turn run twice (original + re-prompt) - expect(mockTurnRunFn).toHaveBeenCalledTimes(2); - expect(mockTurnRunFn).toHaveBeenNthCalledWith( - 2, - expect.anything(), - [{ text: 'Please explain' }], - expect.anything(), - ); + expect(resetChatSpy).toHaveBeenCalledTimes(1); + + resetChatSpy.mockRestore(); }); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 7d8c70b0b5..2adb5d8bad 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -40,7 +40,10 @@ import { logContentRetryFailure, logNextSpeakerCheck, } from '../telemetry/loggers.js'; -import type { DefaultHookOutput } from '../hooks/types.js'; +import type { + DefaultHookOutput, + AfterAgentHookOutput, +} from '../hooks/types.js'; import { ContentRetryFailureEvent, NextSpeakerCheckEvent, @@ -816,26 +819,41 @@ export class GeminiClient { turn, ); - if (hookOutput?.shouldStopExecution()) { + // Cast to AfterAgentHookOutput for access to shouldClearContext() + const afterAgentOutput = hookOutput as AfterAgentHookOutput | undefined; + + if (afterAgentOutput?.shouldStopExecution()) { + const contextCleared = afterAgentOutput.shouldClearContext(); yield { type: GeminiEventType.AgentExecutionStopped, value: { - reason: hookOutput.getEffectiveReason(), - systemMessage: hookOutput.systemMessage, + reason: afterAgentOutput.getEffectiveReason(), + systemMessage: afterAgentOutput.systemMessage, + contextCleared, }, }; + // Clear context if requested (honor both stop + clear) + if (contextCleared) { + await this.resetChat(); + } return turn; } - if (hookOutput?.isBlockingDecision()) { - const continueReason = hookOutput.getEffectiveReason(); + if (afterAgentOutput?.isBlockingDecision()) { + const continueReason = afterAgentOutput.getEffectiveReason(); + const contextCleared = afterAgentOutput.shouldClearContext(); yield { type: GeminiEventType.AgentExecutionBlocked, value: { reason: continueReason, - systemMessage: hookOutput.systemMessage, + systemMessage: afterAgentOutput.systemMessage, + contextCleared, }, }; + // Clear context if requested + if (contextCleared) { + await this.resetChat(); + } const continueRequest = [{ text: continueReason }]; yield* this.sendMessageStream( continueRequest, diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 099530c90a..8e6974704d 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -79,6 +79,7 @@ export type ServerGeminiAgentExecutionStoppedEvent = { value: { reason: string; systemMessage?: string; + contextCleared?: boolean; }; }; @@ -87,6 +88,7 @@ export type ServerGeminiAgentExecutionBlockedEvent = { value: { reason: string; systemMessage?: string; + contextCleared?: boolean; }; }; diff --git a/packages/core/src/hooks/hookAggregator.ts b/packages/core/src/hooks/hookAggregator.ts index 0163f21856..0583c08776 100644 --- a/packages/core/src/hooks/hookAggregator.ts +++ b/packages/core/src/hooks/hookAggregator.ts @@ -16,6 +16,7 @@ import { BeforeModelHookOutput, BeforeToolSelectionHookOutput, AfterModelHookOutput, + AfterAgentHookOutput, } from './types.js'; import { HookEventName } from './types.js'; @@ -158,11 +159,21 @@ export class HookAggregator { merged.suppressOutput = true; } - // Merge hookSpecificOutput - if (output.hookSpecificOutput) { + // Handle clearContext (any true wins) - for AfterAgent hooks + if (output.hookSpecificOutput?.['clearContext'] === true) { merged.hookSpecificOutput = { ...(merged.hookSpecificOutput || {}), - ...output.hookSpecificOutput, + clearContext: true, + }; + } + + // Merge hookSpecificOutput (excluding clearContext which is handled above) + if (output.hookSpecificOutput) { + const { clearContext: _clearContext, ...restSpecificOutput } = + output.hookSpecificOutput; + merged.hookSpecificOutput = { + ...(merged.hookSpecificOutput || {}), + ...restSpecificOutput, }; } @@ -323,6 +334,8 @@ export class HookAggregator { return new BeforeToolSelectionHookOutput(output); case HookEventName.AfterModel: return new AfterModelHookOutput(output); + case HookEventName.AfterAgent: + return new AfterAgentHookOutput(output); default: return new DefaultHookOutput(output); } diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index e115cc27cc..fbcb6dd51d 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -140,6 +140,8 @@ export function createHookOutput( return new BeforeToolSelectionHookOutput(data); case 'BeforeTool': return new BeforeToolHookOutput(data); + case 'AfterAgent': + return new AfterAgentHookOutput(data); default: return new DefaultHookOutput(data); } @@ -243,6 +245,13 @@ export class DefaultHookOutput implements HookOutput { } return { blocked: false, reason: '' }; } + + /** + * Check if context clearing was requested by hook. + */ + shouldClearContext(): boolean { + return false; + } } /** @@ -367,6 +376,21 @@ export class AfterModelHookOutput extends DefaultHookOutput { } } +/** + * Specific hook output class for AfterAgent events + */ +export class AfterAgentHookOutput extends DefaultHookOutput { + /** + * Check if context clearing was requested by hook + */ + override shouldClearContext(): boolean { + if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) { + return this.hookSpecificOutput['clearContext'] === true; + } + return false; + } +} + /** * Context for MCP tool executions. * Contains non-sensitive connection information about the MCP server @@ -480,6 +504,16 @@ export interface AfterAgentInput extends HookInput { stop_hook_active: boolean; } +/** + * AfterAgent hook output + */ +export interface AfterAgentOutput extends HookOutput { + hookSpecificOutput?: { + hookEventName: 'AfterAgent'; + clearContext?: boolean; + }; +} + /** * SessionStart source types */