mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-14 08:01:02 -07:00
feat(core): tune loop detection for earlier catch and alternating patterns
TOOL_CALL_LOOP_THRESHOLD 5→4, LLM_CHECK_AFTER_TURNS 30→20, and new alternating-pattern detection (A→B→A→B) that the consecutive-only check missed.
This commit is contained in:
@@ -26,7 +26,7 @@ vi.mock('../telemetry/loggers.js', () => ({
|
||||
logLlmLoopCheck: vi.fn(),
|
||||
}));
|
||||
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 5;
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 4;
|
||||
const CONTENT_LOOP_THRESHOLD = 10;
|
||||
const CONTENT_CHUNK_SIZE = 50;
|
||||
|
||||
@@ -806,15 +806,15 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
};
|
||||
|
||||
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
|
||||
await advanceTurns(29);
|
||||
await advanceTurns(19);
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should trigger LLM check on the 30th turn', async () => {
|
||||
it('should trigger LLM check on the 20th turn', async () => {
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -828,12 +828,12 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
});
|
||||
|
||||
it('should detect a cognitive loop when confidence is high', async () => {
|
||||
// First check at turn 30
|
||||
// First check at turn 20
|
||||
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
|
||||
unproductive_state_confidence: 0.85,
|
||||
unproductive_state_analysis: 'Repetitive actions',
|
||||
});
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -843,13 +843,13 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
|
||||
// The confidence of 0.85 will result in a low interval.
|
||||
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
|
||||
await advanceTurns(6); // advance to turn 36
|
||||
await advanceTurns(6); // advance to turn 26
|
||||
|
||||
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
|
||||
unproductive_state_confidence: 0.95,
|
||||
unproductive_state_analysis: 'Repetitive actions',
|
||||
});
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 27
|
||||
|
||||
expect(finalResult).toBe(true);
|
||||
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
|
||||
@@ -867,7 +867,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_confidence: 0.5,
|
||||
unproductive_state_analysis: 'Looks okay',
|
||||
});
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
@@ -878,13 +878,13 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.0 });
|
||||
await advanceTurns(30); // First check at turn 30
|
||||
await advanceTurns(20); // First check at turn 20
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await advanceTurns(14); // Advance to turn 44
|
||||
await advanceTurns(14); // Advance to turn 34
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await service.turnStarted(abortController.signal); // Turn 45
|
||||
await service.turnStarted(abortController.signal); // Turn 35
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
@@ -892,7 +892,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('API error'));
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
@@ -901,7 +901,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
it('should not trigger LLM check when disabled for session', async () => {
|
||||
service.disableForSession();
|
||||
expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1);
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
@@ -924,7 +924,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
@@ -949,7 +949,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Main says loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
|
||||
// It should have called generateJson twice
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -989,7 +989,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Main says no loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
|
||||
@@ -1032,7 +1032,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Flash says loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(20);
|
||||
|
||||
// It should have called generateJson only once
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
@@ -27,7 +27,7 @@ import {
|
||||
} from '../utils/messageInspectors.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 5;
|
||||
const TOOL_CALL_LOOP_THRESHOLD = 4;
|
||||
const CONTENT_LOOP_THRESHOLD = 10;
|
||||
const CONTENT_CHUNK_SIZE = 50;
|
||||
const MAX_HISTORY_LENGTH = 5000;
|
||||
@@ -40,7 +40,7 @@ const LLM_LOOP_CHECK_HISTORY_COUNT = 20;
|
||||
/**
|
||||
* The number of turns that must pass in a single prompt before the LLM-based loop check is activated.
|
||||
*/
|
||||
const LLM_CHECK_AFTER_TURNS = 30;
|
||||
const LLM_CHECK_AFTER_TURNS = 20;
|
||||
|
||||
/**
|
||||
* The default interval, in number of turns, at which the LLM-based loop check is performed.
|
||||
@@ -105,6 +105,7 @@ export class LoopDetectionService {
|
||||
// Tool call tracking
|
||||
private lastToolCallKey: string | null = null;
|
||||
private toolCallRepetitionCount: number = 0;
|
||||
private recentToolCallKeys: string[] = [];
|
||||
|
||||
// Content streaming tracking
|
||||
private streamContentHistory = '';
|
||||
@@ -217,6 +218,53 @@ export class LoopDetectionService {
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Alternating pattern detection: track last 12 tool calls and detect
|
||||
// when a pattern of 2-3 distinct calls repeats 3+ times.
|
||||
this.recentToolCallKeys.push(key);
|
||||
if (this.recentToolCallKeys.length > 12) {
|
||||
this.recentToolCallKeys.shift();
|
||||
}
|
||||
if (this.detectAlternatingPattern()) {
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
LoopType.CONSECUTIVE_IDENTICAL_TOOL_CALLS,
|
||||
this.promptId,
|
||||
),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects alternating patterns like A->B->A->B->A->B or A->B->C->A->B->C.
|
||||
* Checks if a pattern of length 2 or 3 repeats at least 3 times at the
|
||||
* end of the recent tool call history.
|
||||
*/
|
||||
private detectAlternatingPattern(): boolean {
|
||||
const keys = this.recentToolCallKeys;
|
||||
// Check patterns of length 2 and 3
|
||||
for (const patternLen of [2, 3]) {
|
||||
const minRequired = patternLen * 3; // Need at least 3 repetitions
|
||||
if (keys.length < minRequired) continue;
|
||||
|
||||
const pattern = keys.slice(keys.length - patternLen);
|
||||
let repetitions = 1;
|
||||
for (let i = keys.length - patternLen * 2; i >= 0; i -= patternLen) {
|
||||
const segment = keys.slice(i, i + patternLen);
|
||||
if (segment.every((k, idx) => k === pattern[idx])) {
|
||||
repetitions++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (repetitions >= 3) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -613,6 +661,7 @@ export class LoopDetectionService {
|
||||
private resetToolCallCount(): void {
|
||||
this.lastToolCallKey = null;
|
||||
this.toolCallRepetitionCount = 0;
|
||||
this.recentToolCallKeys = [];
|
||||
}
|
||||
|
||||
private resetContentTracking(resetHistory = true): void {
|
||||
|
||||
Reference in New Issue
Block a user