From dd84c2fb837ac6d84b90416b5388341b8d83ab87 Mon Sep 17 00:00:00 2001 From: Sandy Tao Date: Sun, 4 Jan 2026 18:58:34 -0800 Subject: [PATCH] feat(hooks): implement granular stop and block behavior for agent hooks (#15824) --- packages/cli/src/nonInteractiveCli.test.ts | 59 ++++++++ packages/cli/src/nonInteractiveCli.ts | 25 ++++ .../cli/src/ui/hooks/useGeminiStream.test.tsx | 57 ++++++++ packages/cli/src/ui/hooks/useGeminiStream.ts | 49 +++++++ packages/core/src/core/client.test.ts | 131 ++++++++++++++++++ packages/core/src/core/client.ts | 60 ++++++-- packages/core/src/core/turn.ts | 20 ++- 7 files changed, 388 insertions(+), 13 deletions(-) diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index f2a2a43592..c171d95a74 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -1797,4 +1797,63 @@ describe('runNonInteractive', () => { // The key assertion: sendMessageStream should have been called ONLY ONCE (initial user input). expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); }); + + describe('Agent Execution Events', () => { + it('should handle AgentExecutionStopped event', async () => { + const events: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.AgentExecutionStopped, + value: { reason: 'Stopped by hook' }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test stop', + prompt_id: 'prompt-id-stop', + }); + + expect(processStderrSpy).toHaveBeenCalledWith( + 'Agent execution stopped: Stopped by hook\n', + ); + // Should exit without calling sendMessageStream again + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); + }); + + it('should handle AgentExecutionBlocked event', async () => { + const allEvents: ServerGeminiStreamEvent[] = [ + { + type: GeminiEventType.AgentExecutionBlocked, + value: { reason: 'Blocked by hook' }, + }, + { type: GeminiEventType.Content, value: 'Final answer' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 10 } }, + }, + ]; + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(allEvents), + ); + + await runNonInteractive({ + config: mockConfig, + settings: mockSettings, + input: 'test block', + prompt_id: 'prompt-id-block', + }); + + expect(processStderrSpy).toHaveBeenCalledWith( + '[WARNING] Agent execution blocked: Blocked by hook\n', + ); + // sendMessageStream is called once, recursion is internal to it and transparent to the caller + expect(mockGeminiClient.sendMessageStream).toHaveBeenCalledTimes(1); + expect(getWrittenOutput()).toBe('Final answer\n'); + }); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index c81efd72f5..d1f468ef39 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -348,6 +348,31 @@ export async function runNonInteractive({ } } else if (event.type === GeminiEventType.Error) { throw event.value.error; + } else if (event.type === GeminiEventType.AgentExecutionStopped) { + const stopMessage = `Agent execution stopped: ${event.value.reason}`; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`${stopMessage}\n`); + } + // Emit final result event for streaming JSON if needed + if (streamFormatter) { + const metrics = uiTelemetryService.getMetrics(); + const durationMs = Date.now() - startTime; + streamFormatter.emitEvent({ + type: JsonStreamEventType.RESULT, + timestamp: new Date().toISOString(), + status: 'success', + stats: streamFormatter.convertToStreamStats( + metrics, + durationMs, + ), + }); + } + return; + } else if (event.type === GeminiEventType.AgentExecutionBlocked) { + const blockMessage = `Agent execution blocked: ${event.value.reason}`; + if (config.getOutputFormat() === OutputFormat.TEXT) { + process.stderr.write(`[WARNING] ${blockMessage}\n`); + } } } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 7c1b3a7dc9..2414c340f4 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -2798,4 +2798,61 @@ describe('useGeminiStream', () => { }); }); }); + + describe('Agent Execution Events', () => { + it('should handle AgentExecutionStopped event', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.AgentExecutionStopped, + value: { reason: 'Stopped by hook' }, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test stop'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + { + type: MessageType.INFO, + text: 'Agent execution stopped: Stopped by hook', + }, + expect.any(Number), + ); + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + }); + + it('should handle AgentExecutionBlocked event', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.AgentExecutionBlocked, + value: { reason: 'Blocked by hook' }, + }; + })(), + ); + + const { result } = renderTestHook(); + + await act(async () => { + await result.current.submitQuery('test block'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + { + type: MessageType.WARNING, + text: 'Agent execution blocked: Blocked by hook', + }, + expect.any(Number), + ); + }); + }); + }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index d36d9f57ed..4522af13c7 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -793,6 +793,41 @@ export const useGeminiStream = ( [addItem, pendingHistoryItemRef, setPendingHistoryItem, settings], ); + const handleAgentExecutionStoppedEvent = useCallback( + (reason: string, userMessageTimestamp: number) => { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + addItem( + { + type: MessageType.INFO, + text: `Agent execution stopped: ${reason}`, + }, + userMessageTimestamp, + ); + setIsResponding(false); + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem, setIsResponding], + ); + + const handleAgentExecutionBlockedEvent = useCallback( + (reason: string, userMessageTimestamp: number) => { + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + addItem( + { + type: MessageType.WARNING, + text: `Agent execution blocked: ${reason}`, + }, + userMessageTimestamp, + ); + }, + [addItem, pendingHistoryItemRef, setPendingHistoryItem], + ); + const processGeminiStreamEvents = useCallback( async ( stream: AsyncIterable, @@ -822,6 +857,18 @@ export const useGeminiStream = ( case ServerGeminiEventType.Error: handleErrorEvent(event.value, userMessageTimestamp); break; + case ServerGeminiEventType.AgentExecutionStopped: + handleAgentExecutionStoppedEvent( + event.value.reason, + userMessageTimestamp, + ); + break; + case ServerGeminiEventType.AgentExecutionBlocked: + handleAgentExecutionBlockedEvent( + event.value.reason, + userMessageTimestamp, + ); + break; case ServerGeminiEventType.ChatCompressed: handleChatCompressionEvent(event.value, userMessageTimestamp); break; @@ -879,6 +926,8 @@ export const useGeminiStream = ( handleContextWindowWillOverflowEvent, handleCitationEvent, handleChatModelEvent, + handleAgentExecutionStoppedEvent, + handleAgentExecutionBlockedEvent, ], ); const submitQuery = useCallback( diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 34facc4737..6045088c04 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -46,6 +46,7 @@ import type { } from '../services/modelConfigService.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { HookSystem } from '../hooks/hookSystem.js'; +import type { DefaultHookOutput } from '../hooks/types.js'; import * as policyCatalog from '../availability/policyCatalog.js'; vi.mock('../services/chatCompressionService.js'); @@ -2781,6 +2782,136 @@ ${JSON.stringify( expect(client['hookStateMap'].has('old-id')).toBe(false); expect(client['hookStateMap'].has('new-id')).toBe(true); }); + + it('should stop execution in BeforeAgent when hook returns continue: false', async () => { + const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); + vi.mocked(fireBeforeAgentHook).mockResolvedValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'Stopped by hook', + } as DefaultHookOutput); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), + }; + client['chat'] = mockChat as GeminiChat; + + const request = [{ text: 'Hello' }]; + const stream = client.sendMessageStream( + request, + new AbortController().signal, + 'test-prompt', + ); + const events = await fromAsync(stream); + + expect(events).toContainEqual({ + type: GeminiEventType.AgentExecutionStopped, + value: { reason: 'Stopped by hook' }, + }); + expect(mockChat.addHistory).toHaveBeenCalledWith({ + role: 'user', + parts: request, + }); + expect(mockTurnRunFn).not.toHaveBeenCalled(); + }); + + it('should block execution in BeforeAgent when hook returns decision: block', async () => { + const { fireBeforeAgentHook } = await import('./clientHookTriggers.js'); + vi.mocked(fireBeforeAgentHook).mockResolvedValue({ + shouldStopExecution: () => false, + isBlockingDecision: () => true, + getEffectiveReason: () => 'Blocked by hook', + } as DefaultHookOutput); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), + }; + client['chat'] = mockChat as GeminiChat; + + const request = [{ text: 'Hello' }]; + const stream = client.sendMessageStream( + request, + new AbortController().signal, + 'test-prompt', + ); + const events = await fromAsync(stream); + + expect(events).toContainEqual({ + type: GeminiEventType.AgentExecutionBlocked, + value: { + reason: 'Blocked by hook', + }, + }); + expect(mockChat.addHistory).not.toHaveBeenCalled(); + expect(mockTurnRunFn).not.toHaveBeenCalled(); + }); + + it('should stop execution in AfterAgent when hook returns continue: false', async () => { + const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); + vi.mocked(fireAfterAgentHook).mockResolvedValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'Stopped after agent', + } as DefaultHookOutput); + + mockTurnRunFn.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Hello' }; + }); + + const stream = client.sendMessageStream( + { text: 'Hi' }, + new AbortController().signal, + 'test-prompt', + ); + const events = await fromAsync(stream); + + expect(events).toContainEqual({ + type: GeminiEventType.AgentExecutionStopped, + value: { reason: 'Stopped after agent' }, + }); + // sendMessageStream should not recurse + expect(mockTurnRunFn).toHaveBeenCalledTimes(1); + }); + + it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => { + const { fireAfterAgentHook } = await import('./clientHookTriggers.js'); + vi.mocked(fireAfterAgentHook) + .mockResolvedValueOnce({ + shouldStopExecution: () => false, + isBlockingDecision: () => true, + getEffectiveReason: () => 'Please explain', + } as DefaultHookOutput) + .mockResolvedValueOnce({ + shouldStopExecution: () => false, + isBlockingDecision: () => false, + } as DefaultHookOutput); + + mockTurnRunFn.mockImplementation(async function* () { + yield { type: GeminiEventType.Content, value: 'Response' }; + }); + + const stream = client.sendMessageStream( + { text: 'Hi' }, + new AbortController().signal, + 'test-prompt', + ); + const events = await fromAsync(stream); + + expect(events).toContainEqual({ + type: GeminiEventType.AgentExecutionBlocked, + value: { reason: 'Please explain' }, + }); + // Should have called turn run twice (original + re-prompt) + expect(mockTurnRunFn).toHaveBeenCalledTimes(2); + expect(mockTurnRunFn).toHaveBeenNthCalledWith( + 2, + expect.anything(), + [{ text: 'Please explain' }], + expect.anything(), + ); + }); }); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 448dd99310..ecd1eff471 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -11,6 +11,7 @@ import type { Tool, GenerateContentResponse, } from '@google/genai'; +import { createUserContent } from '@google/genai'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { getDirectoryContextString, @@ -65,8 +66,12 @@ const MAX_TURNS = 100; type BeforeAgentHookReturn = | { - type: GeminiEventType.Error; - value: { error: Error }; + type: GeminiEventType.AgentExecutionStopped; + value: { reason: string }; + } + | { + type: GeminiEventType.AgentExecutionBlocked; + value: { reason: string }; } | { additionalContext: string | undefined } | undefined; @@ -135,13 +140,20 @@ export class GeminiClient { const hookOutput = await fireBeforeAgentHook(messageBus, request); hookState.hasFiredBeforeAgent = true; - if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) { + if (hookOutput?.shouldStopExecution()) { return { - type: GeminiEventType.Error, + type: GeminiEventType.AgentExecutionStopped, value: { - error: new Error( - `BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`, - ), + reason: hookOutput.getEffectiveReason(), + }, + }; + } + + if (hookOutput?.isBlockingDecision()) { + return { + type: GeminiEventType.AgentExecutionBlocked, + value: { + reason: hookOutput.getEffectiveReason(), }, }; } @@ -747,7 +759,18 @@ export class GeminiClient { prompt_id, ); if (hookResult) { - if ('type' in hookResult && hookResult.type === GeminiEventType.Error) { + if ( + 'type' in hookResult && + hookResult.type === GeminiEventType.AgentExecutionStopped + ) { + // Add user message to history before returning so it's kept in the transcript + this.getChat().addHistory(createUserContent(request)); + yield hookResult; + return new Turn(this.getChat(), prompt_id); + } else if ( + 'type' in hookResult && + hookResult.type === GeminiEventType.AgentExecutionBlocked + ) { yield hookResult; return new Turn(this.getChat(), prompt_id); } else if ('additionalContext' in hookResult) { @@ -781,11 +804,24 @@ export class GeminiClient { turn, ); - if ( - hookOutput?.isBlockingDecision() || - hookOutput?.shouldStopExecution() - ) { + if (hookOutput?.shouldStopExecution()) { + yield { + type: GeminiEventType.AgentExecutionStopped, + value: { + reason: hookOutput.getEffectiveReason(), + }, + }; + return turn; + } + + if (hookOutput?.isBlockingDecision()) { const continueReason = hookOutput.getEffectiveReason(); + yield { + type: GeminiEventType.AgentExecutionBlocked, + value: { + reason: continueReason, + }, + }; const continueRequest = [{ text: continueReason }]; yield* this.sendMessageStream( continueRequest, diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 21191b34f2..11825d9d7b 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -66,12 +66,28 @@ export enum GeminiEventType { ContextWindowWillOverflow = 'context_window_will_overflow', InvalidStream = 'invalid_stream', ModelInfo = 'model_info', + AgentExecutionStopped = 'agent_execution_stopped', + AgentExecutionBlocked = 'agent_execution_blocked', } export type ServerGeminiRetryEvent = { type: GeminiEventType.Retry; }; +export type ServerGeminiAgentExecutionStoppedEvent = { + type: GeminiEventType.AgentExecutionStopped; + value: { + reason: string; + }; +}; + +export type ServerGeminiAgentExecutionBlockedEvent = { + type: GeminiEventType.AgentExecutionBlocked; + value: { + reason: string; + }; +}; + export type ServerGeminiContextWindowWillOverflowEvent = { type: GeminiEventType.ContextWindowWillOverflow; value: { @@ -204,7 +220,9 @@ export type ServerGeminiStreamEvent = | ServerGeminiRetryEvent | ServerGeminiContextWindowWillOverflowEvent | ServerGeminiInvalidStreamEvent - | ServerGeminiModelInfoEvent; + | ServerGeminiModelInfoEvent + | ServerGeminiAgentExecutionStoppedEvent + | ServerGeminiAgentExecutionBlockedEvent; // A turn manages the agentic loop turn within the server context. export class Turn {