From e1e3efc9d04a1e93899e559e888b93c95b14ae2f Mon Sep 17 00:00:00 2001 From: Sandy Tao Date: Fri, 9 Jan 2026 13:36:27 +0800 Subject: [PATCH] feat(hooks): Support explicit stop and block execution control in model hooks (#15947) Co-authored-by: matt korwel --- packages/core/src/core/geminiChat.test.ts | 159 ++++++++++++++ packages/core/src/core/geminiChat.ts | 99 ++++++++- .../src/core/geminiChatHookTriggers.test.ts | 204 ++++++++++++++++++ .../core/src/core/geminiChatHookTriggers.ts | 52 ++++- packages/core/src/core/turn.ts | 16 ++ packages/core/src/hooks/types.test.ts | 36 +--- packages/core/src/hooks/types.ts | 16 -- 7 files changed, 517 insertions(+), 65 deletions(-) create mode 100644 packages/core/src/core/geminiChatHookTriggers.test.ts diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index baf8973904..1f60565f0d 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -28,6 +28,18 @@ import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import * as policyHelpers from '../availability/policyHelpers.js'; import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; +import { + fireBeforeModelHook, + fireAfterModelHook, + fireBeforeToolSelectionHook, +} from './geminiChatHookTriggers.js'; + +// Mock hook triggers +vi.mock('./geminiChatHookTriggers.js', () => ({ + fireBeforeModelHook: vi.fn(), + fireAfterModelHook: vi.fn(), + fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}), +})); // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -2269,4 +2281,151 @@ describe('GeminiChat', () => { ); }); }); + + describe('Hook execution control', () => { + beforeEach(() => { + vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true); + // Default to allowing execution + vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false }); + vi.mocked(fireAfterModelHook).mockResolvedValue({ + response: {} as GenerateContentResponse, + }); + vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({}); + }); + + it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => { + vi.mocked(fireBeforeModelHook).mockResolvedValue({ + blocked: true, + stopped: true, + reason: 'stopped by hook', + }); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-id', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toHaveLength(1); + expect(events[0]).toEqual({ + type: StreamEventType.AGENT_EXECUTION_STOPPED, + reason: 'stopped by hook', + }); + }); + + it('should yield AGENT_EXECUTION_BLOCKED and synthetic response when BeforeModel hook blocks execution', async () => { + const syntheticResponse = { + candidates: [{ content: { parts: [{ text: 'blocked' }] } }], + } as GenerateContentResponse; + + vi.mocked(fireBeforeModelHook).mockResolvedValue({ + blocked: true, + reason: 'blocked by hook', + syntheticResponse, + }); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-id', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toHaveLength(2); + expect(events[0]).toEqual({ + type: StreamEventType.AGENT_EXECUTION_BLOCKED, + reason: 'blocked by hook', + }); + expect(events[1]).toEqual({ + type: StreamEventType.CHUNK, + value: syntheticResponse, + }); + }); + + it('should yield AGENT_EXECUTION_STOPPED when AfterModel hook stops execution', async () => { + // Mock content generator to return a stream + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'response' }] } }], + } as unknown as GenerateContentResponse; + })(), + ); + + vi.mocked(fireAfterModelHook).mockResolvedValue({ + response: {} as GenerateContentResponse, + stopped: true, + reason: 'stopped by after hook', + }); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-id', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toContainEqual({ + type: StreamEventType.AGENT_EXECUTION_STOPPED, + reason: 'stopped by after hook', + }); + }); + + it('should yield AGENT_EXECUTION_BLOCKED and response when AfterModel hook blocks execution', async () => { + const response = { + candidates: [{ content: { parts: [{ text: 'response' }] } }], + } as unknown as GenerateContentResponse; + + // Mock content generator to return a stream + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + yield response; + })(), + ); + + vi.mocked(fireAfterModelHook).mockResolvedValue({ + response, + blocked: true, + reason: 'blocked by after hook', + }); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-id', + new AbortController().signal, + ); + + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + expect(events).toContainEqual({ + type: StreamEventType.AGENT_EXECUTION_BLOCKED, + reason: 'blocked by after hook', + }); + // Should also contain the chunk (hook response) + expect(events).toContainEqual({ + type: StreamEventType.CHUNK, + value: response, + }); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 3bc928c6fb..2dff70c16d 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -61,11 +61,17 @@ export enum StreamEventType { /** A signal that a retry is about to happen. The UI should discard any partial * content from the attempt that just failed. */ RETRY = 'retry', + /** A signal that the agent execution has been stopped by a hook. */ + AGENT_EXECUTION_STOPPED = 'agent_execution_stopped', + /** A signal that the agent execution has been blocked by a hook. */ + AGENT_EXECUTION_BLOCKED = 'agent_execution_blocked', } export type StreamEvent = | { type: StreamEventType.CHUNK; value: GenerateContentResponse } - | { type: StreamEventType.RETRY }; + | { type: StreamEventType.RETRY } + | { type: StreamEventType.AGENT_EXECUTION_STOPPED; reason: string } + | { type: StreamEventType.AGENT_EXECUTION_BLOCKED; reason: string }; /** * Options for retrying due to invalid content from the model. @@ -197,6 +203,29 @@ export class InvalidStreamError extends Error { } } +/** + * Custom error to signal that agent execution has been stopped. + */ +export class AgentExecutionStoppedError extends Error { + constructor(public reason: string) { + super(reason); + this.name = 'AgentExecutionStoppedError'; + } +} + +/** + * Custom error to signal that agent execution has been blocked. + */ +export class AgentExecutionBlockedError extends Error { + constructor( + public reason: string, + public syntheticResponse?: GenerateContentResponse, + ) { + super(reason); + this.name = 'AgentExecutionBlockedError'; + } +} + /** * Chat session that enables sending messages to the model with previous * conversation context. @@ -325,6 +354,30 @@ export class GeminiChat { lastError = null; break; } catch (error) { + if (error instanceof AgentExecutionStoppedError) { + yield { + type: StreamEventType.AGENT_EXECUTION_STOPPED, + reason: error.reason, + }; + lastError = null; // Clear error as this is an expected stop + return; // Stop the generator + } + + if (error instanceof AgentExecutionBlockedError) { + yield { + type: StreamEventType.AGENT_EXECUTION_BLOCKED, + reason: error.reason, + }; + if (error.syntheticResponse) { + yield { + type: StreamEventType.CHUNK, + value: error.syntheticResponse, + }; + } + lastError = null; // Clear error as this is an expected stop + return; // Stop the generator + } + if (isConnectionPhase) { throw error; } @@ -457,19 +510,35 @@ export class GeminiChat { contents: contentsToUse, }); + // Check if hook requested to stop execution + if (beforeModelResult.stopped) { + throw new AgentExecutionStoppedError( + beforeModelResult.reason || 'Agent execution stopped by hook', + ); + } + // Check if hook blocked the model call if (beforeModelResult.blocked) { // Return a synthetic response generator const syntheticResponse = beforeModelResult.syntheticResponse; if (syntheticResponse) { - return (async function* () { - yield syntheticResponse; - })(); + // Ensure synthetic response has a finish reason to prevent InvalidStreamError + if ( + syntheticResponse.candidates && + syntheticResponse.candidates.length > 0 + ) { + for (const candidate of syntheticResponse.candidates) { + if (!candidate.finishReason) { + candidate.finishReason = FinishReason.STOP; + } + } + } } - // If blocked without synthetic response, return empty generator - return (async function* () { - // Empty generator - no response - })(); + + throw new AgentExecutionBlockedError( + beforeModelResult.reason || 'Model call blocked by hook', + syntheticResponse, + ); } // Apply modifications from BeforeModel hook @@ -748,6 +817,20 @@ export class GeminiChat { originalRequest, chunk, ); + + if (hookResult.stopped) { + throw new AgentExecutionStoppedError( + hookResult.reason || 'Agent execution stopped by hook', + ); + } + + if (hookResult.blocked) { + throw new AgentExecutionBlockedError( + hookResult.reason || 'Agent execution blocked by hook', + hookResult.response, + ); + } + yield hookResult.response; } else { yield chunk; // Yield every chunk to the UI immediately. diff --git a/packages/core/src/core/geminiChatHookTriggers.test.ts b/packages/core/src/core/geminiChatHookTriggers.test.ts new file mode 100644 index 0000000000..0bc1501386 --- /dev/null +++ b/packages/core/src/core/geminiChatHookTriggers.test.ts @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { + fireBeforeModelHook, + fireAfterModelHook, +} from './geminiChatHookTriggers.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; +import type { + GenerateContentParameters, + GenerateContentResponse, +} from '@google/genai'; + +// Mock dependencies +const mockRequest = vi.fn(); +const mockMessageBus = { + request: mockRequest, +} as unknown as MessageBus; + +// Mock hook types +vi.mock('../hooks/types.js', async () => { + const actual = await vi.importActual('../hooks/types.js'); + return { + ...actual, + createHookOutput: vi.fn(), + }; +}); + +import { createHookOutput } from '../hooks/types.js'; + +describe('Gemini Chat Hook Triggers', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe('fireBeforeModelHook', () => { + const llmRequest = { + model: 'gemini-pro', + contents: [{ parts: [{ text: 'test' }] }], + } as GenerateContentParameters; + + it('should return stopped: true when hook requests stop execution', async () => { + mockRequest.mockResolvedValue({ + output: { continue: false, stopReason: 'stopped by hook' }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'stopped by hook', + getBlockingError: () => ({ blocked: false, reason: '' }), + } as unknown as ReturnType); + + const result = await fireBeforeModelHook(mockMessageBus, llmRequest); + + expect(result).toEqual({ + blocked: true, + stopped: true, + reason: 'stopped by hook', + }); + }); + + it('should return blocked: true when hook blocks execution', async () => { + mockRequest.mockResolvedValue({ + output: { decision: 'block', reason: 'blocked by hook' }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ blocked: true, reason: 'blocked by hook' }), + getEffectiveReason: () => 'blocked by hook', + getSyntheticResponse: () => undefined, + } as unknown as ReturnType); + + const result = await fireBeforeModelHook(mockMessageBus, llmRequest); + + expect(result).toEqual({ + blocked: true, + reason: 'blocked by hook', + syntheticResponse: undefined, + }); + }); + + it('should return modifications when hook allows execution', async () => { + mockRequest.mockResolvedValue({ + output: { decision: 'allow' }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ blocked: false, reason: '' }), + applyLLMRequestModifications: (req: GenerateContentParameters) => req, + } as unknown as ReturnType); + + const result = await fireBeforeModelHook(mockMessageBus, llmRequest); + + expect(result).toEqual({ + blocked: false, + modifiedConfig: undefined, + modifiedContents: llmRequest.contents, + }); + }); + }); + + describe('fireAfterModelHook', () => { + const llmRequest = { + model: 'gemini-pro', + contents: [], + } as GenerateContentParameters; + const llmResponse = { + candidates: [ + { content: { role: 'model', parts: [{ text: 'response' }] } }, + ], + } as GenerateContentResponse; + + it('should return stopped: true when hook requests stop execution', async () => { + mockRequest.mockResolvedValue({ + output: { continue: false, stopReason: 'stopped by hook' }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => true, + getEffectiveReason: () => 'stopped by hook', + } as unknown as ReturnType); + + const result = await fireAfterModelHook( + mockMessageBus, + llmRequest, + llmResponse, + ); + + expect(result).toEqual({ + response: llmResponse, + stopped: true, + reason: 'stopped by hook', + }); + }); + + it('should return blocked: true when hook blocks execution', async () => { + mockRequest.mockResolvedValue({ + output: { decision: 'block', reason: 'blocked by hook' }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ blocked: true, reason: 'blocked by hook' }), + getEffectiveReason: () => 'blocked by hook', + } as unknown as ReturnType); + + const result = await fireAfterModelHook( + mockMessageBus, + llmRequest, + llmResponse, + ); + + expect(result).toEqual({ + response: llmResponse, + blocked: true, + reason: 'blocked by hook', + }); + }); + + it('should return modified response when hook modifies response', async () => { + const modifiedResponse = { ...llmResponse, text: 'modified' }; + mockRequest.mockResolvedValue({ + output: { hookSpecificOutput: { llm_response: {} } }, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ blocked: false, reason: '' }), + getModifiedResponse: () => modifiedResponse, + } as unknown as ReturnType); + + const result = await fireAfterModelHook( + mockMessageBus, + llmRequest, + llmResponse, + ); + + expect(result).toEqual({ + response: modifiedResponse, + }); + }); + + it('should return original response when hook has no effect', async () => { + mockRequest.mockResolvedValue({ + output: {}, + }); + vi.mocked(createHookOutput).mockReturnValue({ + shouldStopExecution: () => false, + getBlockingError: () => ({ blocked: false, reason: '' }), + getModifiedResponse: () => undefined, + } as unknown as ReturnType); + + const result = await fireAfterModelHook( + mockMessageBus, + llmRequest, + llmResponse, + ); + + expect(result).toEqual({ + response: llmResponse, + }); + }); + }); +}); diff --git a/packages/core/src/core/geminiChatHookTriggers.ts b/packages/core/src/core/geminiChatHookTriggers.ts index 0672ec961d..e0632105de 100644 --- a/packages/core/src/core/geminiChatHookTriggers.ts +++ b/packages/core/src/core/geminiChatHookTriggers.ts @@ -32,6 +32,8 @@ import { debugLogger } from '../utils/debugLogger.js'; export interface BeforeModelHookResult { /** Whether the model call was blocked */ blocked: boolean; + /** Whether the execution should be stopped entirely */ + stopped?: boolean; /** Reason for blocking (if blocked) */ reason?: string; /** Synthetic response to return instead of calling the model (if blocked) */ @@ -59,14 +61,16 @@ export interface BeforeToolSelectionHookResult { export interface AfterModelHookResult { /** The response to yield (either modified or original) */ response: GenerateContentResponse; + /** Whether the execution should be stopped entirely */ + stopped?: boolean; + /** Whether the model call was blocked */ + blocked?: boolean; + /** Reason for blocking or stopping */ + reason?: string; } /** * Fires the BeforeModel hook and returns the result. - * - * @param messageBus The message bus to use for hook communication - * @param llmRequest The LLM request parameters - * @returns The hook result with blocking info or modifications */ export async function fireBeforeModelHook( messageBus: MessageBus, @@ -94,9 +98,18 @@ export async function fireBeforeModelHook( const hookOutput = beforeResultFinalOutput; - // Check if hook blocked the model call or requested to stop execution + // Check if hook requested to stop execution + if (hookOutput?.shouldStopExecution()) { + return { + blocked: true, + stopped: true, + reason: hookOutput.getEffectiveReason(), + }; + } + + // Check if hook blocked the model call const blockingError = hookOutput?.getBlockingError(); - if (blockingError?.blocked || hookOutput?.shouldStopExecution()) { + if (blockingError?.blocked) { const beforeModelOutput = hookOutput as BeforeModelHookOutput; const syntheticResponse = beforeModelOutput.getSyntheticResponse(); const reason = @@ -217,9 +230,30 @@ export async function fireAfterModelHook( ? createHookOutput('AfterModel', response.output) : undefined; - // Apply modifications from hook (handles both normal modifications and stop execution) - if (afterResultFinalOutput) { - const afterModelOutput = afterResultFinalOutput as AfterModelHookOutput; + const hookOutput = afterResultFinalOutput; + + // Check if hook requested to stop execution + if (hookOutput?.shouldStopExecution()) { + return { + response: chunk, + stopped: true, + reason: hookOutput.getEffectiveReason(), + }; + } + + // Check if hook blocked the model call + const blockingError = hookOutput?.getBlockingError(); + if (blockingError?.blocked) { + return { + response: chunk, + blocked: true, + reason: hookOutput?.getEffectiveReason(), + }; + } + + // Apply modifications from hook + if (hookOutput) { + const afterModelOutput = hookOutput as AfterModelHookOutput; const modifiedResponse = afterModelOutput.getModifiedResponse(); if (modifiedResponse) { return { response: modifiedResponse }; diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 11825d9d7b..fcb8e18e04 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -264,6 +264,22 @@ export class Turn { continue; // Skip to the next event in the stream } + if (streamEvent.type === 'agent_execution_stopped') { + yield { + type: GeminiEventType.AgentExecutionStopped, + value: { reason: streamEvent.reason }, + }; + return; + } + + if (streamEvent.type === 'agent_execution_blocked') { + yield { + type: GeminiEventType.AgentExecutionBlocked, + value: { reason: streamEvent.reason }, + }; + continue; + } + // Assuming other events are chunks with a `value` property const resp = streamEvent.value; if (!resp) continue; // Skip if there's no response body diff --git a/packages/core/src/hooks/types.test.ts b/packages/core/src/hooks/types.test.ts index 18a18fe121..fb3e6d062c 100644 --- a/packages/core/src/hooks/types.test.ts +++ b/packages/core/src/hooks/types.test.ts @@ -319,45 +319,17 @@ describe('Hook Output Classes', () => { expect(output.getModifiedResponse()).toBeUndefined(); }); - it('getModifiedResponse should return a synthetic stop response if shouldStopExecution is true', () => { + it('getModifiedResponse should return undefined if shouldStopExecution is true', () => { const output = new AfterModelHookOutput({ continue: false, stopReason: 'stopped by hook', }); - const expectedResponse: LLMResponse = { - candidates: [ - { - content: { - role: 'model', - parts: ['stopped by hook'], - }, - finishReason: 'STOP', - }, - ], - }; - expect(output.getModifiedResponse()).toEqual(expectedResponse); - expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( - expectedResponse, - ); + expect(output.getModifiedResponse()).toBeUndefined(); }); - it('getModifiedResponse should return a synthetic stop response with default reason if shouldStopExecution is true and no stopReason', () => { + it('getModifiedResponse should return undefined if shouldStopExecution is true and no stopReason', () => { const output = new AfterModelHookOutput({ continue: false }); - const expectedResponse: LLMResponse = { - candidates: [ - { - content: { - role: 'model', - parts: ['No reason provided'], - }, - finishReason: 'STOP', - }, - ], - }; - expect(output.getModifiedResponse()).toEqual(expectedResponse); - expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith( - expectedResponse, - ); + expect(output.getModifiedResponse()).toBeUndefined(); }); }); }); diff --git a/packages/core/src/hooks/types.ts b/packages/core/src/hooks/types.ts index 5ca7bd5fb1..8d6e203778 100644 --- a/packages/core/src/hooks/types.ts +++ b/packages/core/src/hooks/types.ts @@ -353,22 +353,6 @@ export class AfterModelHookOutput extends DefaultHookOutput { } } - // If hook wants to stop execution, create a synthetic stop response - if (this.shouldStopExecution()) { - const stopResponse: LLMResponse = { - candidates: [ - { - content: { - role: 'model', - parts: [this.getEffectiveReason() || 'Execution stopped by hook'], - }, - finishReason: 'STOP', - }, - ], - }; - return defaultHookTranslator.fromHookLLMResponse(stopResponse); - } - return undefined; } }