mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-12 12:54:07 -07:00
Add other hook wrapper methods to hooksystem (#16361)
This commit is contained in:
@@ -46,9 +46,8 @@ import type {
|
|||||||
ResolvedModelConfig,
|
ResolvedModelConfig,
|
||||||
} from '../services/modelConfigService.js';
|
} from '../services/modelConfigService.js';
|
||||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||||
import { HookSystem } from '../hooks/hookSystem.js';
|
|
||||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
|
||||||
import * as policyCatalog from '../availability/policyCatalog.js';
|
import * as policyCatalog from '../availability/policyCatalog.js';
|
||||||
|
import { partToString } from '../utils/partUtils.js';
|
||||||
|
|
||||||
vi.mock('../services/chatCompressionService.js');
|
vi.mock('../services/chatCompressionService.js');
|
||||||
|
|
||||||
@@ -137,15 +136,22 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
vi.mock('../hooks/hookSystem.js');
|
vi.mock('../hooks/hookSystem.js');
|
||||||
vi.mock('./clientHookTriggers.js', () => ({
|
const mockHookSystem = {
|
||||||
fireBeforeAgentHook: vi.fn(),
|
fireBeforeAgentEvent: vi.fn().mockResolvedValue({
|
||||||
fireAfterAgentHook: vi.fn().mockResolvedValue({
|
success: true,
|
||||||
decision: 'allow',
|
finalOutput: undefined,
|
||||||
continue: false,
|
allOutputs: [],
|
||||||
suppressOutput: false,
|
errors: [],
|
||||||
systemMessage: undefined,
|
totalDuration: 0,
|
||||||
}),
|
}),
|
||||||
}));
|
fireAfterAgentEvent: vi.fn().mockResolvedValue({
|
||||||
|
success: true,
|
||||||
|
finalOutput: undefined,
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||||
@@ -286,9 +292,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
.fn()
|
.fn()
|
||||||
.mockReturnValue(createAvailabilityServiceMock()),
|
.mockReturnValue(createAvailabilityServiceMock()),
|
||||||
} as unknown as Config;
|
} as unknown as Config;
|
||||||
mockConfig.getHookSystem = vi
|
mockConfig.getHookSystem = vi.fn().mockReturnValue(mockHookSystem);
|
||||||
.fn()
|
|
||||||
.mockReturnValue(new HookSystem(mockConfig));
|
|
||||||
|
|
||||||
client = new GeminiClient(mockConfig);
|
client = new GeminiClient(mockConfig);
|
||||||
await client.initialize();
|
await client.initialize();
|
||||||
@@ -2688,9 +2692,6 @@ ${JSON.stringify(
|
|||||||
const promptId = 'test-prompt-hook-1';
|
const promptId = 'test-prompt-hook-1';
|
||||||
const request = { text: 'Hello Hooks' };
|
const request = { text: 'Hello Hooks' };
|
||||||
const signal = new AbortController().signal;
|
const signal = new AbortController().signal;
|
||||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
|
||||||
'./clientHookTriggers.js'
|
|
||||||
);
|
|
||||||
|
|
||||||
mockTurnRunFn.mockImplementation(async function* (
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
this: MockTurnContext,
|
this: MockTurnContext,
|
||||||
@@ -2702,11 +2703,10 @@ ${JSON.stringify(
|
|||||||
const stream = client.sendMessageStream(request, signal, promptId);
|
const stream = client.sendMessageStream(request, signal, promptId);
|
||||||
while (!(await stream.next()).done);
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
|
||||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
|
||||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
partToString(request),
|
||||||
request,
|
|
||||||
'Hook Response',
|
'Hook Response',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -2725,9 +2725,6 @@ ${JSON.stringify(
|
|||||||
const promptId = 'test-prompt-hook-recursive';
|
const promptId = 'test-prompt-hook-recursive';
|
||||||
const request = { text: 'Recursion Test' };
|
const request = { text: 'Recursion Test' };
|
||||||
const signal = new AbortController().signal;
|
const signal = new AbortController().signal;
|
||||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
|
||||||
'./clientHookTriggers.js'
|
|
||||||
);
|
|
||||||
|
|
||||||
let callCount = 0;
|
let callCount = 0;
|
||||||
mockTurnRunFn.mockImplementation(async function* (
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
@@ -2743,15 +2740,14 @@ ${JSON.stringify(
|
|||||||
while (!(await stream.next()).done);
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
// BeforeAgent should fire ONLY once despite multiple internal turns
|
// BeforeAgent should fire ONLY once despite multiple internal turns
|
||||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
expect(mockHookSystem.fireBeforeAgentEvent).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
// AfterAgent should fire ONLY when the stack unwinds
|
// AfterAgent should fire ONLY when the stack unwinds
|
||||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
// Check cumulative response (separated by newline)
|
// Check cumulative response (separated by newline)
|
||||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
partToString(request),
|
||||||
request,
|
|
||||||
'Response 1\nResponse 2',
|
'Response 1\nResponse 2',
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -2769,7 +2765,6 @@ ${JSON.stringify(
|
|||||||
const promptId = 'test-prompt-hook-original-req';
|
const promptId = 'test-prompt-hook-original-req';
|
||||||
const request = { text: 'Do something' };
|
const request = { text: 'Do something' };
|
||||||
const signal = new AbortController().signal;
|
const signal = new AbortController().signal;
|
||||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
|
||||||
|
|
||||||
mockTurnRunFn.mockImplementation(async function* (
|
mockTurnRunFn.mockImplementation(async function* (
|
||||||
this: MockTurnContext,
|
this: MockTurnContext,
|
||||||
@@ -2781,9 +2776,8 @@ ${JSON.stringify(
|
|||||||
const stream = client.sendMessageStream(request, signal, promptId);
|
const stream = client.sendMessageStream(request, signal, promptId);
|
||||||
while (!(await stream.next()).done);
|
while (!(await stream.next()).done);
|
||||||
|
|
||||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
expect(mockHookSystem.fireAfterAgentEvent).toHaveBeenCalledWith(
|
||||||
expect.anything(),
|
partToString(request), // Should be 'Do something'
|
||||||
request, // Should be 'Do something'
|
|
||||||
expect.stringContaining('Ok'),
|
expect.stringContaining('Ok'),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
@@ -2817,11 +2811,17 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
|
it('should stop execution in BeforeAgent when hook returns continue: false', async () => {
|
||||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
|
||||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
success: true,
|
||||||
shouldStopExecution: () => true,
|
finalOutput: {
|
||||||
getEffectiveReason: () => 'Stopped by hook',
|
shouldStopExecution: () => true,
|
||||||
} as DefaultHookOutput);
|
getEffectiveReason: () => 'Stopped by hook',
|
||||||
|
systemMessage: undefined,
|
||||||
|
},
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
});
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
const mockChat: Partial<GeminiChat> = {
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
@@ -2850,12 +2850,18 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
|
it('should block execution in BeforeAgent when hook returns decision: block', async () => {
|
||||||
const { fireBeforeAgentHook } = await import('./clientHookTriggers.js');
|
mockHookSystem.fireBeforeAgentEvent.mockResolvedValue({
|
||||||
vi.mocked(fireBeforeAgentHook).mockResolvedValue({
|
success: true,
|
||||||
shouldStopExecution: () => false,
|
finalOutput: {
|
||||||
isBlockingDecision: () => true,
|
shouldStopExecution: () => false,
|
||||||
getEffectiveReason: () => 'Blocked by hook',
|
isBlockingDecision: () => true,
|
||||||
} as DefaultHookOutput);
|
getEffectiveReason: () => 'Blocked by hook',
|
||||||
|
systemMessage: undefined,
|
||||||
|
},
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
});
|
||||||
|
|
||||||
const mockChat: Partial<GeminiChat> = {
|
const mockChat: Partial<GeminiChat> = {
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
@@ -2883,11 +2889,17 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
|
it('should stop execution in AfterAgent when hook returns continue: false', async () => {
|
||||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
mockHookSystem.fireAfterAgentEvent.mockResolvedValue({
|
||||||
vi.mocked(fireAfterAgentHook).mockResolvedValue({
|
success: true,
|
||||||
shouldStopExecution: () => true,
|
finalOutput: {
|
||||||
getEffectiveReason: () => 'Stopped after agent',
|
shouldStopExecution: () => true,
|
||||||
} as DefaultHookOutput);
|
getEffectiveReason: () => 'Stopped after agent',
|
||||||
|
systemMessage: undefined,
|
||||||
|
},
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
});
|
||||||
|
|
||||||
mockTurnRunFn.mockImplementation(async function* () {
|
mockTurnRunFn.mockImplementation(async function* () {
|
||||||
yield { type: GeminiEventType.Content, value: 'Hello' };
|
yield { type: GeminiEventType.Content, value: 'Hello' };
|
||||||
@@ -2909,17 +2921,30 @@ ${JSON.stringify(
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
|
it('should yield AgentExecutionBlocked and recurse in AfterAgent when hook returns decision: block', async () => {
|
||||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
mockHookSystem.fireAfterAgentEvent
|
||||||
vi.mocked(fireAfterAgentHook)
|
|
||||||
.mockResolvedValueOnce({
|
.mockResolvedValueOnce({
|
||||||
shouldStopExecution: () => false,
|
success: true,
|
||||||
isBlockingDecision: () => true,
|
finalOutput: {
|
||||||
getEffectiveReason: () => 'Please explain',
|
shouldStopExecution: () => false,
|
||||||
} as DefaultHookOutput)
|
isBlockingDecision: () => true,
|
||||||
|
getEffectiveReason: () => 'Please explain',
|
||||||
|
systemMessage: undefined,
|
||||||
|
},
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
})
|
||||||
.mockResolvedValueOnce({
|
.mockResolvedValueOnce({
|
||||||
shouldStopExecution: () => false,
|
success: true,
|
||||||
isBlockingDecision: () => false,
|
finalOutput: {
|
||||||
} as DefaultHookOutput);
|
shouldStopExecution: () => false,
|
||||||
|
isBlockingDecision: () => false,
|
||||||
|
systemMessage: undefined,
|
||||||
|
},
|
||||||
|
allOutputs: [],
|
||||||
|
errors: [],
|
||||||
|
totalDuration: 0,
|
||||||
|
});
|
||||||
|
|
||||||
mockTurnRunFn.mockImplementation(async function* () {
|
mockTurnRunFn.mockImplementation(async function* () {
|
||||||
yield { type: GeminiEventType.Content, value: 'Response' };
|
yield { type: GeminiEventType.Content, value: 'Response' };
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import type {
|
|||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { createUserContent } from '@google/genai';
|
import { createUserContent } from '@google/genai';
|
||||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
|
||||||
import {
|
import {
|
||||||
getDirectoryContextString,
|
getDirectoryContextString,
|
||||||
getInitialChatHistory,
|
getInitialChatHistory,
|
||||||
@@ -40,10 +39,6 @@ import {
|
|||||||
logContentRetryFailure,
|
logContentRetryFailure,
|
||||||
logNextSpeakerCheck,
|
logNextSpeakerCheck,
|
||||||
} from '../telemetry/loggers.js';
|
} from '../telemetry/loggers.js';
|
||||||
import {
|
|
||||||
fireBeforeAgentHook,
|
|
||||||
fireAfterAgentHook,
|
|
||||||
} from './clientHookTriggers.js';
|
|
||||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||||
import {
|
import {
|
||||||
ContentRetryFailureEvent,
|
ContentRetryFailureEvent,
|
||||||
@@ -62,6 +57,7 @@ import {
|
|||||||
} from '../availability/policyHelpers.js';
|
} from '../availability/policyHelpers.js';
|
||||||
import { resolveModel } from '../config/models.js';
|
import { resolveModel } from '../config/models.js';
|
||||||
import type { RetryAvailabilityContext } from '../utils/retry.js';
|
import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||||
|
import { partToString } from '../utils/partUtils.js';
|
||||||
|
|
||||||
const MAX_TURNS = 100;
|
const MAX_TURNS = 100;
|
||||||
|
|
||||||
@@ -113,7 +109,6 @@ export class GeminiClient {
|
|||||||
>();
|
>();
|
||||||
|
|
||||||
private async fireBeforeAgentHookSafe(
|
private async fireBeforeAgentHookSafe(
|
||||||
messageBus: MessageBus,
|
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
): Promise<BeforeAgentHookReturn> {
|
): Promise<BeforeAgentHookReturn> {
|
||||||
@@ -138,7 +133,10 @@ export class GeminiClient {
|
|||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
const hookResult = await this.config
|
||||||
|
.getHookSystem()
|
||||||
|
?.fireBeforeAgentEvent(partToString(request));
|
||||||
|
const hookOutput = hookResult?.finalOutput;
|
||||||
hookState.hasFiredBeforeAgent = true;
|
hookState.hasFiredBeforeAgent = true;
|
||||||
|
|
||||||
if (hookOutput?.shouldStopExecution()) {
|
if (hookOutput?.shouldStopExecution()) {
|
||||||
@@ -169,7 +167,6 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async fireAfterAgentHookSafe(
|
private async fireAfterAgentHookSafe(
|
||||||
messageBus: MessageBus,
|
|
||||||
currentRequest: PartListUnion,
|
currentRequest: PartListUnion,
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
turn?: Turn,
|
turn?: Turn,
|
||||||
@@ -190,11 +187,11 @@ export class GeminiClient {
|
|||||||
'[no response text]';
|
'[no response text]';
|
||||||
const finalRequest = hookState.originalRequest || currentRequest;
|
const finalRequest = hookState.originalRequest || currentRequest;
|
||||||
|
|
||||||
const hookOutput = await fireAfterAgentHook(
|
const hookResult = await this.config
|
||||||
messageBus,
|
.getHookSystem()
|
||||||
finalRequest,
|
?.fireAfterAgentEvent(partToString(finalRequest), finalResponseText);
|
||||||
finalResponseText,
|
const hookOutput = hookResult?.finalOutput;
|
||||||
);
|
|
||||||
return hookOutput;
|
return hookOutput;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -757,11 +754,7 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (hooksEnabled && messageBus) {
|
if (hooksEnabled && messageBus) {
|
||||||
const hookResult = await this.fireBeforeAgentHookSafe(
|
const hookResult = await this.fireBeforeAgentHookSafe(request, prompt_id);
|
||||||
messageBus,
|
|
||||||
request,
|
|
||||||
prompt_id,
|
|
||||||
);
|
|
||||||
if (hookResult) {
|
if (hookResult) {
|
||||||
if (
|
if (
|
||||||
'type' in hookResult &&
|
'type' in hookResult &&
|
||||||
@@ -802,7 +795,6 @@ export class GeminiClient {
|
|||||||
// Fire AfterAgent hook if we have a turn and no pending tools
|
// Fire AfterAgent hook if we have a turn and no pending tools
|
||||||
if (hooksEnabled && messageBus) {
|
if (hooksEnabled && messageBus) {
|
||||||
const hookOutput = await this.fireAfterAgentHookSafe(
|
const hookOutput = await this.fireAfterAgentHookSafe(
|
||||||
messageBus,
|
|
||||||
request,
|
request,
|
||||||
prompt_id,
|
prompt_id,
|
||||||
turn,
|
turn,
|
||||||
|
|||||||
@@ -117,4 +117,28 @@ export class HookSystem {
|
|||||||
}
|
}
|
||||||
return this.hookEventHandler.firePreCompressEvent(trigger);
|
return this.hookEventHandler.firePreCompressEvent(trigger);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fireBeforeAgentEvent(
|
||||||
|
prompt: string,
|
||||||
|
): Promise<AggregatedHookResult | undefined> {
|
||||||
|
if (!this.config.getEnableHooks()) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return this.hookEventHandler.fireBeforeAgentEvent(prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
async fireAfterAgentEvent(
|
||||||
|
prompt: string,
|
||||||
|
response: string,
|
||||||
|
stopHookActive: boolean = false,
|
||||||
|
): Promise<AggregatedHookResult | undefined> {
|
||||||
|
if (!this.config.getEnableHooks()) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return this.hookEventHandler.fireAfterAgentEvent(
|
||||||
|
prompt,
|
||||||
|
response,
|
||||||
|
stopHookActive,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user