mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-09 04:41:19 -07:00
feat: add clearContext to AfterAgent hooks (#16574)
This commit is contained in:
committed by
Sandy Tao
parent
958cc45937
commit
2a3c879782
@@ -3107,6 +3107,7 @@ ${JSON.stringify(
|
||||
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped after agent',
|
||||
shouldClearContext: () => false,
|
||||
systemMessage: undefined,
|
||||
});
|
||||
|
||||
@@ -3121,10 +3122,12 @@ ${JSON.stringify(
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: { reason: 'Stopped after agent' },
|
||||
});
|
||||
expect(events).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: expect.objectContaining({ reason: 'Stopped after agent' }),
|
||||
}),
|
||||
);
|
||||
// sendMessageStream should not recurse
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
@@ -3135,11 +3138,60 @@ ${JSON.stringify(
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Please explain',
|
||||
shouldClearContext: () => false,
|
||||
systemMessage: undefined,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => false,
|
||||
shouldClearContext: () => false,
|
||||
systemMessage: undefined,
|
||||
});
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Response' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
{ text: 'Hi' },
|
||||
new AbortController().signal,
|
||||
'test-prompt',
|
||||
);
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
expect(events).toContainEqual(
|
||||
expect.objectContaining({
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: expect.objectContaining({ reason: 'Please explain' }),
|
||||
}),
|
||||
);
|
||||
// Should have called turn run twice (original + re-prompt)
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.anything(),
|
||||
[{ text: 'Please explain' }],
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should call resetChat when AfterAgent hook returns shouldClearContext: true', async () => {
|
||||
const resetChatSpy = vi
|
||||
.spyOn(client, 'resetChat')
|
||||
.mockResolvedValue(undefined);
|
||||
|
||||
mockHookSystem.fireAfterAgentEvent
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Blocked and clearing context',
|
||||
shouldClearContext: () => true,
|
||||
systemMessage: undefined,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => false,
|
||||
shouldClearContext: () => false,
|
||||
systemMessage: undefined,
|
||||
});
|
||||
|
||||
@@ -3156,16 +3208,15 @@ ${JSON.stringify(
|
||||
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: { reason: 'Please explain' },
|
||||
value: {
|
||||
reason: 'Blocked and clearing context',
|
||||
systemMessage: undefined,
|
||||
contextCleared: true,
|
||||
},
|
||||
});
|
||||
// Should have called turn run twice (original + re-prompt)
|
||||
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
|
||||
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.anything(),
|
||||
[{ text: 'Please explain' }],
|
||||
expect.anything(),
|
||||
);
|
||||
expect(resetChatSpy).toHaveBeenCalledTimes(1);
|
||||
|
||||
resetChatSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -40,7 +40,10 @@ import {
|
||||
logContentRetryFailure,
|
||||
logNextSpeakerCheck,
|
||||
} from '../telemetry/loggers.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import type {
|
||||
DefaultHookOutput,
|
||||
AfterAgentHookOutput,
|
||||
} from '../hooks/types.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
@@ -812,26 +815,41 @@ export class GeminiClient {
|
||||
turn,
|
||||
);
|
||||
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
// Cast to AfterAgentHookOutput for access to shouldClearContext()
|
||||
const afterAgentOutput = hookOutput as AfterAgentHookOutput | undefined;
|
||||
|
||||
if (afterAgentOutput?.shouldStopExecution()) {
|
||||
const contextCleared = afterAgentOutput.shouldClearContext();
|
||||
yield {
|
||||
type: GeminiEventType.AgentExecutionStopped,
|
||||
value: {
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
systemMessage: hookOutput.systemMessage,
|
||||
reason: afterAgentOutput.getEffectiveReason(),
|
||||
systemMessage: afterAgentOutput.systemMessage,
|
||||
contextCleared,
|
||||
},
|
||||
};
|
||||
// Clear context if requested (honor both stop + clear)
|
||||
if (contextCleared) {
|
||||
await this.resetChat();
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (hookOutput?.isBlockingDecision()) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
if (afterAgentOutput?.isBlockingDecision()) {
|
||||
const continueReason = afterAgentOutput.getEffectiveReason();
|
||||
const contextCleared = afterAgentOutput.shouldClearContext();
|
||||
yield {
|
||||
type: GeminiEventType.AgentExecutionBlocked,
|
||||
value: {
|
||||
reason: continueReason,
|
||||
systemMessage: hookOutput.systemMessage,
|
||||
systemMessage: afterAgentOutput.systemMessage,
|
||||
contextCleared,
|
||||
},
|
||||
};
|
||||
// Clear context if requested
|
||||
if (contextCleared) {
|
||||
await this.resetChat();
|
||||
}
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
|
||||
@@ -79,6 +79,7 @@ export type ServerGeminiAgentExecutionStoppedEvent = {
|
||||
value: {
|
||||
reason: string;
|
||||
systemMessage?: string;
|
||||
contextCleared?: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
@@ -87,6 +88,7 @@ export type ServerGeminiAgentExecutionBlockedEvent = {
|
||||
value: {
|
||||
reason: string;
|
||||
systemMessage?: string;
|
||||
contextCleared?: boolean;
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
BeforeModelHookOutput,
|
||||
BeforeToolSelectionHookOutput,
|
||||
AfterModelHookOutput,
|
||||
AfterAgentHookOutput,
|
||||
} from './types.js';
|
||||
import { HookEventName } from './types.js';
|
||||
|
||||
@@ -158,11 +159,21 @@ export class HookAggregator {
|
||||
merged.suppressOutput = true;
|
||||
}
|
||||
|
||||
// Merge hookSpecificOutput
|
||||
if (output.hookSpecificOutput) {
|
||||
// Handle clearContext (any true wins) - for AfterAgent hooks
|
||||
if (output.hookSpecificOutput?.['clearContext'] === true) {
|
||||
merged.hookSpecificOutput = {
|
||||
...(merged.hookSpecificOutput || {}),
|
||||
...output.hookSpecificOutput,
|
||||
clearContext: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Merge hookSpecificOutput (excluding clearContext which is handled above)
|
||||
if (output.hookSpecificOutput) {
|
||||
const { clearContext: _clearContext, ...restSpecificOutput } =
|
||||
output.hookSpecificOutput;
|
||||
merged.hookSpecificOutput = {
|
||||
...(merged.hookSpecificOutput || {}),
|
||||
...restSpecificOutput,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -323,6 +334,8 @@ export class HookAggregator {
|
||||
return new BeforeToolSelectionHookOutput(output);
|
||||
case HookEventName.AfterModel:
|
||||
return new AfterModelHookOutput(output);
|
||||
case HookEventName.AfterAgent:
|
||||
return new AfterAgentHookOutput(output);
|
||||
default:
|
||||
return new DefaultHookOutput(output);
|
||||
}
|
||||
|
||||
@@ -140,6 +140,8 @@ export function createHookOutput(
|
||||
return new BeforeToolSelectionHookOutput(data);
|
||||
case 'BeforeTool':
|
||||
return new BeforeToolHookOutput(data);
|
||||
case 'AfterAgent':
|
||||
return new AfterAgentHookOutput(data);
|
||||
default:
|
||||
return new DefaultHookOutput(data);
|
||||
}
|
||||
@@ -238,6 +240,13 @@ export class DefaultHookOutput implements HookOutput {
|
||||
}
|
||||
return { blocked: false, reason: '' };
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if context clearing was requested by hook.
|
||||
*/
|
||||
shouldClearContext(): boolean {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -362,6 +371,21 @@ export class AfterModelHookOutput extends DefaultHookOutput {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Specific hook output class for AfterAgent events
|
||||
*/
|
||||
export class AfterAgentHookOutput extends DefaultHookOutput {
|
||||
/**
|
||||
* Check if context clearing was requested by hook
|
||||
*/
|
||||
override shouldClearContext(): boolean {
|
||||
if (this.hookSpecificOutput && 'clearContext' in this.hookSpecificOutput) {
|
||||
return this.hookSpecificOutput['clearContext'] === true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Context for MCP tool executions.
|
||||
* Contains non-sensitive connection information about the MCP server
|
||||
@@ -475,6 +499,16 @@ export interface AfterAgentInput extends HookInput {
|
||||
stop_hook_active: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* AfterAgent hook output
|
||||
*/
|
||||
export interface AfterAgentOutput extends HookOutput {
|
||||
hookSpecificOutput?: {
|
||||
hookEventName: 'AfterAgent';
|
||||
clearContext?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* SessionStart source types
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user