fix(hooks): deduplicate agent hooks and add cross-platform integration tests (#15701)

This commit is contained in:
Abhi
2025-12-30 14:13:16 -05:00
committed by GitHub
parent 4e6fee7fcd
commit 15c9f88da6
6 changed files with 779 additions and 227 deletions
+172 -4
View File
@@ -81,6 +81,10 @@ vi.mock('node:fs', () => {
});
// --- Mocks ---
interface MockTurnContext {
getResponseText: Mock<() => string>;
}
const mockTurnRunFn = vi.fn();
vi.mock('./turn', async (importOriginal) => {
@@ -94,6 +98,8 @@ vi.mock('./turn', async (importOriginal) => {
constructor() {
// The constructor can be empty or do some mock setup
}
getResponseText = vi.fn().mockReturnValue('Mock Response');
}
// Export the mock class as 'Turn'
return {
@@ -129,6 +135,15 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
},
}));
vi.mock('../hooks/hookSystem.js');
vi.mock('./clientHookTriggers.js', () => ({
fireBeforeAgentHook: vi.fn(),
fireAfterAgentHook: vi.fn().mockResolvedValue({
decision: 'allow',
continue: false,
suppressOutput: false,
systemMessage: undefined,
}),
}));
/**
* Array.fromAsync ponyfill, which will be available in es 2024.
@@ -543,16 +558,22 @@ describe('Gemini Client (client.ts)', () => {
await client.tryCompressChat('prompt-1', false); // force = false
// 3. Assert Step 1: Check that the flag became true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((client as any).hasFailedCompressionAttempt).toBe(true);
// 3. Assert Step 1: Check that the flag became true
expect(
(client as unknown as { hasFailedCompressionAttempt: boolean })
.hasFailedCompressionAttempt,
).toBe(true);
// 4. Test Step 2: Trigger a forced failure
await client.tryCompressChat('prompt-2', true); // force = true
// 5. Assert Step 2: Check that the flag REMAINS true
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((client as any).hasFailedCompressionAttempt).toBe(true);
// 5. Assert Step 2: Check that the flag REMAINS true
expect(
(client as unknown as { hasFailedCompressionAttempt: boolean })
.hasFailedCompressionAttempt,
).toBe(true);
});
it('should not trigger summarization if token count is below threshold', async () => {
@@ -2615,5 +2636,152 @@ ${JSON.stringify(
'test-session-id',
);
});
describe('Hook System', () => {
let mockMessageBus: { publish: Mock; subscribe: Mock };
beforeEach(() => {
vi.clearAllMocks();
mockMessageBus = { publish: vi.fn(), subscribe: vi.fn() };
// Force override config methods on the client instance
client['config'].getEnableHooks = vi.fn().mockReturnValue(true);
client['config'].getMessageBus = vi
.fn()
.mockReturnValue(mockMessageBus);
});
it('should fire BeforeAgent and AfterAgent exactly once for a simple turn', async () => {
const promptId = 'test-prompt-hook-1';
const request = { text: 'Hello Hooks' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Hook Response');
yield { type: GeminiEventType.Content, value: 'Hook Response' };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
'Hook Response',
);
// Map should be empty
expect(client['hookStateMap'].size).toBe(0);
});
it('should fire BeforeAgent once and AfterAgent once even with recursion', async () => {
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
vi.mocked(checkNextSpeaker)
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
.mockResolvedValueOnce(null);
const promptId = 'test-prompt-hook-recursive';
const request = { text: 'Recursion Test' };
const signal = new AbortController().signal;
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
'./clientHookTriggers.js'
);
let callCount = 0;
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
callCount++;
const response = `Response ${callCount}`;
this.getResponseText.mockReturnValue(response);
yield { type: GeminiEventType.Content, value: response };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
// BeforeAgent should fire ONLY once despite multiple internal turns
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
// AfterAgent should fire ONLY when the stack unwinds
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
// Check cumulative response (separated by newline)
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request,
'Response 1\nResponse 2',
);
expect(client['hookStateMap'].size).toBe(0);
});
it('should use original request in AfterAgent hook even when continuation happened', async () => {
const { checkNextSpeaker } = await import(
'../utils/nextSpeakerChecker.js'
);
vi.mocked(checkNextSpeaker)
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
.mockResolvedValueOnce(null);
const promptId = 'test-prompt-hook-original-req';
const request = { text: 'Do something' };
const signal = new AbortController().signal;
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Ok');
yield { type: GeminiEventType.Content, value: 'Ok' };
});
const stream = client.sendMessageStream(request, signal, promptId);
while (!(await stream.next()).done);
expect(fireAfterAgentHook).toHaveBeenCalledWith(
expect.anything(),
request, // Should be 'Do something'
expect.stringContaining('Ok'),
);
});
it('should cleanup state when prompt_id changes', async () => {
const signal = new AbortController().signal;
mockTurnRunFn.mockImplementation(async function* (
this: MockTurnContext,
) {
this.getResponseText.mockReturnValue('Ok');
yield { type: GeminiEventType.Content, value: 'Ok' };
});
client['hookStateMap'].set('old-id', {
hasFiredBeforeAgent: true,
cumulativeResponse: 'Old',
activeCalls: 0,
originalRequest: { text: 'Old' },
});
client['lastPromptId'] = 'old-id';
const stream = client.sendMessageStream(
{ text: 'New' },
signal,
'new-id',
);
await stream.next();
expect(client['hookStateMap'].has('old-id')).toBe(false);
expect(client['hookStateMap'].has('new-id')).toBe(true);
});
});
});
});
+272 -118
View File
@@ -11,6 +11,7 @@ import type {
Tool,
GenerateContentResponse,
} from '@google/genai';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
import {
getDirectoryContextString,
getInitialChatHistory,
@@ -42,6 +43,7 @@ import {
fireBeforeAgentHook,
fireAfterAgentHook,
} from './clientHookTriggers.js';
import type { DefaultHookOutput } from '../hooks/types.js';
import {
ContentRetryFailureEvent,
NextSpeakerCheckEvent,
@@ -61,6 +63,14 @@ import type { RetryAvailabilityContext } from '../utils/retry.js';
const MAX_TURNS = 100;
type BeforeAgentHookReturn =
| {
type: GeminiEventType.Error;
value: { error: Error };
}
| { additionalContext: string | undefined }
| undefined;
export class GeminiClient {
private chat?: GeminiChat;
private sessionTurnCount = 0;
@@ -84,6 +94,95 @@ export class GeminiClient {
this.lastPromptId = this.config.getSessionId();
}
// Hook state to deduplicate BeforeAgent calls and track response for
// AfterAgent
private hookStateMap = new Map<
string,
{
hasFiredBeforeAgent: boolean;
cumulativeResponse: string;
activeCalls: number;
originalRequest: PartListUnion;
}
>();
private async fireBeforeAgentHookSafe(
messageBus: MessageBus,
request: PartListUnion,
prompt_id: string,
): Promise<BeforeAgentHookReturn> {
let hookState = this.hookStateMap.get(prompt_id);
if (!hookState) {
hookState = {
hasFiredBeforeAgent: false,
cumulativeResponse: '',
activeCalls: 0,
originalRequest: request,
};
this.hookStateMap.set(prompt_id, hookState);
}
// Increment active calls for this prompt_id
// This is called at the start of sendMessageStream, so it acts as an entry
// counter. We increment here, assuming this helper is ALWAYS called at
// entry.
hookState.activeCalls++;
if (hookState.hasFiredBeforeAgent) {
return undefined;
}
const hookOutput = await fireBeforeAgentHook(messageBus, request);
hookState.hasFiredBeforeAgent = true;
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
return {
type: GeminiEventType.Error,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
},
};
}
const additionalContext = hookOutput?.getAdditionalContext();
if (additionalContext) {
return { additionalContext };
}
return undefined;
}
private async fireAfterAgentHookSafe(
messageBus: MessageBus,
currentRequest: PartListUnion,
prompt_id: string,
turn?: Turn,
): Promise<DefaultHookOutput | undefined> {
const hookState = this.hookStateMap.get(prompt_id);
// Only fire on the outermost call (when activeCalls is 1)
if (!hookState || hookState.activeCalls !== 1) {
return undefined;
}
if (turn && turn.pendingToolCalls.length > 0) {
return undefined;
}
const finalResponseText =
hookState.cumulativeResponse ||
turn?.getResponseText() ||
'[no response text]';
const finalRequest = hookState.originalRequest || currentRequest;
const hookOutput = await fireAfterAgentHook(
messageBus,
finalRequest,
finalResponseText,
);
return hookOutput;
}
private updateTelemetryTokenCount() {
if (this.chat) {
uiTelemetryService.setLastPromptTokenCount(
@@ -400,63 +499,27 @@ export class GeminiClient {
return this.config.getActiveModel();
}
async *sendMessageStream(
private async *processTurn(
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
boundedTurns: number,
isInvalidStreamRetry: boolean,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!isInvalidStreamRetry) {
this.config.resetTurn();
}
// Re-initialize turn (it was empty before if in loop, or new instance)
let turn = new Turn(this.getChat(), prompt_id);
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (hooksEnabled && messageBus) {
const hookOutput = await fireBeforeAgentHook(messageBus, request);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
yield {
type: GeminiEventType.Error,
value: {
error: new Error(
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
),
},
};
return new Turn(this.getChat(), prompt_id);
}
// Add additional context from hooks to the request
const additionalContext = hookOutput?.getAdditionalContext();
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [...requestArray, { text: additionalContext }];
}
}
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.lastPromptId = prompt_id;
this.currentSequenceModel = null;
}
this.sessionTurnCount++;
if (
this.config.getMaxSessionTurns() > 0 &&
this.sessionTurnCount > this.config.getMaxSessionTurns()
) {
yield { type: GeminiEventType.MaxSessionTurns };
return new Turn(this.getChat(), prompt_id);
return turn;
}
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
const boundedTurns = Math.min(turns, MAX_TURNS);
if (!boundedTurns) {
return new Turn(this.getChat(), prompt_id);
return turn;
}
// Check for context window overflow
@@ -478,7 +541,7 @@ export class GeminiClient {
type: GeminiEventType.ContextWindowWillOverflow,
value: { estimatedRequestTokenCount, remainingTokenCount },
};
return new Turn(this.getChat(), prompt_id);
return turn;
}
const compressed = await this.tryCompressChat(prompt_id, false);
@@ -514,7 +577,8 @@ export class GeminiClient {
this.forceFullIdeContext = false;
}
const turn = new Turn(this.getChat(), prompt_id);
// Re-initialize turn with fresh history
turn = new Turn(this.getChat(), prompt_id);
const controller = new AbortController();
const linkedSignal = AbortSignal.any([signal, controller.signal]);
@@ -555,6 +619,9 @@ export class GeminiClient {
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
let isError = false;
let isInvalidStream = false;
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
@@ -566,94 +633,181 @@ export class GeminiClient {
this.updateTelemetryTokenCount();
if (event.type === GeminiEventType.InvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (isInvalidStreamRetry) {
// We already retried once, so stop here.
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(
4, // 2 initial + 2 after injections
'FAILED_AFTER_PROMPT_INJECTION',
modelToUse,
),
);
return turn;
}
const nextRequest = [{ text: 'System: Please continue.' }];
yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
true, // Set isInvalidStreamRetry to true
isInvalidStream = true;
}
if (event.type === GeminiEventType.Error) {
isError = true;
}
}
if (isError) {
return turn;
}
// Update cumulative response in hook state
// We do this immediately after the stream finishes for THIS turn.
const hooksEnabled = this.config.getEnableHooks();
if (hooksEnabled) {
const responseText = turn.getResponseText() || '';
const hookState = this.hookStateMap.get(prompt_id);
if (hookState && responseText) {
// Append with newline if not empty
hookState.cumulativeResponse = hookState.cumulativeResponse
? `${hookState.cumulativeResponse}\n${responseText}`
: responseText;
}
}
if (isInvalidStream) {
if (this.config.getContinueOnFailedApiCall()) {
if (isInvalidStreamRetry) {
logContentRetryFailure(
this.config,
new ContentRetryFailureEvent(
4,
'FAILED_AFTER_PROMPT_INJECTION',
modelToUse,
),
);
return turn;
}
}
if (event.type === GeminiEventType.Error) {
return turn;
}
}
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
// Check if next speaker check is needed
if (this.config.getQuotaErrorOccurred()) {
return turn;
}
if (this.config.getSkipNextSpeakerCheck()) {
return turn;
}
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this.config.getBaseLlmClient(),
signal,
prompt_id,
);
logNextSpeakerCheck(
this.config,
new NextSpeakerCheckEvent(
prompt_id,
turn.finishReason?.toString() || '',
nextSpeakerCheck?.next_speaker || '',
),
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, and the final
// turn object from the recursive call will be returned.
return yield* this.sendMessageStream(
const nextRequest = [{ text: 'System: Please continue.' }];
// Recursive call - update turn with result
turn = yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
// isInvalidStreamRetry is false here, as this is a next speaker check
true,
);
return turn;
}
}
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
if (hooksEnabled && messageBus) {
const responseText = turn.getResponseText() || '[no response text]';
const hookOutput = await fireAfterAgentHook(
messageBus,
request,
responseText,
);
// For AfterAgent hooks, blocking/stop execution should force continuation
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
!this.config.getQuotaErrorOccurred() &&
!this.config.getSkipNextSpeakerCheck()
) {
const continueReason = hookOutput.getEffectiveReason();
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
const nextSpeakerCheck = await checkNextSpeaker(
this.getChat(),
this.config.getBaseLlmClient(),
signal,
prompt_id,
boundedTurns - 1,
);
logNextSpeakerCheck(
this.config,
new NextSpeakerCheckEvent(
prompt_id,
turn.finishReason?.toString() || '',
nextSpeakerCheck?.next_speaker || '',
),
);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
turn = yield* this.sendMessageStream(
nextRequest,
signal,
prompt_id,
boundedTurns - 1,
// isInvalidStreamRetry is false
);
return turn;
}
}
}
return turn;
}
async *sendMessageStream(
request: PartListUnion,
signal: AbortSignal,
prompt_id: string,
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!isInvalidStreamRetry) {
this.config.resetTurn();
}
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id);
this.hookStateMap.delete(this.lastPromptId);
this.lastPromptId = prompt_id;
this.currentSequenceModel = null;
}
if (hooksEnabled && messageBus) {
const hookResult = await this.fireBeforeAgentHookSafe(
messageBus,
request,
prompt_id,
);
if (hookResult) {
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
yield hookResult;
return new Turn(this.getChat(), prompt_id);
} else if ('additionalContext' in hookResult) {
const additionalContext = hookResult.additionalContext;
if (additionalContext) {
const requestArray = Array.isArray(request) ? request : [request];
request = [...requestArray, { text: additionalContext }];
}
}
}
}
const boundedTurns = Math.min(turns, MAX_TURNS);
let turn = new Turn(this.getChat(), prompt_id);
try {
turn = yield* this.processTurn(
request,
signal,
prompt_id,
boundedTurns,
isInvalidStreamRetry,
);
// Fire AfterAgent hook if we have a turn and no pending tools
if (hooksEnabled && messageBus) {
const hookOutput = await this.fireAfterAgentHookSafe(
messageBus,
request,
prompt_id,
turn,
);
if (
hookOutput?.isBlockingDecision() ||
hookOutput?.shouldStopExecution()
) {
const continueReason = hookOutput.getEffectiveReason();
const continueRequest = [{ text: continueReason }];
yield* this.sendMessageStream(
continueRequest,
signal,
prompt_id,
boundedTurns - 1,
);
}
}
} finally {
const hookState = this.hookStateMap.get(prompt_id);
if (hookState) {
hookState.activeCalls--;
const isPendingTools =
turn?.pendingToolCalls && turn.pendingToolCalls.length > 0;
const isAborted = signal?.aborted;
if (hookState.activeCalls <= 0) {
if (!isPendingTools || isAborted) {
this.hookStateMap.delete(prompt_id);
}
}
}
}