mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(hooks): Hook LLM Request/Response Integration (#9110)
This commit is contained in:
@@ -27,6 +27,8 @@ import { AuthType } from './contentGenerator.js';
|
||||
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
|
||||
import { retryWithBackoff, type RetryOptions } from '../utils/retry.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -154,12 +156,20 @@ describe('GeminiChat', () => {
|
||||
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
|
||||
setPreviewModelFallbackMode: vi.fn(),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
} as unknown as Config;
|
||||
|
||||
// Use proper MessageBus mocking for Phase 3 preparation
|
||||
const mockMessageBus = createMockMessageBus();
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(mockMessageBus);
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
// Reset history for each test by creating a new instance
|
||||
chat = new GeminiChat(mockConfig);
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -14,6 +14,7 @@ import type {
|
||||
Tool,
|
||||
PartListUnion,
|
||||
GenerateContentConfig,
|
||||
GenerateContentParameters,
|
||||
} from '@google/genai';
|
||||
import { ThinkingLevel } from '@google/genai';
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
@@ -47,6 +48,11 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
import {
|
||||
fireAfterModelHook,
|
||||
fireBeforeModelHook,
|
||||
fireBeforeToolSelectionHook,
|
||||
} from './geminiChatHookTriggers.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
/** A regular content chunk from the API. */
|
||||
@@ -287,9 +293,9 @@ export class GeminiChat {
|
||||
this.history.push(userContent);
|
||||
const requestContents = this.getHistory(true);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-this-alias
|
||||
const self = this;
|
||||
return (async function* () {
|
||||
const streamWithRetries = async function* (
|
||||
this: GeminiChat,
|
||||
): AsyncGenerator<StreamEvent, void, void> {
|
||||
try {
|
||||
let lastError: unknown = new Error('Request failed after all retries.');
|
||||
|
||||
@@ -297,7 +303,7 @@ export class GeminiChat {
|
||||
// If we are in Preview Model Fallback Mode, we want to fail fast (1 attempt)
|
||||
// when probing the Preview Model.
|
||||
if (
|
||||
self.config.isPreviewModelFallbackMode() &&
|
||||
this.config.isPreviewModelFallbackMode() &&
|
||||
model === PREVIEW_GEMINI_MODEL
|
||||
) {
|
||||
maxAttempts = 1;
|
||||
@@ -314,7 +320,7 @@ export class GeminiChat {
|
||||
generateContentConfig.temperature = 1;
|
||||
}
|
||||
|
||||
const stream = await self.makeApiCallAndProcessStream(
|
||||
const stream = await this.makeApiCallAndProcessStream(
|
||||
model,
|
||||
generateContentConfig,
|
||||
requestContents,
|
||||
@@ -335,7 +341,7 @@ export class GeminiChat {
|
||||
// Check if we have more attempts left.
|
||||
if (attempt < maxAttempts - 1) {
|
||||
logContentRetry(
|
||||
self.config,
|
||||
this.config,
|
||||
new ContentRetryEvent(
|
||||
attempt,
|
||||
(error as InvalidStreamError).type,
|
||||
@@ -363,7 +369,7 @@ export class GeminiChat {
|
||||
isGemini2Model(model)
|
||||
) {
|
||||
logContentRetryFailure(
|
||||
self.config,
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
maxAttempts,
|
||||
(lastError as InvalidStreamError).type,
|
||||
@@ -377,15 +383,17 @@ export class GeminiChat {
|
||||
// We only do this if we didn't bypass Preview Model (i.e. we actually used it).
|
||||
if (
|
||||
model === PREVIEW_GEMINI_MODEL &&
|
||||
!self.config.isPreviewModelBypassMode()
|
||||
!this.config.isPreviewModelBypassMode()
|
||||
) {
|
||||
self.config.setPreviewModelFallbackMode(false);
|
||||
this.config.setPreviewModelFallbackMode(false);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
streamDoneResolver!();
|
||||
}
|
||||
})();
|
||||
};
|
||||
|
||||
return streamWithRetries.call(this);
|
||||
}
|
||||
|
||||
private async makeApiCallAndProcessStream(
|
||||
@@ -397,7 +405,13 @@ export class GeminiChat {
|
||||
let effectiveModel = model;
|
||||
const contentsForPreviewModel =
|
||||
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
||||
const apiCall = () => {
|
||||
|
||||
// Track final request parameters for AfterModel hooks
|
||||
let lastModelToUse = model;
|
||||
let lastConfig: GenerateContentConfig = generateContentConfig;
|
||||
let lastContentsToUse: Content[] = requestContents;
|
||||
|
||||
const apiCall = async () => {
|
||||
let modelToUse = getEffectiveModel(
|
||||
this.config.isInFallbackMode(),
|
||||
model,
|
||||
@@ -439,14 +453,79 @@ export class GeminiChat {
|
||||
};
|
||||
delete config.thinkingConfig?.thinkingLevel;
|
||||
}
|
||||
let contentsToUse =
|
||||
modelToUse === PREVIEW_GEMINI_MODEL
|
||||
? contentsForPreviewModel
|
||||
: requestContents;
|
||||
|
||||
// Fire BeforeModel and BeforeToolSelection hooks if enabled
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
// Fire BeforeModel hook
|
||||
const beforeModelResult = await fireBeforeModelHook(messageBus, {
|
||||
model: modelToUse,
|
||||
config,
|
||||
contents: contentsToUse,
|
||||
});
|
||||
|
||||
// Check if hook blocked the model call
|
||||
if (beforeModelResult.blocked) {
|
||||
// Return a synthetic response generator
|
||||
const syntheticResponse = beforeModelResult.syntheticResponse;
|
||||
if (syntheticResponse) {
|
||||
return (async function* () {
|
||||
yield syntheticResponse;
|
||||
})();
|
||||
}
|
||||
// If blocked without synthetic response, return empty generator
|
||||
return (async function* () {
|
||||
// Empty generator - no response
|
||||
})();
|
||||
}
|
||||
|
||||
// Apply modifications from BeforeModel hook
|
||||
if (beforeModelResult.modifiedConfig) {
|
||||
Object.assign(config, beforeModelResult.modifiedConfig);
|
||||
}
|
||||
if (
|
||||
beforeModelResult.modifiedContents &&
|
||||
Array.isArray(beforeModelResult.modifiedContents)
|
||||
) {
|
||||
contentsToUse = beforeModelResult.modifiedContents as Content[];
|
||||
}
|
||||
|
||||
// Fire BeforeToolSelection hook
|
||||
const toolSelectionResult = await fireBeforeToolSelectionHook(
|
||||
messageBus,
|
||||
{
|
||||
model: modelToUse,
|
||||
config,
|
||||
contents: contentsToUse,
|
||||
},
|
||||
);
|
||||
|
||||
// Apply tool configuration modifications
|
||||
if (toolSelectionResult.toolConfig) {
|
||||
config.toolConfig = toolSelectionResult.toolConfig;
|
||||
}
|
||||
if (
|
||||
toolSelectionResult.tools &&
|
||||
Array.isArray(toolSelectionResult.tools)
|
||||
) {
|
||||
config.tools = toolSelectionResult.tools as Tool[];
|
||||
}
|
||||
}
|
||||
|
||||
// Track final request parameters for AfterModel hooks
|
||||
lastModelToUse = modelToUse;
|
||||
lastConfig = config;
|
||||
lastContentsToUse = contentsToUse;
|
||||
|
||||
return this.config.getContentGenerator().generateContentStream(
|
||||
{
|
||||
model: modelToUse,
|
||||
contents:
|
||||
modelToUse === PREVIEW_GEMINI_MODEL
|
||||
? contentsForPreviewModel
|
||||
: requestContents,
|
||||
contents: contentsToUse,
|
||||
config,
|
||||
},
|
||||
prompt_id,
|
||||
@@ -470,7 +549,18 @@ export class GeminiChat {
|
||||
: undefined,
|
||||
});
|
||||
|
||||
return this.processStreamResponse(effectiveModel, streamResponse);
|
||||
// Store the original request for AfterModel hooks
|
||||
const originalRequest: GenerateContentParameters = {
|
||||
model: lastModelToUse,
|
||||
config: lastConfig,
|
||||
contents: lastContentsToUse,
|
||||
};
|
||||
|
||||
return this.processStreamResponse(
|
||||
effectiveModel,
|
||||
streamResponse,
|
||||
originalRequest,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -624,6 +714,7 @@ export class GeminiChat {
|
||||
private async *processStreamResponse(
|
||||
model: string,
|
||||
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
||||
originalRequest: GenerateContentParameters,
|
||||
): AsyncGenerator<GenerateContentResponse> {
|
||||
const modelResponseParts: Part[] = [];
|
||||
|
||||
@@ -663,7 +754,19 @@ export class GeminiChat {
|
||||
}
|
||||
}
|
||||
|
||||
yield chunk; // Yield every chunk to the UI immediately.
|
||||
// Fire AfterModel hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus && originalRequest && chunk) {
|
||||
const hookResult = await fireAfterModelHook(
|
||||
messageBus,
|
||||
originalRequest,
|
||||
chunk,
|
||||
);
|
||||
yield hookResult.response;
|
||||
} else {
|
||||
yield chunk; // Yield every chunk to the UI immediately.
|
||||
}
|
||||
}
|
||||
|
||||
// String thoughts and consolidate text parts.
|
||||
|
||||
235
packages/core/src/core/geminiChatHookTriggers.ts
Normal file
235
packages/core/src/core/geminiChatHookTriggers.ts
Normal file
@@ -0,0 +1,235 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
GenerateContentResponse,
|
||||
GenerateContentParameters,
|
||||
GenerateContentConfig,
|
||||
ContentListUnion,
|
||||
ToolConfig,
|
||||
ToolListUnion,
|
||||
} from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
MessageBusType,
|
||||
type HookExecutionRequest,
|
||||
type HookExecutionResponse,
|
||||
} from '../confirmation-bus/types.js';
|
||||
import {
|
||||
createHookOutput,
|
||||
type BeforeModelHookOutput,
|
||||
type BeforeToolSelectionHookOutput,
|
||||
type AfterModelHookOutput,
|
||||
} from '../hooks/types.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Result from firing the BeforeModel hook.
|
||||
*/
|
||||
export interface BeforeModelHookResult {
|
||||
/** Whether the model call was blocked */
|
||||
blocked: boolean;
|
||||
/** Reason for blocking (if blocked) */
|
||||
reason?: string;
|
||||
/** Synthetic response to return instead of calling the model (if blocked) */
|
||||
syntheticResponse?: GenerateContentResponse;
|
||||
/** Modified config (if not blocked) */
|
||||
modifiedConfig?: GenerateContentConfig;
|
||||
/** Modified contents (if not blocked) */
|
||||
modifiedContents?: ContentListUnion;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result from firing the BeforeToolSelection hook.
|
||||
*/
|
||||
export interface BeforeToolSelectionHookResult {
|
||||
/** Modified tool config */
|
||||
toolConfig?: ToolConfig;
|
||||
/** Modified tools */
|
||||
tools?: ToolListUnion;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result from firing the AfterModel hook.
|
||||
* Contains either a modified response or indicates to use the original chunk.
|
||||
*/
|
||||
export interface AfterModelHookResult {
|
||||
/** The response to yield (either modified or original) */
|
||||
response: GenerateContentResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the BeforeModel hook and returns the result.
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param llmRequest The LLM request parameters
|
||||
* @returns The hook result with blocking info or modifications
|
||||
*/
|
||||
export async function fireBeforeModelHook(
|
||||
messageBus: MessageBus,
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<BeforeModelHookResult> {
|
||||
try {
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'BeforeModel',
|
||||
input: {
|
||||
llm_request: llmRequest as unknown as Record<string, unknown>,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
// Reconstruct result from response
|
||||
const beforeResultFinalOutput = response.output
|
||||
? createHookOutput('BeforeModel', response.output)
|
||||
: undefined;
|
||||
|
||||
const hookOutput = beforeResultFinalOutput;
|
||||
|
||||
// Check if hook blocked the model call or requested to stop execution
|
||||
const blockingError = hookOutput?.getBlockingError();
|
||||
if (blockingError?.blocked || hookOutput?.shouldStopExecution()) {
|
||||
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
|
||||
const syntheticResponse = beforeModelOutput.getSyntheticResponse();
|
||||
const reason =
|
||||
hookOutput?.getEffectiveReason() || 'Model call blocked by hook';
|
||||
|
||||
return {
|
||||
blocked: true,
|
||||
reason,
|
||||
syntheticResponse,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply modifications from hook
|
||||
if (hookOutput) {
|
||||
const beforeModelOutput = hookOutput as BeforeModelHookOutput;
|
||||
const modifiedRequest =
|
||||
beforeModelOutput.applyLLMRequestModifications(llmRequest);
|
||||
|
||||
return {
|
||||
blocked: false,
|
||||
modifiedConfig: modifiedRequest.config,
|
||||
modifiedContents: modifiedRequest.contents,
|
||||
};
|
||||
}
|
||||
|
||||
return { blocked: false };
|
||||
} catch (error) {
|
||||
debugLogger.warn(`BeforeModel hook failed:`, error);
|
||||
return { blocked: false };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the BeforeToolSelection hook and returns the result.
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param llmRequest The LLM request parameters
|
||||
* @returns The hook result with tool configuration modifications
|
||||
*/
|
||||
export async function fireBeforeToolSelectionHook(
|
||||
messageBus: MessageBus,
|
||||
llmRequest: GenerateContentParameters,
|
||||
): Promise<BeforeToolSelectionHookResult> {
|
||||
try {
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'BeforeToolSelection',
|
||||
input: {
|
||||
llm_request: llmRequest as unknown as Record<string, unknown>,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
// Reconstruct result from response
|
||||
const toolSelectionResultFinalOutput = response.output
|
||||
? createHookOutput('BeforeToolSelection', response.output)
|
||||
: undefined;
|
||||
|
||||
// Apply tool configuration modifications
|
||||
if (toolSelectionResultFinalOutput) {
|
||||
const beforeToolSelectionOutput =
|
||||
toolSelectionResultFinalOutput as BeforeToolSelectionHookOutput;
|
||||
const modifiedConfig =
|
||||
beforeToolSelectionOutput.applyToolConfigModifications({
|
||||
toolConfig: llmRequest.config?.toolConfig,
|
||||
tools: llmRequest.config?.tools,
|
||||
});
|
||||
|
||||
return {
|
||||
toolConfig: modifiedConfig.toolConfig,
|
||||
tools: modifiedConfig.tools,
|
||||
};
|
||||
}
|
||||
|
||||
return {};
|
||||
} catch (error) {
|
||||
debugLogger.warn(`BeforeToolSelection hook failed:`, error);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Fires the AfterModel hook and returns the result.
|
||||
*
|
||||
* @param messageBus The message bus to use for hook communication
|
||||
* @param originalRequest The original LLM request parameters
|
||||
* @param chunk The current response chunk from the model
|
||||
* @returns The hook result containing the response to yield
|
||||
*/
|
||||
export async function fireAfterModelHook(
|
||||
messageBus: MessageBus,
|
||||
originalRequest: GenerateContentParameters,
|
||||
chunk: GenerateContentResponse,
|
||||
): Promise<AfterModelHookResult> {
|
||||
try {
|
||||
const response = await messageBus.request<
|
||||
HookExecutionRequest,
|
||||
HookExecutionResponse
|
||||
>(
|
||||
{
|
||||
type: MessageBusType.HOOK_EXECUTION_REQUEST,
|
||||
eventName: 'AfterModel',
|
||||
input: {
|
||||
llm_request: originalRequest as unknown as Record<string, unknown>,
|
||||
llm_response: chunk as unknown as Record<string, unknown>,
|
||||
},
|
||||
},
|
||||
MessageBusType.HOOK_EXECUTION_RESPONSE,
|
||||
);
|
||||
|
||||
// Reconstruct result from response
|
||||
const afterResultFinalOutput = response.output
|
||||
? createHookOutput('AfterModel', response.output)
|
||||
: undefined;
|
||||
|
||||
// Apply modifications from hook (handles both normal modifications and stop execution)
|
||||
if (afterResultFinalOutput) {
|
||||
const afterModelOutput = afterResultFinalOutput as AfterModelHookOutput;
|
||||
const modifiedResponse = afterModelOutput.getModifiedResponse();
|
||||
if (modifiedResponse) {
|
||||
return { response: modifiedResponse };
|
||||
}
|
||||
}
|
||||
|
||||
return { response: chunk };
|
||||
} catch (error) {
|
||||
debugLogger.warn(`AfterModel hook failed:`, error);
|
||||
// On error, return original chunk to avoid interrupting the stream.
|
||||
return { response: chunk };
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user