fix(core): stop streaming request on loop detection (#8377)

This commit is contained in:
Sandy Tao
2025-09-12 16:44:24 -07:00
committed by GitHub
parent 1f70a27e9c
commit 5a05fb0dd0
2 changed files with 74 additions and 1 deletions

View File

@@ -2405,6 +2405,75 @@ ${JSON.stringify(
// Assert
expect(mockCheckNextSpeaker).not.toHaveBeenCalled();
});
it('should create linked abort signal and pass it to turn.run', async () => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const originalSignal = new AbortController().signal;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
originalSignal,
'prompt-id-signal',
);
for await (const _ of stream) {
// consume stream
}
// Assert
expect(mockTurnRunFn).toHaveBeenCalledWith(
expect.any(String),
[{ text: 'Hi' }],
expect.not.objectContaining({ signal: originalSignal }),
);
});
it('should abort linked signal when loop is detected', async () => {
// Arrange
vi.spyOn(client['loopDetector'], 'turnStarted').mockResolvedValue(false);
vi.spyOn(client['loopDetector'], 'addAndCheck')
.mockReturnValueOnce(false)
.mockReturnValueOnce(true);
const mockStream = (async function* () {
yield { type: 'content', value: 'First event' };
yield { type: 'content', value: 'Second event' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-loop',
);
const events = [];
for await (const event of stream) {
events.push(event);
}
// Assert
expect(events).toContainEqual({ type: GeminiEventType.LoopDetected });
expect(client['loopDetector'].addAndCheck).toHaveBeenCalledTimes(2);
});
});
describe('generateContent', () => {

View File

@@ -489,6 +489,9 @@ export class GeminiClient {
const turn = new Turn(this.getChat(), prompt_id);
const controller = new AbortController();
const linkedSignal = AbortSignal.any([signal, controller.signal]);
const loopDetected = await this.loopDetector.turnStarted(signal);
if (loopDetected) {
yield { type: GeminiEventType.LoopDetected };
@@ -514,10 +517,11 @@ export class GeminiClient {
this.currentSequenceModel = modelToUse;
}
const resultStream = turn.run(modelToUse, request, signal);
const resultStream = turn.run(modelToUse, request, linkedSignal);
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
controller.abort();
return turn;
}
yield event;