mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
Add other hook wrapper methods to hooksystem (#16361)
This commit is contained in:
@@ -46,9 +46,8 @@ import type {
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import * as policyCatalog from '../availability/policyCatalog.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
|
||||
@@ -137,15 +136,22 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||
},
|
||||
}));
|
||||
vi.mock('../hooks/hookSystem.js');
|
||||
vi.mock('./clientHookTriggers.js', () => ({
|
||||
fireBeforeAgentHook: vi.fn(),
|
||||
fireAfterAgentHook: vi.fn().mockResolvedValue({
|
||||
decision: 'allow',
|
||||
continue: false,
|
||||
suppressOutput: false,
|
||||
systemMessage: undefined,
|
||||
const mockHookSystem = {
|
||||
fireBeforeAgentEvent: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
finalOutput: undefined,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
}),
|
||||
}));
|
||||
fireAfterAgentEvent: vi.fn().mockResolvedValue({
|
||||
success: true,
|
||||
finalOutput: undefined,
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
}),
|
||||
};
|
||||
|
||||
/**
|
||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||
@@ -286,9 +292,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.fn()
|
||||
.mockReturnValue(createAvailabilityServiceMock()),
|
||||
} as unknown as Config;
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
.mockReturnValue(new HookSystem(mockConfig));
|
||||
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
|
||||
|
||||
client = new GeminiClient(mockConfig);
|
||||
await client.initialize();
|
||||
@@ -2688,9 +2692,6 @@ ${JSON.stringify(
|
||||
const promptId = 'test-prompt-hook-1';
|
||||
const request = { text: 'Hello Hooks' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
@@ -2702,11 +2703,10 @@ ${JSON.stringify(
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
|
||||
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
|
||||
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||
partToString(request),
|
||||
'Hook Response',
|
||||
);
|
||||
|
||||
@@ -2725,9 +2725,6 @@ ${JSON.stringify(
|
||||
const promptId = 'test-prompt-hook-recursive';
|
||||
const request = { text: 'Recursion Test' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
let callCount = 0;
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
@@ -2743,15 +2740,14 @@ ${JSON.stringify(
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
// BeforeAgent should fire ONLY once despite multiple internal turns
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
|
||||
|
||||
// AfterAgent should fire ONLY when the stack unwinds
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Check cumulative response (separated by newline)
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||
partToString(request),
|
||||
'Response 1\nResponse 2',
|
||||
);
|
||||
|
||||
@@ -2769,7 +2765,6 @@ ${JSON.stringify(
|
||||
const promptId = 'test-prompt-hook-original-req';
|
||||
const request = { text: 'Do something' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
@@ -2781,9 +2776,8 @@ ${JSON.stringify(
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request, // Should be 'Do something'
|
||||
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||
partToString(request), // Should be 'Do something'
|
||||
expect.stringContaining('Ok'),
|
||||
);
|
||||
});
|
||||
@@ -2817,11 +2811,17 @@ ${JSON.stringify(
|
||||
});
|
||||
|
||||
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
|
||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped by hook',
|
||||
} as DefaultHookOutput);
|
||||
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
|
||||
success: true,
|
||||
finalOutput: {
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped by hook',
|
||||
systemMessage: undefined,
|
||||
},
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
});
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
@@ -2850,12 +2850,18 @@ ${JSON.stringify(
|
||||
});
|
||||
|
||||
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
|
||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Blocked by hook',
|
||||
} as DefaultHookOutput);
|
||||
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
|
||||
success: true,
|
||||
finalOutput: {
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Blocked by hook',
|
||||
systemMessage: undefined,
|
||||
},
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
});
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
@@ -2883,11 +2889,17 @@ ${JSON.stringify(
|
||||
});
|
||||
|
||||
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireAfterAgentHook).mockResolvedValue({
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped after agent',
|
||||
} as DefaultHookOutput);
|
||||
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
|
||||
success: true,
|
||||
finalOutput: {
|
||||
shouldStopExecution: () => true,
|
||||
getEffectiveReason: () => 'Stopped after agent',
|
||||
systemMessage: undefined,
|
||||
},
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
});
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Hello' };
|
||||
@@ -2909,17 +2921,30 @@ ${JSON.stringify(
|
||||
});
|
||||
|
||||
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
vi.mocked(fireAfterAgentHook)
|
||||
mockHookSystem.fireAfterAgentEvent
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Please explain',
|
||||
} as DefaultHookOutput)
|
||||
success: true,
|
||||
finalOutput: {
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => true,
|
||||
getEffectiveReason: () => 'Please explain',
|
||||
systemMessage: undefined,
|
||||
},
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => false,
|
||||
} as DefaultHookOutput);
|
||||
success: true,
|
||||
finalOutput: {
|
||||
shouldStopExecution: () => false,
|
||||
isBlockingDecision: () => false,
|
||||
systemMessage: undefined,
|
||||
},
|
||||
allOutputs: [],
|
||||
errors: [],
|
||||
totalDuration: 0,
|
||||
});
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* () {
|
||||
yield { type: GeminiEventType.Content, value: 'Response' };
|
||||
|
||||
@@ -12,7 +12,6 @@ import type {
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
getInitialChatHistory,
|
||||
@@ -40,10 +39,6 @@ import {
|
||||
logContentRetryFailure,
|
||||
logNextSpeakerCheck,
|
||||
} from '../telemetry/loggers.js';
|
||||
import {
|
||||
fireBeforeAgentHook,
|
||||
fireAfterAgentHook,
|
||||
} from './clientHookTriggers.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
@@ -62,6 +57,7 @@ import {
|
||||
} from '../availability/policyHelpers.js';
|
||||
import { resolveModel } from '../config/models.js';
|
||||
import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||
import { partToString } from '../utils/partUtils.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
@@ -113,7 +109,6 @@ export class GeminiClient {
|
||||
>();
|
||||
|
||||
private async fireBeforeAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
prompt_id: string,
|
||||
): Promise<BeforeAgentHookReturn> {
|
||||
@@ -138,7 +133,10 @@ export class GeminiClient {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
const hookResult = await this.config
|
||||
.getHookSystem()
|
||||
?.fireBeforeAgentEvent(partToString(request));
|
||||
const hookOutput = hookResult?.finalOutput;
|
||||
hookState.hasFiredBeforeAgent = true;
|
||||
|
||||
if (hookOutput?.shouldStopExecution()) {
|
||||
@@ -169,7 +167,6 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
private async fireAfterAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
currentRequest: PartListUnion,
|
||||
prompt_id: string,
|
||||
turn?: Turn,
|
||||
@@ -190,11 +187,11 @@ export class GeminiClient {
|
||||
'[no response text]';
|
||||
const finalRequest = hookState.originalRequest || currentRequest;
|
||||
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
finalRequest,
|
||||
finalResponseText,
|
||||
);
|
||||
const hookResult = await this.config
|
||||
.getHookSystem()
|
||||
?.fireAfterAgentEvent(partToString(finalRequest), finalResponseText);
|
||||
const hookOutput = hookResult?.finalOutput;
|
||||
|
||||
return hookOutput;
|
||||
}
|
||||
|
||||
@@ -757,11 +754,7 @@ export class GeminiClient {
|
||||
}
|
||||
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookResult = await this.fireBeforeAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
);
|
||||
const hookResult = await this.fireBeforeAgentHookSafe(request, prompt_id);
|
||||
if (hookResult) {
|
||||
if (
|
||||
'type' in hookResult &&
|
||||
@@ -802,7 +795,6 @@ export class GeminiClient {
|
||||
// Fire AfterAgent hook if we have a turn and no pending tools
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await this.fireAfterAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
turn,
|
||||
|
||||
@@ -117,4 +117,28 @@ export class HookSystem {
|
||||
}
|
||||
return this.hookEventHandler.firePreCompressEvent(trigger);
|
||||
}
|
||||
|
||||
async fireBeforeAgentEvent(
|
||||
prompt: string,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
return this.hookEventHandler.fireBeforeAgentEvent(prompt);
|
||||
}
|
||||
|
||||
async fireAfterAgentEvent(
|
||||
prompt: string,
|
||||
response: string,
|
||||
stopHookActive: boolean = false,
|
||||
): Promise<AggregatedHookResult | undefined> {
|
||||
if (!this.config.getEnableHooks()) {
|
||||
return undefined;
|
||||
}
|
||||
return this.hookEventHandler.fireAfterAgentEvent(
|
||||
prompt,
|
||||
response,
|
||||
stopHookActive,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user