mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
fix(core): stop streaming request on loop detection (#8377)
This commit is contained in:
@@ -2405,6 +2405,75 @@ ${JSON.stringify(
|
||||
// Assert
|
||||
expect(mockCheckNextSpeaker).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create linked abort signal and pass it to turn.run', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const originalSignal = new AbortController().signal;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
originalSignal,
|
||||
'prompt-id-signal',
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
[{ text: 'Hi' }],
|
||||
expect.not.objectContaining({ signal: originalSignal }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should abort linked signal when loop is detected', async () => {
|
||||
// Arrange
|
||||
vi.spyOn(client['loopDetector'], 'turnStarted').mockResolvedValue(false);
|
||||
vi.spyOn(client['loopDetector'], 'addAndCheck')
|
||||
.mockReturnValueOnce(false)
|
||||
.mockReturnValueOnce(true);
|
||||
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'First event' };
|
||||
yield { type: 'content', value: 'Second event' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-id-loop',
|
||||
);
|
||||
|
||||
const events = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(events).toContainEqual({ type: GeminiEventType.LoopDetected });
|
||||
expect(client['loopDetector'].addAndCheck).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateContent', () => {
|
||||
|
||||
@@ -489,6 +489,9 @@ export class GeminiClient {
|
||||
|
||||
const turn = new Turn(this.getChat(), prompt_id);
|
||||
|
||||
const controller = new AbortController();
|
||||
const linkedSignal = AbortSignal.any([signal, controller.signal]);
|
||||
|
||||
const loopDetected = await this.loopDetector.turnStarted(signal);
|
||||
if (loopDetected) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
@@ -514,10 +517,11 @@ export class GeminiClient {
|
||||
this.currentSequenceModel = modelToUse;
|
||||
}
|
||||
|
||||
const resultStream = turn.run(modelToUse, request, signal);
|
||||
const resultStream = turn.run(modelToUse, request, linkedSignal);
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
yield { type: GeminiEventType.LoopDetected };
|
||||
controller.abort();
|
||||
return turn;
|
||||
}
|
||||
yield event;
|
||||
|
||||
Reference in New Issue
Block a user