diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 2c278bb3c2..059b72437f 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -28,6 +28,7 @@ import { GeminiEventType, Turn, type ChatCompressionInfo, + type ServerGeminiStreamEvent, } from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; import { DEFAULT_GEMINI_MODEL_AUTO } from '../config/models.js'; @@ -1118,6 +1119,54 @@ ${JSON.stringify( // The actual token calculation is unit tested in tokenCalculation.test.ts }); + it('should cleanly abort and return Turn on LoopDetected without unhandled promise rejections', async () => { + // Arrange + const mockStream = (async function* () { + // Yield an event that will trigger the loop detector + yield { type: 'content', value: 'Looping content' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + setTools: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + getLastPromptTokenCount: vi.fn(), + }; + client['chat'] = mockChat as GeminiChat; + + // Mock loop detector to return count > 1 on the first event (loop detected) + vi.spyOn(client['loopDetector'], 'addAndCheck').mockReturnValue({ + count: 2, + }); + + const abortSpy = vi.spyOn(AbortController.prototype, 'abort'); + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-1', + ); + + const events: ServerGeminiStreamEvent[] = []; + let finalResult: Turn | undefined; + + while (true) { + const result = await stream.next(); + if (result.done) { + finalResult = result.value; + break; + } + events.push(result.value); + } + + // Assert + expect(events).toContainEqual({ type: GeminiEventType.LoopDetected }); + expect(abortSpy).toHaveBeenCalled(); + expect(finalResult).toBeInstanceOf(Turn); + }); + it('should return the turn instance after the stream is complete', async () => { // Arrange const mockStream = (async function* () { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index bb391ed645..14e2f42bc3 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -708,27 +708,22 @@ export class GeminiClient { let isError = false; let isInvalidStream = false; + let loopDetectedAbort = false; + let loopRecoverResult: { detail?: string } | undefined; for await (const event of resultStream) { const loopResult = this.loopDetector.addAndCheck(event); if (loopResult.count > 1) { yield { type: GeminiEventType.LoopDetected }; - controller.abort(); - return turn; + loopDetectedAbort = true; + break; } else if (loopResult.count === 1) { if (boundedTurns <= 1) { yield { type: GeminiEventType.MaxSessionTurns }; - controller.abort(); - return turn; + loopDetectedAbort = true; + break; } - return yield* this._recoverFromLoop( - loopResult, - signal, - prompt_id, - boundedTurns, - isInvalidStreamRetry, - displayContent, - controller, - ); + loopRecoverResult = loopResult; + break; } yield event; @@ -742,6 +737,23 @@ export class GeminiClient { } } + if (loopDetectedAbort) { + controller.abort(); + return turn; + } + + if (loopRecoverResult) { + return yield* this._recoverFromLoop( + loopRecoverResult, + signal, + prompt_id, + boundedTurns, + isInvalidStreamRetry, + displayContent, + controller, + ); + } + if (isError) { return turn; }