feat: add clearContext to AfterAgent hooks (#16574)

This commit is contained in:
Jack Wotherspoon
2026-01-23 17:14:30 -05:00
committed by GitHub
parent 5c649d8db1
commit da1664c7a0
8 changed files with 253 additions and 25 deletions
+2
View File
@@ -142,6 +142,8 @@ case is response validation and automatic retries.
- `reason`: Required if denied. This text is sent **to the agent as a new - `reason`: Required if denied. This text is sent **to the agent as a new
prompt** to request a correction. prompt** to request a correction.
- `continue`: Set to `false` to **stop the session** without retrying. - `continue`: Set to `false` to **stop the session** without retrying.
- `clearContext`: If `true`, clears conversation history (LLM memory) while
preserving UI display.
- **Exit Code 2 (Retry)**: Rejects the response and triggers an automatic retry - **Exit Code 2 (Retry)**: Rejects the response and triggers an automatic retry
turn using `stderr` as the feedback prompt. turn using `stderr` as the feedback prompt.
@@ -155,6 +155,84 @@ describe('Hooks Agent Flow', () => {
// The fake response contains "Hello World" // The fake response contains "Hello World"
expect(afterAgentLog?.hookCall.stdout).toContain('Hello World'); expect(afterAgentLog?.hookCall.stdout).toContain('Hello World');
}); });
it('should process clearContext in AfterAgent hook output', async () => {
await rig.setup('should process clearContext in AfterAgent hook output', {
fakeResponsesPath: join(
import.meta.dirname,
'hooks-system.after-agent.responses',
),
});
// BeforeModel hook to track message counts across LLM calls
const messageCountFile = join(rig.testDir!, 'message-counts.json');
const beforeModelScript = `
const fs = require('fs');
const input = JSON.parse(fs.readFileSync(0, 'utf-8'));
const messageCount = input.llm_request?.contents?.length || 0;
let counts = [];
try { counts = JSON.parse(fs.readFileSync('${messageCountFile}', 'utf-8')); } catch (e) {}
counts.push(messageCount);
fs.writeFileSync('${messageCountFile}', JSON.stringify(counts));
console.log(JSON.stringify({ decision: 'allow' }));
`;
const beforeModelScriptPath = join(
rig.testDir!,
'before_model_counter.cjs',
);
writeFileSync(beforeModelScriptPath, beforeModelScript);
await rig.setup('should process clearContext in AfterAgent hook output', {
settings: {
hooks: {
enabled: true,
BeforeModel: [
{
hooks: [
{
type: 'command',
command: `node "${beforeModelScriptPath}"`,
timeout: 5000,
},
],
},
],
AfterAgent: [
{
hooks: [
{
type: 'command',
command: `node -e "console.log(JSON.stringify({decision: 'block', reason: 'Security policy triggered', hookSpecificOutput: {hookEventName: 'AfterAgent', clearContext: true}}))"`,
timeout: 5000,
},
],
},
],
},
},
});
const result = await rig.run({ args: 'Hello test' });
const hookTelemetryFound = await rig.waitForTelemetryEvent('hook_call');
expect(hookTelemetryFound).toBeTruthy();
const hookLogs = rig.readHookLogs();
const afterAgentLog = hookLogs.find(
(log) => log.hookCall.hook_event_name === 'AfterAgent',
);
expect(afterAgentLog).toBeDefined();
expect(afterAgentLog?.hookCall.stdout).toContain('clearContext');
expect(afterAgentLog?.hookCall.stdout).toContain('true');
expect(result).toContain('Security policy triggered');
// Verify context was cleared: second call should not have more messages than first
const countsRaw = rig.readFile('message-counts.json');
const counts = JSON.parse(countsRaw) as number[];
expect(counts.length).toBeGreaterThanOrEqual(2);
expect(counts[1]).toBeLessThanOrEqual(counts[0]);
});
}); });
describe('Multi-step Loops', () => { describe('Multi-step Loops', () => {
+32 -2
View File
@@ -803,7 +803,12 @@ export const useGeminiStream = (
); );
const handleAgentExecutionStoppedEvent = useCallback( const handleAgentExecutionStoppedEvent = useCallback(
(reason: string, userMessageTimestamp: number, systemMessage?: string) => { (
reason: string,
userMessageTimestamp: number,
systemMessage?: string,
contextCleared?: boolean,
) => {
if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp); addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null); setPendingHistoryItem(null);
@@ -815,13 +820,27 @@ export const useGeminiStream = (
}, },
userMessageTimestamp, userMessageTimestamp,
); );
if (contextCleared) {
addItem(
{
type: MessageType.INFO,
text: 'Conversation context has been cleared.',
},
userMessageTimestamp,
);
}
setIsResponding(false); setIsResponding(false);
}, },
[addItem, pendingHistoryItemRef, setPendingHistoryItem, setIsResponding], [addItem, pendingHistoryItemRef, setPendingHistoryItem, setIsResponding],
); );
const handleAgentExecutionBlockedEvent = useCallback( const handleAgentExecutionBlockedEvent = useCallback(
(reason: string, userMessageTimestamp: number, systemMessage?: string) => { (
reason: string,
userMessageTimestamp: number,
systemMessage?: string,
contextCleared?: boolean,
) => {
if (pendingHistoryItemRef.current) { if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp); addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null); setPendingHistoryItem(null);
@@ -833,6 +852,15 @@ export const useGeminiStream = (
}, },
userMessageTimestamp, userMessageTimestamp,
); );
if (contextCleared) {
addItem(
{
type: MessageType.INFO,
text: 'Conversation context has been cleared.',
},
userMessageTimestamp,
);
}
}, },
[addItem, pendingHistoryItemRef, setPendingHistoryItem], [addItem, pendingHistoryItemRef, setPendingHistoryItem],
); );
@@ -873,6 +901,7 @@ export const useGeminiStream = (
event.value.reason, event.value.reason,
userMessageTimestamp, userMessageTimestamp,
event.value.systemMessage, event.value.systemMessage,
event.value.contextCleared,
); );
break; break;
case ServerGeminiEventType.AgentExecutionBlocked: case ServerGeminiEventType.AgentExecutionBlocked:
@@ -880,6 +909,7 @@ export const useGeminiStream = (
event.value.reason, event.value.reason,
userMessageTimestamp, userMessageTimestamp,
event.value.systemMessage, event.value.systemMessage,
event.value.contextCleared,
); );
break; break;
case ServerGeminiEventType.ChatCompressed: case ServerGeminiEventType.ChatCompressed:
+64 -13
View File
@@ -3118,6 +3118,7 @@ ${JSON.stringify(
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({ mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
shouldStopExecution: () => true, shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped after agent', getEffectiveReason: () => 'Stopped after agent',
shouldClearContext: () => false,
systemMessage: undefined, systemMessage: undefined,
}); });
@@ -3132,10 +3133,12 @@ ${JSON.stringify(
); );
const events = await fromAsync(stream); const events = await fromAsync(stream);
expect(events).toContainEqual({ expect(events).toContainEqual(
type: GeminiEventType.AgentExecutionStopped, expect.objectContaining({
value: { reason: 'Stopped after agent' }, type: GeminiEventType.AgentExecutionStopped,
}); value: expect.objectContaining({ reason: 'Stopped after agent' }),
}),
);
// sendMessageStream should not recurse // sendMessageStream should not recurse
expect(mockTurnRunFn).toHaveBeenCalledTimes(1); expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
}); });
@@ -3146,11 +3149,60 @@ ${JSON.stringify(
shouldStopExecution: () => false, shouldStopExecution: () => false,
isBlockingDecision: () => true, isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain', getEffectiveReason: () => 'Please explain',
shouldClearContext: () => false,
systemMessage: undefined, systemMessage: undefined,
}) })
.mockResolvedValueOnce({ .mockResolvedValueOnce({
shouldStopExecution: () => false, shouldStopExecution: () => false,
isBlockingDecision: () => false, isBlockingDecision: () => false,
shouldClearContext: () => false,
systemMessage: undefined,
});
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Response' };
});
const stream = client.sendMessageStream(
{ text: 'Hi' },
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual(
expect.objectContaining({
type: GeminiEventType.AgentExecutionBlocked,
value: expect.objectContaining({ reason: 'Please explain' }),
}),
);
// Should have called turn run twice (original + re-prompt)
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.anything(),
[{ text: 'Please explain' }],
expect.anything(),
);
});
it('should call resetChat when AfterAgent hook returns shouldClearContext: true', async () => {
const resetChatSpy = vi
.spyOn(client, 'resetChat')
.mockResolvedValue(undefined);
mockHookSystem.fireAfterAgentEvent
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Blocked and clearing context',
shouldClearContext: () => true,
systemMessage: undefined,
})
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => false,
shouldClearContext: () => false,
systemMessage: undefined, systemMessage: undefined,
}); });
@@ -3167,16 +3219,15 @@ ${JSON.stringify(
expect(events).toContainEqual({ expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionBlocked, type: GeminiEventType.AgentExecutionBlocked,
value: { reason: 'Please explain' }, value: {
reason: 'Blocked and clearing context',
systemMessage: undefined,
contextCleared: true,
},
}); });
// Should have called turn run twice (original + re-prompt) expect(resetChatSpy).toHaveBeenCalledTimes(1);
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith( resetChatSpy.mockRestore();
2,
expect.anything(),
[{ text: 'Please explain' }],
expect.anything(),
);
}); });
}); });
}); });
+25 -7
View File
@@ -40,7 +40,10 @@ import {
logContentRetryFailure, logContentRetryFailure,
logNextSpeakerCheck, logNextSpeakerCheck,
} from '../telemetry/loggers.js'; } from '../telemetry/loggers.js';
import type { DefaultHookOutput } from '../hooks/types.js'; import type {
DefaultHookOutput,
AfterAgentHookOutput,
} from '../hooks/types.js';
import { import {
ContentRetryFailureEvent, ContentRetryFailureEvent,
NextSpeakerCheckEvent, NextSpeakerCheckEvent,
@@ -816,26 +819,41 @@ export class GeminiClient {
turn, turn,
); );
if (hookOutput?.shouldStopExecution()) { // Cast to AfterAgentHookOutput for access to shouldClearContext()
const afterAgentOutput = hookOutput as AfterAgentHookOutput | undefined;
if (afterAgentOutput?.shouldStopExecution()) {
const contextCleared = afterAgentOutput.shouldClearContext();
yield { yield {
type: GeminiEventType.AgentExecutionStopped, type: GeminiEventType.AgentExecutionStopped,
value: { value: {
reason: hookOutput.getEffectiveReason(), reason: afterAgentOutput.getEffectiveReason(),
systemMessage: hookOutput.systemMessage, systemMessage: afterAgentOutput.systemMessage,
contextCleared,
}, },
}; };
// Clear context if requested (honor both stop + clear)
if (contextCleared) {
await this.resetChat();
}
return turn; return turn;
} }
if (hookOutput?.isBlockingDecision()) { if (afterAgentOutput?.isBlockingDecision()) {
const continueReason = hookOutput.getEffectiveReason(); const continueReason = afterAgentOutput.getEffectiveReason();
const contextCleared = afterAgentOutput.shouldClearContext();
yield { yield {
type: GeminiEventType.AgentExecutionBlocked, type: GeminiEventType.AgentExecutionBlocked,
value: { value: {
reason: continueReason, reason: continueReason,
systemMessage: hookOutput.systemMessage, systemMessage: afterAgentOutput.systemMessage,
contextCleared,
}, },
}; };
// Clear context if requested
if (contextCleared) {
await this.resetChat();
}
const continueRequest = [{ text: continueReason }]; const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream( yield* this.sendMessageStream(
continueRequest, continueRequest,
+2
View File
@@ -79,6 +79,7 @@ export type ServerGeminiAgentExecutionStoppedEvent = {
value: { value: {
reason: string; reason: string;
systemMessage?: string; systemMessage?: string;
contextCleared?: boolean;
}; };
}; };
@@ -87,6 +88,7 @@ export type ServerGeminiAgentExecutionBlockedEvent = {
value: { value: {
reason: string; reason: string;
systemMessage?: string; systemMessage?: string;
contextCleared?: boolean;
}; };
}; };
+16 -3
View File
@@ -16,6 +16,7 @@ import {
BeforeModelHookOutput, BeforeModelHookOutput,
BeforeToolSelectionHookOutput, BeforeToolSelectionHookOutput,
AfterModelHookOutput, AfterModelHookOutput,
AfterAgentHookOutput,
} from './types.js'; } from './types.js';
import { HookEventName } from './types.js'; import { HookEventName } from './types.js';
@@ -158,11 +159,21 @@ export class HookAggregator {
merged.suppressOutput = true; merged.suppressOutput = true;
} }
// Merge hookSpecificOutput // Handle clearContext (any true wins) - for AfterAgent hooks
if (output.hookSpecificOutput) { if (output.hookSpecificOutput?.['clearContext'] === true) {
merged.hookSpecificOutput = { merged.hookSpecificOutput = {
...(merged.hookSpecificOutput || {}), ...(merged.hookSpecificOutput || {}),
...output.hookSpecificOutput, clearContext: true,
};
}
// Merge hookSpecificOutput (excluding clearContext which is handled above)
if (output.hookSpecificOutput) {
const { clearContext: _clearContext, ...restSpecificOutput } =
output.hookSpecificOutput;
merged.hookSpecificOutput = {
...(merged.hookSpecificOutput || {}),
...restSpecificOutput,
}; };
} }
@@ -323,6 +334,8 @@ export class HookAggregator {
return new BeforeToolSelectionHookOutput(output); return new BeforeToolSelectionHookOutput(output);
case HookEventName.AfterModel: case HookEventName.AfterModel:
return new AfterModelHookOutput(output); return new AfterModelHookOutput(output);
case HookEventName.AfterAgent:
return new AfterAgentHookOutput(output);
default: default:
return new DefaultHookOutput(output); return new DefaultHookOutput(output);
} }
+34
View File
@@ -140,6 +140,8 @@ export function createHookOutput(
return new BeforeToolSelectionHookOutput(data); return new BeforeToolSelectionHookOutput(data);
case 'BeforeTool': case 'BeforeTool':
return new BeforeToolHookOutput(data); return new BeforeToolHookOutput(data);
case 'AfterAgent':
return new AfterAgentHookOutput(data);
default: default:
return new DefaultHookOutput(data); return new DefaultHookOutput(data);
} }
@@ -243,6 +245,13 @@ export class DefaultHookOutput implements HookOutput {
} }
return { blocked: false, reason: '' }; return { blocked: false, reason: '' };
} }
/**
* Check if context clearing was requested by hook.
*/
shouldClearContext(): boolean {
return false;
}
} }
/** /**
@@ -367,6 +376,21 @@ export class AfterModelHookOutput extends DefaultHookOutput {
} }
} }
/**
* Specific hook output class for AfterAgent events
*/
export class AfterAgentHookOutput extends DefaultHookOutput {
/**
* Check if context clearing was requested by hook
*/
override shouldClearContext(): boolean {
if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) {
return this.hookSpecificOutput['clearContext'] === true;
}
return false;
}
}
/** /**
* Context for MCP tool executions. * Context for MCP tool executions.
* Contains non-sensitive connection information about the MCP server * Contains non-sensitive connection information about the MCP server
@@ -480,6 +504,16 @@ export interface AfterAgentInput extends HookInput {
stop_hook_active: boolean; stop_hook_active: boolean;
} }
/**
* AfterAgent hook output
*/
export interface AfterAgentOutput extends HookOutput {
hookSpecificOutput?: {
hookEventName: 'AfterAgent';
clearContext?: boolean;
};
}
/** /**
* SessionStart source types * SessionStart source types
*/ */