mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-13 13:22:35 -07:00
fix(hooks): deduplicate agent hooks and add cross-platform integration tests (#15701)
This commit is contained in:
@@ -81,6 +81,10 @@ vi.mock('node:fs', () => {
|
||||
});
|
||||
|
||||
// --- Mocks ---
|
||||
interface MockTurnContext {
|
||||
getResponseText: Mock<() => string>;
|
||||
}
|
||||
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('./turn', async (importOriginal) => {
|
||||
@@ -94,6 +98,8 @@ vi.mock('./turn', async (importOriginal) => {
|
||||
constructor() {
|
||||
// The constructor can be empty or do some mock setup
|
||||
}
|
||||
|
||||
getResponseText = vi.fn().mockReturnValue('Mock Response');
|
||||
}
|
||||
// Export the mock class as 'Turn'
|
||||
return {
|
||||
@@ -129,6 +135,15 @@ vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||
},
|
||||
}));
|
||||
vi.mock('../hooks/hookSystem.js');
|
||||
vi.mock('./clientHookTriggers.js', () => ({
|
||||
fireBeforeAgentHook: vi.fn(),
|
||||
fireAfterAgentHook: vi.fn().mockResolvedValue({
|
||||
decision: 'allow',
|
||||
continue: false,
|
||||
suppressOutput: false,
|
||||
systemMessage: undefined,
|
||||
}),
|
||||
}));
|
||||
|
||||
/**
|
||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||
@@ -543,16 +558,22 @@ describe('Gemini Client (client.ts)', () => {
|
||||
await client.tryCompressChat('prompt-1', false); // force = false
|
||||
|
||||
// 3. Assert Step 1: Check that the flag became true
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
||||
// 3. Assert Step 1: Check that the flag became true
|
||||
expect(
|
||||
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||
.hasFailedCompressionAttempt,
|
||||
).toBe(true);
|
||||
|
||||
// 4. Test Step 2: Trigger a forced failure
|
||||
|
||||
await client.tryCompressChat('prompt-2', true); // force = true
|
||||
|
||||
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
expect((client as any).hasFailedCompressionAttempt).toBe(true);
|
||||
// 5. Assert Step 2: Check that the flag REMAINS true
|
||||
expect(
|
||||
(client as unknown as { hasFailedCompressionAttempt: boolean })
|
||||
.hasFailedCompressionAttempt,
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should not trigger summarization if token count is below threshold', async () => {
|
||||
@@ -2615,5 +2636,152 @@ ${JSON.stringify(
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
describe('Hook System', () => {
|
||||
let mockMessageBus: { publish: Mock; subscribe: Mock };
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockMessageBus = { publish: vi.fn(), subscribe: vi.fn() };
|
||||
|
||||
// Force override config methods on the client instance
|
||||
client['config'].getEnableHooks = vi.fn().mockReturnValue(true);
|
||||
client['config'].getMessageBus = vi
|
||||
.fn()
|
||||
.mockReturnValue(mockMessageBus);
|
||||
});
|
||||
|
||||
it('should fire BeforeAgent and AfterAgent exactly once for a simple turn', async () => {
|
||||
const promptId = 'test-prompt-hook-1';
|
||||
const request = { text: 'Hello Hooks' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Hook Response');
|
||||
yield { type: GeminiEventType.Content, value: 'Hook Response' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
'Hook Response',
|
||||
);
|
||||
|
||||
// Map should be empty
|
||||
expect(client['hookStateMap'].size).toBe(0);
|
||||
});
|
||||
|
||||
it('should fire BeforeAgent once and AfterAgent once even with recursion', async () => {
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
vi.mocked(checkNextSpeaker)
|
||||
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||
.mockResolvedValueOnce(null);
|
||||
|
||||
const promptId = 'test-prompt-hook-recursive';
|
||||
const request = { text: 'Recursion Test' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireBeforeAgentHook, fireAfterAgentHook } = await import(
|
||||
'./clientHookTriggers.js'
|
||||
);
|
||||
|
||||
let callCount = 0;
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
callCount++;
|
||||
const response = `Response ${callCount}`;
|
||||
this.getResponseText.mockReturnValue(response);
|
||||
yield { type: GeminiEventType.Content, value: response };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
// BeforeAgent should fire ONLY once despite multiple internal turns
|
||||
expect(fireBeforeAgentHook).toHaveBeenCalledTimes(1);
|
||||
|
||||
// AfterAgent should fire ONLY when the stack unwinds
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Check cumulative response (separated by newline)
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request,
|
||||
'Response 1\nResponse 2',
|
||||
);
|
||||
|
||||
expect(client['hookStateMap'].size).toBe(0);
|
||||
});
|
||||
|
||||
it('should use original request in AfterAgent hook even when continuation happened', async () => {
|
||||
const { checkNextSpeaker } = await import(
|
||||
'../utils/nextSpeakerChecker.js'
|
||||
);
|
||||
vi.mocked(checkNextSpeaker)
|
||||
.mockResolvedValueOnce({ next_speaker: 'model', reasoning: 'more' })
|
||||
.mockResolvedValueOnce(null);
|
||||
|
||||
const promptId = 'test-prompt-hook-original-req';
|
||||
const request = { text: 'Do something' };
|
||||
const signal = new AbortController().signal;
|
||||
const { fireAfterAgentHook } = await import('./clientHookTriggers.js');
|
||||
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Ok');
|
||||
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||
});
|
||||
|
||||
const stream = client.sendMessageStream(request, signal, promptId);
|
||||
while (!(await stream.next()).done);
|
||||
|
||||
expect(fireAfterAgentHook).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
request, // Should be 'Do something'
|
||||
expect.stringContaining('Ok'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should cleanup state when prompt_id changes', async () => {
|
||||
const signal = new AbortController().signal;
|
||||
mockTurnRunFn.mockImplementation(async function* (
|
||||
this: MockTurnContext,
|
||||
) {
|
||||
this.getResponseText.mockReturnValue('Ok');
|
||||
yield { type: GeminiEventType.Content, value: 'Ok' };
|
||||
});
|
||||
|
||||
client['hookStateMap'].set('old-id', {
|
||||
hasFiredBeforeAgent: true,
|
||||
cumulativeResponse: 'Old',
|
||||
activeCalls: 0,
|
||||
originalRequest: { text: 'Old' },
|
||||
});
|
||||
client['lastPromptId'] = 'old-id';
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
{ text: 'New' },
|
||||
signal,
|
||||
'new-id',
|
||||
);
|
||||
await stream.next();
|
||||
|
||||
expect(client['hookStateMap'].has('old-id')).toBe(false);
|
||||
expect(client['hookStateMap'].has('new-id')).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
+272
-118
@@ -11,6 +11,7 @@ import type {
|
||||
Tool,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import type { MessageBus } from '../confirmation-bus/message-bus.js';
|
||||
import {
|
||||
getDirectoryContextString,
|
||||
getInitialChatHistory,
|
||||
@@ -42,6 +43,7 @@ import {
|
||||
fireBeforeAgentHook,
|
||||
fireAfterAgentHook,
|
||||
} from './clientHookTriggers.js';
|
||||
import type { DefaultHookOutput } from '../hooks/types.js';
|
||||
import {
|
||||
ContentRetryFailureEvent,
|
||||
NextSpeakerCheckEvent,
|
||||
@@ -61,6 +63,14 @@ import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
type BeforeAgentHookReturn =
|
||||
| {
|
||||
type: GeminiEventType.Error;
|
||||
value: { error: Error };
|
||||
}
|
||||
| { additionalContext: string | undefined }
|
||||
| undefined;
|
||||
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private sessionTurnCount = 0;
|
||||
@@ -84,6 +94,95 @@ export class GeminiClient {
|
||||
this.lastPromptId = this.config.getSessionId();
|
||||
}
|
||||
|
||||
// Hook state to deduplicate BeforeAgent calls and track response for
|
||||
// AfterAgent
|
||||
private hookStateMap = new Map<
|
||||
string,
|
||||
{
|
||||
hasFiredBeforeAgent: boolean;
|
||||
cumulativeResponse: string;
|
||||
activeCalls: number;
|
||||
originalRequest: PartListUnion;
|
||||
}
|
||||
>();
|
||||
|
||||
private async fireBeforeAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
request: PartListUnion,
|
||||
prompt_id: string,
|
||||
): Promise<BeforeAgentHookReturn> {
|
||||
let hookState = this.hookStateMap.get(prompt_id);
|
||||
if (!hookState) {
|
||||
hookState = {
|
||||
hasFiredBeforeAgent: false,
|
||||
cumulativeResponse: '',
|
||||
activeCalls: 0,
|
||||
originalRequest: request,
|
||||
};
|
||||
this.hookStateMap.set(prompt_id, hookState);
|
||||
}
|
||||
|
||||
// Increment active calls for this prompt_id
|
||||
// This is called at the start of sendMessageStream, so it acts as an entry
|
||||
// counter. We increment here, assuming this helper is ALWAYS called at
|
||||
// entry.
|
||||
hookState.activeCalls++;
|
||||
|
||||
if (hookState.hasFiredBeforeAgent) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
hookState.hasFiredBeforeAgent = true;
|
||||
|
||||
if (hookOutput?.isBlockingDecision() || hookOutput?.shouldStopExecution()) {
|
||||
return {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
return { additionalContext };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private async fireAfterAgentHookSafe(
|
||||
messageBus: MessageBus,
|
||||
currentRequest: PartListUnion,
|
||||
prompt_id: string,
|
||||
turn?: Turn,
|
||||
): Promise<DefaultHookOutput | undefined> {
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
// Only fire on the outermost call (when activeCalls is 1)
|
||||
if (!hookState || hookState.activeCalls !== 1) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (turn && turn.pendingToolCalls.length > 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const finalResponseText =
|
||||
hookState.cumulativeResponse ||
|
||||
turn?.getResponseText() ||
|
||||
'[no response text]';
|
||||
const finalRequest = hookState.originalRequest || currentRequest;
|
||||
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
finalRequest,
|
||||
finalResponseText,
|
||||
);
|
||||
return hookOutput;
|
||||
}
|
||||
|
||||
private updateTelemetryTokenCount() {
|
||||
if (this.chat) {
|
||||
uiTelemetryService.setLastPromptTokenCount(
|
||||
@@ -400,63 +499,27 @@ export class GeminiClient {
|
||||
return this.config.getActiveModel();
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
private async *processTurn(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
boundedTurns: number,
|
||||
isInvalidStreamRetry: boolean,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!isInvalidStreamRetry) {
|
||||
this.config.resetTurn();
|
||||
}
|
||||
// Re-initialize turn (it was empty before if in loop, or new instance)
|
||||
let turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await fireBeforeAgentHook(messageBus, request);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
yield {
|
||||
type: GeminiEventType.Error,
|
||||
value: {
|
||||
error: new Error(
|
||||
`BeforeAgent hook blocked processing: ${hookOutput.getEffectiveReason()}`,
|
||||
),
|
||||
},
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
}
|
||||
|
||||
// Add additional context from hooks to the request
|
||||
const additionalContext = hookOutput?.getAdditionalContext();
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.lastPromptId = prompt_id;
|
||||
this.currentSequenceModel = null;
|
||||
}
|
||||
this.sessionTurnCount++;
|
||||
if (
|
||||
this.config.getMaxSessionTurns() > 0 &&
|
||||
this.sessionTurnCount > this.config.getMaxSessionTurns()
|
||||
) {
|
||||
yield { type: GeminiEventType.MaxSessionTurns };
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
// Ensure turns never exceeds MAX_TURNS to prevent infinite loops
|
||||
const boundedTurns = Math.min(turns, MAX_TURNS);
|
||||
|
||||
if (!boundedTurns) {
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
|
||||
// Check for context window overflow
|
||||
@@ -478,7 +541,7 @@ export class GeminiClient {
|
||||
type: GeminiEventType.ContextWindowWillOverflow,
|
||||
value: { estimatedRequestTokenCount, remainingTokenCount },
|
||||
};
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
return turn;
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id, false);
|
||||
@@ -514,7 +577,8 @@ export class GeminiClient {
|
||||
this.forceFullIdeContext = false;
|
||||
}
|
||||
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
// Re-initialize turn with fresh history
|
||||
turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
const controller = new AbortController();
|
||||
const linkedSignal = AbortSignal.any([signal, controller.signal]);
|
||||
@@ -555,6 +619,9 @@ export class GeminiClient {
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
|
||||
const resultStream = turn.run(modelConfigKey, request, linkedSignal);
|
||||
let isError = false;
|
||||
let isInvalidStream = false;
|
||||
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
@@ -566,94 +633,181 @@ export class GeminiClient {
|
||||
this.updateTelemetryTokenCount();
|
||||
|
||||
if (event.type === GeminiEventType.InvalidStream) {
|
||||
if (this.config.getContinueOnFailedApiCall()) {
|
||||
if (isInvalidStreamRetry) {
|
||||
// We already retried once, so stop here.
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
4, // 2 initial + 2 after injections
|
||||
'FAILED_AFTER_PROMPT_INJECTION',
|
||||
modelToUse,
|
||||
),
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||
yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
true, // Set isInvalidStreamRetry to true
|
||||
isInvalidStream = true;
|
||||
}
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
isError = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
// Update cumulative response in hook state
|
||||
// We do this immediately after the stream finishes for THIS turn.
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
if (hooksEnabled) {
|
||||
const responseText = turn.getResponseText() || '';
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
if (hookState && responseText) {
|
||||
// Append with newline if not empty
|
||||
hookState.cumulativeResponse = hookState.cumulativeResponse
|
||||
? `${hookState.cumulativeResponse}\n${responseText}`
|
||||
: responseText;
|
||||
}
|
||||
}
|
||||
|
||||
if (isInvalidStream) {
|
||||
if (this.config.getContinueOnFailedApiCall()) {
|
||||
if (isInvalidStreamRetry) {
|
||||
logContentRetryFailure(
|
||||
this.config,
|
||||
new ContentRetryFailureEvent(
|
||||
4,
|
||||
'FAILED_AFTER_PROMPT_INJECTION',
|
||||
modelToUse,
|
||||
),
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
if (event.type === GeminiEventType.Error) {
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
// Check if next speaker check is needed
|
||||
if (this.config.getQuotaErrorOccurred()) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (this.config.getSkipNextSpeakerCheck()) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
const nextSpeakerCheck = await checkNextSpeaker(
|
||||
this.getChat(),
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
prompt_id,
|
||||
);
|
||||
logNextSpeakerCheck(
|
||||
this.config,
|
||||
new NextSpeakerCheckEvent(
|
||||
prompt_id,
|
||||
turn.finishReason?.toString() || '',
|
||||
nextSpeakerCheck?.next_speaker || '',
|
||||
),
|
||||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, and the final
|
||||
// turn object from the recursive call will be returned.
|
||||
return yield* this.sendMessageStream(
|
||||
const nextRequest = [{ text: 'System: Please continue.' }];
|
||||
// Recursive call - update turn with result
|
||||
turn = yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
// isInvalidStreamRetry is false here, as this is a next speaker check
|
||||
true,
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
|
||||
// Fire AfterAgent hook through MessageBus (only if hooks are enabled)
|
||||
if (hooksEnabled && messageBus) {
|
||||
const responseText = turn.getResponseText() || '[no response text]';
|
||||
const hookOutput = await fireAfterAgentHook(
|
||||
messageBus,
|
||||
request,
|
||||
responseText,
|
||||
);
|
||||
|
||||
// For AfterAgent hooks, blocking/stop execution should force continuation
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
!this.config.getQuotaErrorOccurred() &&
|
||||
!this.config.getSkipNextSpeakerCheck()
|
||||
) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
const nextSpeakerCheck = await checkNextSpeaker(
|
||||
this.getChat(),
|
||||
this.config.getBaseLlmClient(),
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
);
|
||||
logNextSpeakerCheck(
|
||||
this.config,
|
||||
new NextSpeakerCheckEvent(
|
||||
prompt_id,
|
||||
turn.finishReason?.toString() || '',
|
||||
nextSpeakerCheck?.next_speaker || '',
|
||||
),
|
||||
);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
turn = yield* this.sendMessageStream(
|
||||
nextRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
// isInvalidStreamRetry is false
|
||||
);
|
||||
return turn;
|
||||
}
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
async *sendMessageStream(
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
prompt_id: string,
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!isInvalidStreamRetry) {
|
||||
this.config.resetTurn();
|
||||
}
|
||||
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
|
||||
if (this.lastPromptId !== prompt_id) {
|
||||
this.loopDetector.reset(prompt_id);
|
||||
this.hookStateMap.delete(this.lastPromptId);
|
||||
this.lastPromptId = prompt_id;
|
||||
this.currentSequenceModel = null;
|
||||
}
|
||||
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookResult = await this.fireBeforeAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
);
|
||||
if (hookResult) {
|
||||
if ('type' in hookResult && hookResult.type === GeminiEventType.Error) {
|
||||
yield hookResult;
|
||||
return new Turn(this.getChat(), prompt_id);
|
||||
} else if ('additionalContext' in hookResult) {
|
||||
const additionalContext = hookResult.additionalContext;
|
||||
if (additionalContext) {
|
||||
const requestArray = Array.isArray(request) ? request : [request];
|
||||
request = [...requestArray, { text: additionalContext }];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const boundedTurns = Math.min(turns, MAX_TURNS);
|
||||
let turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
try {
|
||||
turn = yield* this.processTurn(
|
||||
request,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns,
|
||||
isInvalidStreamRetry,
|
||||
);
|
||||
|
||||
// Fire AfterAgent hook if we have a turn and no pending tools
|
||||
if (hooksEnabled && messageBus) {
|
||||
const hookOutput = await this.fireAfterAgentHookSafe(
|
||||
messageBus,
|
||||
request,
|
||||
prompt_id,
|
||||
turn,
|
||||
);
|
||||
|
||||
if (
|
||||
hookOutput?.isBlockingDecision() ||
|
||||
hookOutput?.shouldStopExecution()
|
||||
) {
|
||||
const continueReason = hookOutput.getEffectiveReason();
|
||||
const continueRequest = [{ text: continueReason }];
|
||||
yield* this.sendMessageStream(
|
||||
continueRequest,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns - 1,
|
||||
);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
const hookState = this.hookStateMap.get(prompt_id);
|
||||
if (hookState) {
|
||||
hookState.activeCalls--;
|
||||
const isPendingTools =
|
||||
turn?.pendingToolCalls && turn.pendingToolCalls.length > 0;
|
||||
const isAborted = signal?.aborted;
|
||||
|
||||
if (hookState.activeCalls <= 0) {
|
||||
if (!isPendingTools || isAborted) {
|
||||
this.hookStateMap.delete(prompt_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user