feat(hooks): Support explicit stop and block execution control in model hooks (#15947)

Co-authored-by: matt korwel <matt.korwel@gmail.com>
This commit is contained in:
Sandy Tao
2026-01-09 13:36:27 +08:00
committed by GitHub
parent 18dd399cb5
commit e1e3efc9d0
7 changed files with 517 additions and 65 deletions

View File

@@ -28,6 +28,18 @@ import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import * as policyHelpers from '../availability/policyHelpers.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
import {
fireBeforeModelHook,
fireAfterModelHook,
fireBeforeToolSelectionHook,
} from './geminiChatHookTriggers.js';
// Mock hook triggers
vi.mock('./geminiChatHookTriggers.js', () => ({
fireBeforeModelHook: vi.fn(),
fireAfterModelHook: vi.fn(),
fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}),
}));
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -2269,4 +2281,151 @@ describe('GeminiChat', () => {
);
});
});
describe('Hook execution control', () => {
beforeEach(() => {
vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true);
// Default to allowing execution
vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false });
vi.mocked(fireAfterModelHook).mockResolvedValue({
response: {} as GenerateContentResponse,
});
vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({});
});
it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => {
vi.mocked(fireBeforeModelHook).mockResolvedValue({
blocked: true,
stopped: true,
reason: 'stopped by hook',
});
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-id',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
expect(events).toHaveLength(1);
expect(events[0]).toEqual({
type: StreamEventType.AGENT_EXECUTION_STOPPED,
reason: 'stopped by hook',
});
});
it('should yield AGENT_EXECUTION_BLOCKED and synthetic response when BeforeModel hook blocks execution', async () => {
const syntheticResponse = {
candidates: [{ content: { parts: [{ text: 'blocked' }] } }],
} as GenerateContentResponse;
vi.mocked(fireBeforeModelHook).mockResolvedValue({
blocked: true,
reason: 'blocked by hook',
syntheticResponse,
});
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-id',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
expect(events).toHaveLength(2);
expect(events[0]).toEqual({
type: StreamEventType.AGENT_EXECUTION_BLOCKED,
reason: 'blocked by hook',
});
expect(events[1]).toEqual({
type: StreamEventType.CHUNK,
value: syntheticResponse,
});
});
it('should yield AGENT_EXECUTION_STOPPED when AfterModel hook stops execution', async () => {
// Mock content generator to return a stream
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'response' }] } }],
} as unknown as GenerateContentResponse;
})(),
);
vi.mocked(fireAfterModelHook).mockResolvedValue({
response: {} as GenerateContentResponse,
stopped: true,
reason: 'stopped by after hook',
});
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-id',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
expect(events).toContainEqual({
type: StreamEventType.AGENT_EXECUTION_STOPPED,
reason: 'stopped by after hook',
});
});
it('should yield AGENT_EXECUTION_BLOCKED and response when AfterModel hook blocks execution', async () => {
const response = {
candidates: [{ content: { parts: [{ text: 'response' }] } }],
} as unknown as GenerateContentResponse;
// Mock content generator to return a stream
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield response;
})(),
);
vi.mocked(fireAfterModelHook).mockResolvedValue({
response,
blocked: true,
reason: 'blocked by after hook',
});
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-id',
new AbortController().signal,
);
const events: StreamEvent[] = [];
for await (const event of stream) {
events.push(event);
}
expect(events).toContainEqual({
type: StreamEventType.AGENT_EXECUTION_BLOCKED,
reason: 'blocked by after hook',
});
// Should also contain the chunk (hook response)
expect(events).toContainEqual({
type: StreamEventType.CHUNK,
value: response,
});
});
});
});

View File

@@ -61,11 +61,17 @@ export enum StreamEventType {
/** A signal that a retry is about to happen. The UI should discard any partial
* content from the attempt that just failed. */
RETRY = 'retry',
/** A signal that the agent execution has been stopped by a hook. */
AGENT_EXECUTION_STOPPED = 'agent_execution_stopped',
/** A signal that the agent execution has been blocked by a hook. */
AGENT_EXECUTION_BLOCKED = 'agent_execution_blocked',
}
export type StreamEvent =
| { type: StreamEventType.CHUNK; value: GenerateContentResponse }
| { type: StreamEventType.RETRY };
| { type: StreamEventType.RETRY }
| { type: StreamEventType.AGENT_EXECUTION_STOPPED; reason: string }
| { type: StreamEventType.AGENT_EXECUTION_BLOCKED; reason: string };
/**
* Options for retrying due to invalid content from the model.
@@ -197,6 +203,29 @@ export class InvalidStreamError extends Error {
}
}
/**
* Custom error to signal that agent execution has been stopped.
*/
export class AgentExecutionStoppedError extends Error {
constructor(public reason: string) {
super(reason);
this.name = 'AgentExecutionStoppedError';
}
}
/**
* Custom error to signal that agent execution has been blocked.
*/
export class AgentExecutionBlockedError extends Error {
constructor(
public reason: string,
public syntheticResponse?: GenerateContentResponse,
) {
super(reason);
this.name = 'AgentExecutionBlockedError';
}
}
/**
* Chat session that enables sending messages to the model with previous
* conversation context.
@@ -325,6 +354,30 @@ export class GeminiChat {
lastError = null;
break;
} catch (error) {
if (error instanceof AgentExecutionStoppedError) {
yield {
type: StreamEventType.AGENT_EXECUTION_STOPPED,
reason: error.reason,
};
lastError = null; // Clear error as this is an expected stop
return; // Stop the generator
}
if (error instanceof AgentExecutionBlockedError) {
yield {
type: StreamEventType.AGENT_EXECUTION_BLOCKED,
reason: error.reason,
};
if (error.syntheticResponse) {
yield {
type: StreamEventType.CHUNK,
value: error.syntheticResponse,
};
}
lastError = null; // Clear error as this is an expected stop
return; // Stop the generator
}
if (isConnectionPhase) {
throw error;
}
@@ -457,19 +510,35 @@ export class GeminiChat {
contents: contentsToUse,
});
// Check if hook requested to stop execution
if (beforeModelResult.stopped) {
throw new AgentExecutionStoppedError(
beforeModelResult.reason || 'Agent execution stopped by hook',
);
}
// Check if hook blocked the model call
if (beforeModelResult.blocked) {
// Return a synthetic response generator
const syntheticResponse = beforeModelResult.syntheticResponse;
if (syntheticResponse) {
return (async function* () {
yield syntheticResponse;
})();
// Ensure synthetic response has a finish reason to prevent InvalidStreamError
if (
syntheticResponse.candidates &&
syntheticResponse.candidates.length > 0
) {
for (const candidate of syntheticResponse.candidates) {
if (!candidate.finishReason) {
candidate.finishReason = FinishReason.STOP;
}
}
}
}
// If blocked without synthetic response, return empty generator
return (async function* () {
// Empty generator - no response
})();
throw new AgentExecutionBlockedError(
beforeModelResult.reason || 'Model call blocked by hook',
syntheticResponse,
);
}
// Apply modifications from BeforeModel hook
@@ -748,6 +817,20 @@ export class GeminiChat {
originalRequest,
chunk,
);
if (hookResult.stopped) {
throw new AgentExecutionStoppedError(
hookResult.reason || 'Agent execution stopped by hook',
);
}
if (hookResult.blocked) {
throw new AgentExecutionBlockedError(
hookResult.reason || 'Agent execution blocked by hook',
hookResult.response,
);
}
yield hookResult.response;
} else {
yield chunk; // Yield every chunk to the UI immediately.

View File

@@ -0,0 +1,204 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
fireBeforeModelHook,
fireAfterModelHook,
} from './geminiChatHookTriggers.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import type {
GenerateContentParameters,
GenerateContentResponse,
} from '@google/genai';
// Mock dependencies
const mockRequest = vi.fn();
const mockMessageBus = {
request: mockRequest,
} as unknown as MessageBus;
// Mock hook types
vi.mock('../hooks/types.js', async () => {
const actual = await vi.importActual('../hooks/types.js');
return {
...actual,
createHookOutput: vi.fn(),
};
});
import { createHookOutput } from '../hooks/types.js';
describe('Gemini Chat Hook Triggers', () => {
beforeEach(() => {
vi.clearAllMocks();
});
describe('fireBeforeModelHook', () => {
const llmRequest = {
model: 'gemini-pro',
contents: [{ parts: [{ text: 'test' }] }],
} as GenerateContentParameters;
it('should return stopped: true when hook requests stop execution', async () => {
mockRequest.mockResolvedValue({
output: { continue: false, stopReason: 'stopped by hook' },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'stopped by hook',
getBlockingError: () => ({ blocked: false, reason: '' }),
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireBeforeModelHook(mockMessageBus, llmRequest);
expect(result).toEqual({
blocked: true,
stopped: true,
reason: 'stopped by hook',
});
});
it('should return blocked: true when hook blocks execution', async () => {
mockRequest.mockResolvedValue({
output: { decision: 'block', reason: 'blocked by hook' },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => false,
getBlockingError: () => ({ blocked: true, reason: 'blocked by hook' }),
getEffectiveReason: () => 'blocked by hook',
getSyntheticResponse: () => undefined,
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireBeforeModelHook(mockMessageBus, llmRequest);
expect(result).toEqual({
blocked: true,
reason: 'blocked by hook',
syntheticResponse: undefined,
});
});
it('should return modifications when hook allows execution', async () => {
mockRequest.mockResolvedValue({
output: { decision: 'allow' },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => false,
getBlockingError: () => ({ blocked: false, reason: '' }),
applyLLMRequestModifications: (req: GenerateContentParameters) => req,
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireBeforeModelHook(mockMessageBus, llmRequest);
expect(result).toEqual({
blocked: false,
modifiedConfig: undefined,
modifiedContents: llmRequest.contents,
});
});
});
describe('fireAfterModelHook', () => {
const llmRequest = {
model: 'gemini-pro',
contents: [],
} as GenerateContentParameters;
const llmResponse = {
candidates: [
{ content: { role: 'model', parts: [{ text: 'response' }] } },
],
} as GenerateContentResponse;
it('should return stopped: true when hook requests stop execution', async () => {
mockRequest.mockResolvedValue({
output: { continue: false, stopReason: 'stopped by hook' },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => true,
getEffectiveReason: () => 'stopped by hook',
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireAfterModelHook(
mockMessageBus,
llmRequest,
llmResponse,
);
expect(result).toEqual({
response: llmResponse,
stopped: true,
reason: 'stopped by hook',
});
});
it('should return blocked: true when hook blocks execution', async () => {
mockRequest.mockResolvedValue({
output: { decision: 'block', reason: 'blocked by hook' },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => false,
getBlockingError: () => ({ blocked: true, reason: 'blocked by hook' }),
getEffectiveReason: () => 'blocked by hook',
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireAfterModelHook(
mockMessageBus,
llmRequest,
llmResponse,
);
expect(result).toEqual({
response: llmResponse,
blocked: true,
reason: 'blocked by hook',
});
});
it('should return modified response when hook modifies response', async () => {
const modifiedResponse = { ...llmResponse, text: 'modified' };
mockRequest.mockResolvedValue({
output: { hookSpecificOutput: { llm_response: {} } },
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => false,
getBlockingError: () => ({ blocked: false, reason: '' }),
getModifiedResponse: () => modifiedResponse,
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireAfterModelHook(
mockMessageBus,
llmRequest,
llmResponse,
);
expect(result).toEqual({
response: modifiedResponse,
});
});
it('should return original response when hook has no effect', async () => {
mockRequest.mockResolvedValue({
output: {},
});
vi.mocked(createHookOutput).mockReturnValue({
shouldStopExecution: () => false,
getBlockingError: () => ({ blocked: false, reason: '' }),
getModifiedResponse: () => undefined,
} as unknown as ReturnType<typeof createHookOutput>);
const result = await fireAfterModelHook(
mockMessageBus,
llmRequest,
llmResponse,
);
expect(result).toEqual({
response: llmResponse,
});
});
});
});

View File

@@ -32,6 +32,8 @@ import { debugLogger } from '../utils/debugLogger.js';
export interface BeforeModelHookResult {
/** Whether the model call was blocked */
blocked: boolean;
/** Whether the execution should be stopped entirely */
stopped?: boolean;
/** Reason for blocking (if blocked) */
reason?: string;
/** Synthetic response to return instead of calling the model (if blocked) */
@@ -59,14 +61,16 @@ export interface BeforeToolSelectionHookResult {
export interface AfterModelHookResult {
/** The response to yield (either modified or original) */
response: GenerateContentResponse;
/** Whether the execution should be stopped entirely */
stopped?: boolean;
/** Whether the model call was blocked */
blocked?: boolean;
/** Reason for blocking or stopping */
reason?: string;
}
/**
* Fires the BeforeModel hook and returns the result.
*
* @param messageBus The message bus to use for hook communication
* @param llmRequest The LLM request parameters
* @returns The hook result with blocking info or modifications
*/
export async function fireBeforeModelHook(
messageBus: MessageBus,
@@ -94,9 +98,18 @@ export async function fireBeforeModelHook(
const hookOutput = beforeResultFinalOutput;
// Check if hook blocked the model call or requested to stop execution
// Check if hook requested to stop execution
if (hookOutput?.shouldStopExecution()) {
return {
blocked: true,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
// Check if hook blocked the model call
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked || hookOutput?.shouldStopExecution()) {
if (blockingError?.blocked) {
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
const syntheticResponse = beforeModelOutput.getSyntheticResponse();
const reason =
@@ -217,9 +230,30 @@ export async function fireAfterModelHook(
? createHookOutput('AfterModel', response.output)
: undefined;
// Apply modifications from hook (handles both normal modifications and stop execution)
if (afterResultFinalOutput) {
const afterModelOutput = afterResultFinalOutput as AfterModelHookOutput;
const hookOutput = afterResultFinalOutput;
// Check if hook requested to stop execution
if (hookOutput?.shouldStopExecution()) {
return {
response: chunk,
stopped: true,
reason: hookOutput.getEffectiveReason(),
};
}
// Check if hook blocked the model call
const blockingError = hookOutput?.getBlockingError();
if (blockingError?.blocked) {
return {
response: chunk,
blocked: true,
reason: hookOutput?.getEffectiveReason(),
};
}
// Apply modifications from hook
if (hookOutput) {
const afterModelOutput = hookOutput as AfterModelHookOutput;
const modifiedResponse = afterModelOutput.getModifiedResponse();
if (modifiedResponse) {
return { response: modifiedResponse };

View File

@@ -264,6 +264,22 @@ export class Turn {
continue; // Skip to the next event in the stream
}
if (streamEvent.type === 'agent_execution_stopped') {
yield {
type: GeminiEventType.AgentExecutionStopped,
value: { reason: streamEvent.reason },
};
return;
}
if (streamEvent.type === 'agent_execution_blocked') {
yield {
type: GeminiEventType.AgentExecutionBlocked,
value: { reason: streamEvent.reason },
};
continue;
}
// Assuming other events are chunks with a `value` property
const resp = streamEvent.value;
if (!resp) continue; // Skip if there's no response body

View File

@@ -319,45 +319,17 @@ describe('Hook Output Classes', () => {
expect(output.getModifiedResponse()).toBeUndefined();
});
it('getModifiedResponse should return a synthetic stop response if shouldStopExecution is true', () => {
it('getModifiedResponse should return undefined if shouldStopExecution is true', () => {
const output = new AfterModelHookOutput({
continue: false,
stopReason: 'stopped by hook',
});
const expectedResponse: LLMResponse = {
candidates: [
{
content: {
role: 'model',
parts: ['stopped by hook'],
},
finishReason: 'STOP',
},
],
};
expect(output.getModifiedResponse()).toEqual(expectedResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
expectedResponse,
);
expect(output.getModifiedResponse()).toBeUndefined();
});
it('getModifiedResponse should return a synthetic stop response with default reason if shouldStopExecution is true and no stopReason', () => {
it('getModifiedResponse should return undefined if shouldStopExecution is true and no stopReason', () => {
const output = new AfterModelHookOutput({ continue: false });
const expectedResponse: LLMResponse = {
candidates: [
{
content: {
role: 'model',
parts: ['No reason provided'],
},
finishReason: 'STOP',
},
],
};
expect(output.getModifiedResponse()).toEqual(expectedResponse);
expect(defaultHookTranslator.fromHookLLMResponse).toHaveBeenCalledWith(
expectedResponse,
);
expect(output.getModifiedResponse()).toBeUndefined();
});
});
});

View File

@@ -353,22 +353,6 @@ export class AfterModelHookOutput extends DefaultHookOutput {
}
}
// If hook wants to stop execution, create a synthetic stop response
if (this.shouldStopExecution()) {
const stopResponse: LLMResponse = {
candidates: [
{
content: {
role: 'model',
parts: [this.getEffectiveReason() || 'Execution stopped by hook'],
},
finishReason: 'STOP',
},
],
};
return defaultHookTranslator.fromHookLLMResponse(stopResponse);
}
return undefined;
}
}