feat: add clearContext to AfterAgent hooks (#16574)

This commit is contained in:
Jack Wotherspoon
2026-01-23 17:14:30 -05:00
committed by Sandy Tao
parent 958cc45937
commit 2a3c879782
8 changed files with 462 additions and 32 deletions

View File

@@ -3107,6 +3107,7 @@ ${JSON.stringify(
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped after agent',
shouldClearContext: () => false,
systemMessage: undefined,
});
@@ -3121,10 +3122,12 @@ ${JSON.stringify(
);
const events = await fromAsync(stream);
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionStopped,
value: { reason: 'Stopped after agent' },
});
expect(events).toContainEqual(
expect.objectContaining({
type: GeminiEventType.AgentExecutionStopped,
value: expect.objectContaining({ reason: 'Stopped after agent' }),
}),
);
// sendMessageStream should not recurse
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
});
@@ -3135,11 +3138,60 @@ ${JSON.stringify(
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain',
shouldClearContext: () => false,
systemMessage: undefined,
})
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => false,
shouldClearContext: () => false,
systemMessage: undefined,
});
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Response' };
});
const stream = client.sendMessageStream(
{ text: 'Hi' },
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual(
expect.objectContaining({
type: GeminiEventType.AgentExecutionBlocked,
value: expect.objectContaining({ reason: 'Please explain' }),
}),
);
// Should have called turn run twice (original + re-prompt)
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.anything(),
[{ text: 'Please explain' }],
expect.anything(),
);
});
it('should call resetChat when AfterAgent hook returns shouldClearContext: true', async () => {
const resetChatSpy = vi
.spyOn(client, 'resetChat')
.mockResolvedValue(undefined);
mockHookSystem.fireAfterAgentEvent
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Blocked and clearing context',
shouldClearContext: () => true,
systemMessage: undefined,
})
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => false,
shouldClearContext: () => false,
systemMessage: undefined,
});
@@ -3156,16 +3208,15 @@ ${JSON.stringify(
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionBlocked,
value: { reason: 'Please explain' },
value: {
reason: 'Blocked and clearing context',
systemMessage: undefined,
contextCleared: true,
},
});
// Should have called turn run twice (original + re-prompt)
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.anything(),
[{ text: 'Please explain' }],
expect.anything(),
);
expect(resetChatSpy).toHaveBeenCalledTimes(1);
resetChatSpy.mockRestore();
});
});
});

View File

@@ -40,7 +40,10 @@ import {
logContentRetryFailure,
logNextSpeakerCheck,
} from '../telemetry/loggers.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import type {
DefaultHookOutput,
AfterAgentHookOutput,
} from '../hooks/types.js';
import {
ContentRetryFailureEvent,
NextSpeakerCheckEvent,
@@ -812,26 +815,41 @@ export class GeminiClient {
turn,
);
if (hookOutput?.shouldStopExecution()) {
// Cast to AfterAgentHookOutput for access to shouldClearContext()
const afterAgentOutput = hookOutput as AfterAgentHookOutput | undefined;
if (afterAgentOutput?.shouldStopExecution()) {
const contextCleared = afterAgentOutput.shouldClearContext();
yield {
type: GeminiEventType.AgentExecutionStopped,
value: {
reason: hookOutput.getEffectiveReason(),
systemMessage: hookOutput.systemMessage,
reason: afterAgentOutput.getEffectiveReason(),
systemMessage: afterAgentOutput.systemMessage,
contextCleared,
},
};
// Clear context if requested (honor both stop + clear)
if (contextCleared) {
await this.resetChat();
}
return turn;
}
if (hookOutput?.isBlockingDecision()) {
const continueReason = hookOutput.getEffectiveReason();
if (afterAgentOutput?.isBlockingDecision()) {
const continueReason = afterAgentOutput.getEffectiveReason();
const contextCleared = afterAgentOutput.shouldClearContext();
yield {
type: GeminiEventType.AgentExecutionBlocked,
value: {
reason: continueReason,
systemMessage: hookOutput.systemMessage,
systemMessage: afterAgentOutput.systemMessage,
contextCleared,
},
};
// Clear context if requested
if (contextCleared) {
await this.resetChat();
}
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,

View File

@@ -79,6 +79,7 @@ export type ServerGeminiAgentExecutionStoppedEvent = {
value: {
reason: string;
systemMessage?: string;
contextCleared?: boolean;
};
};
@@ -87,6 +88,7 @@ export type ServerGeminiAgentExecutionBlockedEvent = {
value: {
reason: string;
systemMessage?: string;
contextCleared?: boolean;
};
};

View File

@@ -16,6 +16,7 @@ import {
BeforeModelHookOutput,
BeforeToolSelectionHookOutput,
AfterModelHookOutput,
AfterAgentHookOutput,
} from './types.js';
import { HookEventName } from './types.js';
@@ -158,11 +159,21 @@ export class HookAggregator {
merged.suppressOutput = true;
}
// Merge hookSpecificOutput
if (output.hookSpecificOutput) {
// Handle clearContext (any true wins) - for AfterAgent hooks
if (output.hookSpecificOutput?.['clearContext'] === true) {
merged.hookSpecificOutput = {
...(merged.hookSpecificOutput || {}),
...output.hookSpecificOutput,
clearContext: true,
};
}
// Merge hookSpecificOutput (excluding clearContext which is handled above)
if (output.hookSpecificOutput) {
const { clearContext: _clearContext, ...restSpecificOutput } =
output.hookSpecificOutput;
merged.hookSpecificOutput = {
...(merged.hookSpecificOutput || {}),
...restSpecificOutput,
};
}
@@ -323,6 +334,8 @@ export class HookAggregator {
return new BeforeToolSelectionHookOutput(output);
case HookEventName.AfterModel:
return new AfterModelHookOutput(output);
case HookEventName.AfterAgent:
return new AfterAgentHookOutput(output);
default:
return new DefaultHookOutput(output);
}

View File

@@ -140,6 +140,8 @@ export function createHookOutput(
return new BeforeToolSelectionHookOutput(data);
case 'BeforeTool':
return new BeforeToolHookOutput(data);
case 'AfterAgent':
return new AfterAgentHookOutput(data);
default:
return new DefaultHookOutput(data);
}
@@ -238,6 +240,13 @@ export class DefaultHookOutput implements HookOutput {
}
return { blocked: false, reason: '' };
}
/**
* Check if context clearing was requested by hook.
*/
shouldClearContext(): boolean {
return false;
}
}
/**
@@ -362,6 +371,21 @@ export class AfterModelHookOutput extends DefaultHookOutput {
}
}
/**
* Specific hook output class for AfterAgent events
*/
export class AfterAgentHookOutput extends DefaultHookOutput {
/**
* Check if context clearing was requested by hook
*/
override shouldClearContext(): boolean {
if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) {
return this.hookSpecificOutput['clearContext'] === true;
}
return false;
}
}
/**
* Context for MCP tool executions.
* Contains non-sensitive connection information about the MCP server
@@ -475,6 +499,16 @@ export interface AfterAgentInput extends HookInput {
stop_hook_active: boolean;
}
/**
* AfterAgent hook output
*/
export interface AfterAgentOutput extends HookOutput {
hookSpecificOutput?: {
hookEventName: 'AfterAgent';
clearContext?: boolean;
};
}
/**
* SessionStart source types
*/