mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-14 08:01:02 -07:00
fix(core): prevent unhandled AbortError crash during stream loop detection (#21123)
Co-authored-by: Gaurav <39389231+gsquared94@users.noreply.github.com> Co-authored-by: ruomeng <ruomeng@google.com>
This commit is contained in:
@@ -28,6 +28,7 @@ import {
|
||||
GeminiEventType,
|
||||
Turn,
|
||||
type ChatCompressionInfo,
|
||||
type ServerGeminiStreamEvent,
|
||||
} from './turn.js';
|
||||
import { getCoreSystemPrompt } from './prompts.js';
|
||||
import { DEFAULT_GEMINI_MODEL_AUTO } from '../config/models.js';
|
||||
@@ -1118,6 +1119,54 @@ ${JSON.stringify(
|
||||
// The actual token calculation is unit tested in tokenCalculation.test.ts
|
||||
});
|
||||
|
||||
it('should cleanly abort and return Turn on LoopDetected without unhandled promise rejections', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
// Yield an event that will trigger the loop detector
|
||||
yield { type: 'content', value: 'Looping content' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
setTools: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
getLastPromptTokenCount: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
// Mock loop detector to return count > 1 on the first event (loop detected)
|
||||
vi.spyOn(client['loopDetector'], 'addAndCheck').mockReturnValue({
|
||||
count: 2,
|
||||
});
|
||||
|
||||
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-id-1',
|
||||
);
|
||||
|
||||
const events: ServerGeminiStreamEvent[] = [];
|
||||
let finalResult: Turn | undefined;
|
||||
|
||||
while (true) {
|
||||
const result = await stream.next();
|
||||
if (result.done) {
|
||||
finalResult = result.value;
|
||||
break;
|
||||
}
|
||||
events.push(result.value);
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(events).toContainEqual({ type: GeminiEventType.LoopDetected });
|
||||
expect(abortSpy).toHaveBeenCalled();
|
||||
expect(finalResult).toBeInstanceOf(Turn);
|
||||
});
|
||||
|
||||
it('should return the turn instance after the stream is complete', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
|
||||
@@ -708,27 +708,22 @@ export class GeminiClient {
|
||||
let isError = false;
|
||||
let isInvalidStream = false;
|
||||
|
||||
let loopDetectedAbort = false;
|
||||
let loopRecoverResult: { detail?: string } | undefined;
|
||||
for await (const event of resultStream) {
|
||||
const loopResult = this.loopDetector.addAndCheck(event);
|
||||
if (loopResult.count > 1) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
controller.abort();
|
||||
return turn;
|
||||
loopDetectedAbort = true;
|
||||
break;
|
||||
} else if (loopResult.count === 1) {
|
||||
if (boundedTurns <= 1) {
|
||||
yield { type: GeminiEventType.MaxSessionTurns };
|
||||
controller.abort();
|
||||
return turn;
|
||||
loopDetectedAbort = true;
|
||||
break;
|
||||
}
|
||||
return yield* this._recoverFromLoop(
|
||||
loopResult,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns,
|
||||
isInvalidStreamRetry,
|
||||
displayContent,
|
||||
controller,
|
||||
);
|
||||
loopRecoverResult = loopResult;
|
||||
break;
|
||||
}
|
||||
yield event;
|
||||
|
||||
@@ -742,6 +737,23 @@ export class GeminiClient {
|
||||
}
|
||||
}
|
||||
|
||||
if (loopDetectedAbort) {
|
||||
controller.abort();
|
||||
return turn;
|
||||
}
|
||||
|
||||
if (loopRecoverResult) {
|
||||
return yield* this._recoverFromLoop(
|
||||
loopRecoverResult,
|
||||
signal,
|
||||
prompt_id,
|
||||
boundedTurns,
|
||||
isInvalidStreamRetry,
|
||||
displayContent,
|
||||
controller,
|
||||
);
|
||||
}
|
||||
|
||||
if (isError) {
|
||||
return turn;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user