feat(hooks): implement granular stop and block behavior for agent hooks (#15824)

This commit is contained in:
Sandy Tao
2026-01-04 18:58:34 -08:00
committed by GitHub
parent bdb349e7f6
commit dd84c2fb83
7 changed files with 388 additions and 13 deletions
+131
View File
@@ -46,6 +46,7 @@ import type {
} from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { HookSystem } from '../hooks/hookSystem.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import * as policyCatalog from '../availability/policyCatalog.js';
vi.mock('../services/chatCompressionService.js');
@@ -2781,6 +2782,136 @@ ${JSON.stringify(
expect(client['hookStateMap'].has('old-id')).toBe(false);
expect(client['hookStateMap'].has('new-id')).toBe(true);
});
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped by hook',
} as DefaultHookOutput);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
getLastPromptTokenCount: vi.fn(),
};
client['chat'] = mockChat as GeminiChat;
const request = [{ text: 'Hello' }];
const stream = client.sendMessageStream(
request,
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionStopped,
value: { reason: 'Stopped by hook' },
});
expect(mockChat.addHistory).toHaveBeenCalledWith({
role: 'user',
parts: request,
});
expect(mockTurnRunFn).not.toHaveBeenCalled();
});
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Blocked by hook',
} as DefaultHookOutput);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
getLastPromptTokenCount: vi.fn(),
};
client['chat'] = mockChat as GeminiChat;
const request = [{ text: 'Hello' }];
const stream = client.sendMessageStream(
request,
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionBlocked,
value: {
reason: 'Blocked by hook',
},
});
expect(mockChat.addHistory).not.toHaveBeenCalled();
expect(mockTurnRunFn).not.toHaveBeenCalled();
});
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireAfterAgentHook).mockResolvedValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'Stopped after agent',
} as DefaultHookOutput);
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Hello' };
});
const stream = client.sendMessageStream(
{ text: 'Hi' },
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionStopped,
value: { reason: 'Stopped after agent' },
});
// sendMessageStream should not recurse
expect(mockTurnRunFn).toHaveBeenCalledTimes(1);
});
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
vi.mocked(fireAfterAgentHook)
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => true,
getEffectiveReason: () => 'Please explain',
} as DefaultHookOutput)
.mockResolvedValueOnce({
shouldStopExecution: () => false,
isBlockingDecision: () => false,
} as DefaultHookOutput);
mockTurnRunFn.mockImplementation(async function* () {
yield { type: GeminiEventType.Content, value: 'Response' };
});
const stream = client.sendMessageStream(
{ text: 'Hi' },
new AbortController().signal,
'test-prompt',
);
const events = await fromAsync(stream);
expect(events).toContainEqual({
type: GeminiEventType.AgentExecutionBlocked,
value: { reason: 'Please explain' },
});
// Should have called turn run twice (original + re-prompt)
expect(mockTurnRunFn).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
expect.anything(),
[{ text: 'Please explain' }],
expect.anything(),
);
});
});
});
});
+48 -12
View File
@@ -11,6 +11,7 @@ import type {
Tool,
GenerateContentResponse,
} from '@google/genai';
import { createUserContent } from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
getDirectoryContextString,
@@ -65,8 +66,12 @@ const MAX_TURNS = 100;
type BeforeAgentHookReturn =
| {
type: GeminiEventType.Error;
value: { error: Error };
type: GeminiEventType.AgentExecutionStopped;
value: { reason: string };
}
| {
type: GeminiEventType.AgentExecutionBlocked;
value: { reason: string };
}
| { additionalContext: string | undefined }
| undefined;
@@ -135,13 +140,20 @@ export class GeminiClient {
const hookOutput = await fireBeforeAgentHook(messageBus, request);
hookState.hasFiredBeforeAgent = true;
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
if (hookOutput?.shouldStopExecution()) {
return {
type: GeminiEventType.Error,
type: GeminiEventType.AgentExecutionStopped,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
reason: hookOutput.getEffectiveReason(),
},
};
}
if (hookOutput?.isBlockingDecision()) {
return {
type: GeminiEventType.AgentExecutionBlocked,
value: {
reason: hookOutput.getEffectiveReason(),
},
};
}
@@ -747,7 +759,18 @@ export class GeminiClient {
prompt_id,
);
if (hookResult) {
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
if (
'type' in hookResult &&
hookResult.type === GeminiEventType.AgentExecutionStopped
) {
// Add user message to history before returning so it's kept in the transcript
this.getChat().addHistory(createUserContent(request));
yield hookResult;
return new Turn(this.getChat(), prompt_id);
} else if (
'type' in hookResult &&
hookResult.type === GeminiEventType.AgentExecutionBlocked
) {
yield hookResult;
return new Turn(this.getChat(), prompt_id);
} else if ('additionalContext' in hookResult) {
@@ -781,11 +804,24 @@ export class GeminiClient {
turn,
);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
if (hookOutput?.shouldStopExecution()) {
yield {
type: GeminiEventType.AgentExecutionStopped,
value: {
reason: hookOutput.getEffectiveReason(),
},
};
return turn;
}
if (hookOutput?.isBlockingDecision()) {
const continueReason = hookOutput.getEffectiveReason();
yield {
type: GeminiEventType.AgentExecutionBlocked,
value: {
reason: continueReason,
},
};
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
+19 -1
View File
@@ -66,12 +66,28 @@ export enum GeminiEventType {
ContextWindowWillOverflow = 'context_window_will_overflow',
InvalidStream = 'invalid_stream',
ModelInfo = 'model_info',
AgentExecutionStopped = 'agent_execution_stopped',
AgentExecutionBlocked = 'agent_execution_blocked',
}
export type ServerGeminiRetryEvent = {
type: GeminiEventType.Retry;
};
export type ServerGeminiAgentExecutionStoppedEvent = {
type: GeminiEventType.AgentExecutionStopped;
value: {
reason: string;
};
};
export type ServerGeminiAgentExecutionBlockedEvent = {
type: GeminiEventType.AgentExecutionBlocked;
value: {
reason: string;
};
};
export type ServerGeminiContextWindowWillOverflowEvent = {
type: GeminiEventType.ContextWindowWillOverflow;
value: {
@@ -204,7 +220,9 @@ export type ServerGeminiStreamEvent =
| ServerGeminiRetryEvent
| ServerGeminiContextWindowWillOverflowEvent
| ServerGeminiInvalidStreamEvent
| ServerGeminiModelInfoEvent;
| ServerGeminiModelInfoEvent
| ServerGeminiAgentExecutionStoppedEvent
| ServerGeminiAgentExecutionBlockedEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {