mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-15 16:41:11 -07:00
fix: migrate BeforeModel and AfterModel hooks to HookSystem (#16599)
Co-authored-by: Tommaso Sciortino <sciortino@gmail.com>
This commit is contained in:
@@ -22,24 +22,12 @@ import { AuthType } from './contentGenerator.js';
|
||||
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
|
||||
import { type RetryOptions } from '../utils/retry.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import * as policyHelpers from '../availability/policyHelpers.js';
|
||||
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
|
||||
import {
|
||||
fireBeforeModelHook,
|
||||
fireAfterModelHook,
|
||||
fireBeforeToolSelectionHook,
|
||||
} from './geminiChatHookTriggers.js';
|
||||
|
||||
// Mock hook triggers
|
||||
vi.mock('./geminiChatHookTriggers.js', () => ({
|
||||
fireBeforeModelHook: vi.fn(),
|
||||
fireAfterModelHook: vi.fn(),
|
||||
fireBeforeToolSelectionHook: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
import type { HookSystem } from '../hooks/hookSystem.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -204,9 +192,7 @@ describe('GeminiChat', () => {
|
||||
setSimulate429(false);
|
||||
// Reset history for each test by creating a new instance
|
||||
chat = new GeminiChat(mockConfig);
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
mockConfig.getHookSystem = vi.fn().mockReturnValue(undefined);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -2283,18 +2269,20 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
describe('Hook execution control', () => {
|
||||
let mockHookSystem: HookSystem;
|
||||
beforeEach(() => {
|
||||
vi.mocked(mockConfig.getEnableHooks).mockReturnValue(true);
|
||||
// Default to allowing execution
|
||||
vi.mocked(fireBeforeModelHook).mockResolvedValue({ blocked: false });
|
||||
vi.mocked(fireAfterModelHook).mockResolvedValue({
|
||||
response: {} as GenerateContentResponse,
|
||||
});
|
||||
vi.mocked(fireBeforeToolSelectionHook).mockResolvedValue({});
|
||||
|
||||
mockHookSystem = {
|
||||
fireBeforeModelEvent: vi.fn().mockResolvedValue({ blocked: false }),
|
||||
fireAfterModelEvent: vi.fn().mockResolvedValue({ response: {} }),
|
||||
fireBeforeToolSelectionEvent: vi.fn().mockResolvedValue({}),
|
||||
} as unknown as HookSystem;
|
||||
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
|
||||
});
|
||||
|
||||
it('should yield AGENT_EXECUTION_STOPPED when BeforeModel hook stops execution', async () => {
|
||||
vi.mocked(fireBeforeModelHook).mockResolvedValue({
|
||||
vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
|
||||
blocked: true,
|
||||
stopped: true,
|
||||
reason: 'stopped by hook',
|
||||
@@ -2324,7 +2312,7 @@ describe('GeminiChat', () => {
|
||||
candidates: [{ content: { parts: [{ text: 'blocked' }] } }],
|
||||
} as GenerateContentResponse;
|
||||
|
||||
vi.mocked(fireBeforeModelHook).mockResolvedValue({
|
||||
vi.mocked(mockHookSystem.fireBeforeModelEvent).mockResolvedValue({
|
||||
blocked: true,
|
||||
reason: 'blocked by hook',
|
||||
syntheticResponse,
|
||||
@@ -2363,7 +2351,7 @@ describe('GeminiChat', () => {
|
||||
})(),
|
||||
);
|
||||
|
||||
vi.mocked(fireAfterModelHook).mockResolvedValue({
|
||||
vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
|
||||
response: {} as GenerateContentResponse,
|
||||
stopped: true,
|
||||
reason: 'stopped by after hook',
|
||||
@@ -2399,7 +2387,7 @@ describe('GeminiChat', () => {
|
||||
})(),
|
||||
);
|
||||
|
||||
vi.mocked(fireAfterModelHook).mockResolvedValue({
|
||||
vi.mocked(mockHookSystem.fireAfterModelEvent).mockResolvedValue({
|
||||
response,
|
||||
blocked: true,
|
||||
reason: 'blocked by after hook',
|
||||
|
||||
@@ -49,11 +49,6 @@ import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
import {
|
||||
fireAfterModelHook,
|
||||
fireBeforeModelHook,
|
||||
fireBeforeToolSelectionHook,
|
||||
} from './geminiChatHookTriggers.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
@@ -507,39 +502,26 @@ export class GeminiChat {
|
||||
? contentsForPreviewModel
|
||||
: requestContents;
|
||||
|
||||
// Fire BeforeModel and BeforeToolSelection hooks if enabled
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
// Fire BeforeModel hook
|
||||
const beforeModelResult = await fireBeforeModelHook(messageBus, {
|
||||
const hookSystem = this.config.getHookSystem();
|
||||
if (hookSystem) {
|
||||
const beforeModelResult = await hookSystem.fireBeforeModelEvent({
|
||||
model: modelToUse,
|
||||
config,
|
||||
contents: contentsToUse,
|
||||
});
|
||||
|
||||
// Check if hook requested to stop execution
|
||||
if (beforeModelResult.stopped) {
|
||||
throw new AgentExecutionStoppedError(
|
||||
beforeModelResult.reason || 'Agent execution stopped by hook',
|
||||
);
|
||||
}
|
||||
|
||||
// Check if hook blocked the model call
|
||||
if (beforeModelResult.blocked) {
|
||||
// Return a synthetic response generator
|
||||
const syntheticResponse = beforeModelResult.syntheticResponse;
|
||||
if (syntheticResponse) {
|
||||
// Ensure synthetic response has a finish reason to prevent InvalidStreamError
|
||||
if (
|
||||
syntheticResponse.candidates &&
|
||||
syntheticResponse.candidates.length > 0
|
||||
) {
|
||||
for (const candidate of syntheticResponse.candidates) {
|
||||
if (!candidate.finishReason) {
|
||||
candidate.finishReason = FinishReason.STOP;
|
||||
}
|
||||
}
|
||||
|
||||
for (const candidate of syntheticResponse?.candidates ?? []) {
|
||||
if (!candidate.finishReason) {
|
||||
candidate.finishReason = FinishReason.STOP;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -549,7 +531,6 @@ export class GeminiChat {
|
||||
);
|
||||
}
|
||||
|
||||
// Apply modifications from BeforeModel hook
|
||||
if (beforeModelResult.modifiedConfig) {
|
||||
Object.assign(config, beforeModelResult.modifiedConfig);
|
||||
}
|
||||
@@ -560,17 +541,13 @@ export class GeminiChat {
|
||||
contentsToUse = beforeModelResult.modifiedContents as Content[];
|
||||
}
|
||||
|
||||
// Fire BeforeToolSelection hook
|
||||
const toolSelectionResult = await fireBeforeToolSelectionHook(
|
||||
messageBus,
|
||||
{
|
||||
const toolSelectionResult =
|
||||
await hookSystem.fireBeforeToolSelectionEvent({
|
||||
model: modelToUse,
|
||||
config,
|
||||
contents: contentsToUse,
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
// Apply tool configuration modifications
|
||||
if (toolSelectionResult.toolConfig) {
|
||||
config.toolConfig = toolSelectionResult.toolConfig;
|
||||
}
|
||||
@@ -825,12 +802,9 @@ export class GeminiChat {
|
||||
}
|
||||
}
|
||||
|
||||
// Fire AfterModel hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus && originalRequest && chunk) {
|
||||
const hookResult = await fireAfterModelHook(
|
||||
messageBus,
|
||||
const hookSystem = this.config.getHookSystem();
|
||||
if (originalRequest && chunk && hookSystem) {
|
||||
const hookResult = await hookSystem.fireAfterModelEvent(
|
||||
originalRequest,
|
||||
chunk,
|
||||
);
|
||||
@@ -850,7 +824,7 @@ export class GeminiChat {
|
||||
|
||||
yield hookResult.response;
|
||||
} else {
|
||||
yield chunk; // Yield every chunk to the UI immediately.
|
||||
yield chunk;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,13 +19,24 @@ import type {
|
||||
SessionEndReason,
|
||||
PreCompressTrigger,
|
||||
DefaultHookOutput,
|
||||
BeforeModelHookOutput,
|
||||
AfterModelHookOutput,
|
||||
BeforeToolSelectionHookOutput,
|
||||
} from './types.js';
|
||||
import type { AggregatedHookResult } from './hookAggregator.js';
|
||||
import type {
|
||||
GenerateContentParameters,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type {
|
||||
AfterModelHookResult,
|
||||
BeforeModelHookResult,
|
||||
BeforeToolSelectionHookResult,
|
||||
} from '../core/geminiChatHookTriggers.js';
|
||||
/**
|
||||
* Main hook system that coordinates all hook-related functionality
|
||||
*/
|
||||
export class HookSystem {
|
||||
private readonly config: Config;
|
||||
private readonly hookRegistry: HookRegistry;
|
||||
private readonly hookRunner: HookRunner;
|
||||
private readonly hookAggregator: HookAggregator;
|
||||
@@ -33,7 +44,6 @@ export class HookSystem {
|
||||
private readonly hookEventHandler: HookEventHandler;
|
||||
|
||||
constructor(config: Config) {
|
||||
this.config = config;
|
||||
const logger: Logger = logs.getLogger(SERVICE_NAME);
|
||||
const messageBus = config.getMessageBus();
|
||||
|
||||
@@ -90,14 +100,10 @@ export class HookSystem {
|
||||
|
||||
/**
|
||||
* Fire hook events directly
|
||||
* Returns undefined if hooks are disabled
|
||||
*/
|
||||
async fireSessionStartEvent(
|
||||
source: SessionStartSource,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
const result = await this.hookEventHandler.fireSessionStartEvent(source);
|
||||
return result.finalOutput;
|
||||
}
|
||||
@@ -105,27 +111,18 @@ export class HookSystem {
|
||||
async fireSessionEndEvent(
|
||||
reason: SessionEndReason,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
return this.hookEventHandler.fireSessionEndEvent(reason);
|
||||
}
|
||||
|
||||
async firePreCompressEvent(
|
||||
trigger: PreCompressTrigger,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
return this.hookEventHandler.firePreCompressEvent(trigger);
|
||||
}
|
||||
|
||||
async fireBeforeAgentEvent(
|
||||
prompt: string,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
const result = await this.hookEventHandler.fireBeforeAgentEvent(prompt);
|
||||
return result.finalOutput;
|
||||
}
|
||||
@@ -135,9 +132,6 @@ export class HookSystem {
|
||||
response: string,
|
||||
stopHookActive: boolean = false,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
const result = await this.hookEventHandler.fireAfterAgentEvent(
|
||||
prompt,
|
||||
response,
|
||||
@@ -145,4 +139,121 @@ export class HookSystem {
|
||||
);
|
||||
return result.finalOutput;
|
||||
}
|
||||
|
||||
async fireBeforeModelEvent(
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<BeforeModelHookResult> {
|
||||
try {
|
||||
const result =
|
||||
await this.hookEventHandler.fireBeforeModelEvent(llmRequest);
|
||||
const hookOutput = result.finalOutput;
|
||||
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
return {
|
||||
blocked: true,
|
||||
stopped: true,
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
};
|
||||
}
|
||||
|
||||
const blockingError = hookOutput?.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
|
||||
const syntheticResponse = beforeModelOutput.getSyntheticResponse();
|
||||
return {
|
||||
blocked: true,
|
||||
reason:
|
||||
hookOutput?.getEffectiveReason() || 'Model call blocked by hook',
|
||||
syntheticResponse,
|
||||
};
|
||||
}
|
||||
|
||||
if (hookOutput) {
|
||||
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
|
||||
const modifiedRequest =
|
||||
beforeModelOutput.applyLLMRequestModifications(llmRequest);
|
||||
return {
|
||||
blocked: false,
|
||||
modifiedConfig: modifiedRequest?.config,
|
||||
modifiedContents: modifiedRequest?.contents,
|
||||
};
|
||||
}
|
||||
|
||||
return { blocked: false };
|
||||
} catch (error) {
|
||||
debugLogger.debug(`BeforeModelHookEvent failed:`, error);
|
||||
return { blocked: false };
|
||||
}
|
||||
}
|
||||
|
||||
async fireAfterModelEvent(
|
||||
originalRequest: GenerateContentParameters,
|
||||
chunk: GenerateContentResponse,
|
||||
): Promise<AfterModelHookResult> {
|
||||
try {
|
||||
const result = await this.hookEventHandler.fireAfterModelEvent(
|
||||
originalRequest,
|
||||
chunk,
|
||||
);
|
||||
const hookOutput = result.finalOutput;
|
||||
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
return {
|
||||
response: chunk,
|
||||
stopped: true,
|
||||
reason: hookOutput.getEffectiveReason(),
|
||||
};
|
||||
}
|
||||
|
||||
const blockingError = hookOutput?.getBlockingError();
|
||||
if (blockingError?.blocked) {
|
||||
return {
|
||||
response: chunk,
|
||||
blocked: true,
|
||||
reason: hookOutput?.getEffectiveReason(),
|
||||
};
|
||||
}
|
||||
|
||||
if (hookOutput) {
|
||||
const afterModelOutput = hookOutput as AfterModelHookOutput;
|
||||
const modifiedResponse = afterModelOutput.getModifiedResponse();
|
||||
if (modifiedResponse) {
|
||||
return { response: modifiedResponse };
|
||||
}
|
||||
}
|
||||
|
||||
return { response: chunk };
|
||||
} catch (error) {
|
||||
debugLogger.debug(`AfterModelHookEvent failed:`, error);
|
||||
return { response: chunk };
|
||||
}
|
||||
}
|
||||
|
||||
async fireBeforeToolSelectionEvent(
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<BeforeToolSelectionHookResult> {
|
||||
try {
|
||||
const result =
|
||||
await this.hookEventHandler.fireBeforeToolSelectionEvent(llmRequest);
|
||||
const hookOutput = result.finalOutput;
|
||||
|
||||
if (hookOutput) {
|
||||
const toolSelectionOutput = hookOutput as BeforeToolSelectionHookOutput;
|
||||
const modifiedConfig = toolSelectionOutput.applyToolConfigModifications(
|
||||
{
|
||||
toolConfig: llmRequest.config?.toolConfig,
|
||||
tools: llmRequest.config?.tools,
|
||||
},
|
||||
);
|
||||
return {
|
||||
toolConfig: modifiedConfig.toolConfig,
|
||||
tools: modifiedConfig.tools,
|
||||
};
|
||||
}
|
||||
return {};
|
||||
} catch (error) {
|
||||
debugLogger.debug(`BeforeToolSelectionEvent failed:`, error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user