feat(loop-reduction): implement iterative loop detection and model feedback (#20763)

This commit is contained in:
Aishanee Shah
2026-03-04 14:38:36 -05:00
committed by GitHub
parent 8f36051f32
commit e200b40408
5 changed files with 668 additions and 252 deletions
@@ -39,7 +39,7 @@ const LLM_LOOP_CHECK_HISTORY_COUNT = 20;
/**
* The number of turns that must pass in a single prompt before the LLM-based loop check is activated.
*/
const LLM_CHECK_AFTER_TURNS = 40;
const LLM_CHECK_AFTER_TURNS = 30;
/**
* The default interval, in number of turns, at which the LLM-based loop check is performed.
@@ -51,7 +51,7 @@ const DEFAULT_LLM_CHECK_INTERVAL = 10;
* The minimum interval for LLM-based loop checks.
* This is used when the confidence of a loop is high, to check more frequently.
*/
const MIN_LLM_CHECK_INTERVAL = 7;
const MIN_LLM_CHECK_INTERVAL = 5;
/**
* The maximum interval for LLM-based loop checks.
@@ -117,6 +117,15 @@ const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
required: ['unproductive_state_analysis', 'unproductive_state_confidence'],
};
/**
* Result of a loop detection check.
*/
export interface LoopDetectionResult {
count: number;
type?: LoopType;
detail?: string;
confirmedByModel?: string;
}
/**
* Service for detecting and preventing infinite loops in AI responses.
* Monitors tool call repetitions and content sentence repetitions.
@@ -135,8 +144,11 @@ export class LoopDetectionService {
private contentStats = new Map<string, number[]>();
private lastContentIndex = 0;
private loopDetected = false;
private detectedCount = 0;
private lastLoopDetail?: string;
private inCodeBlock = false;
private lastLoopType?: LoopType;
// LLM loop track tracking
private turnsInCurrentPrompt = 0;
private llmCheckInterval = DEFAULT_LLM_CHECK_INTERVAL;
@@ -169,31 +181,68 @@ export class LoopDetectionService {
/**
* Processes a stream event and checks for loop conditions.
* @param event - The stream event to process
* @returns true if a loop is detected, false otherwise
* @returns A LoopDetectionResult
*/
addAndCheck(event: ServerGeminiStreamEvent): boolean {
addAndCheck(event: ServerGeminiStreamEvent): LoopDetectionResult {
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
return false;
return { count: 0 };
}
if (this.loopDetected) {
return {
count: this.detectedCount,
type: this.lastLoopType,
detail: this.lastLoopDetail,
};
}
if (this.loopDetected) {
return this.loopDetected;
}
let isLoop = false;
let detail: string | undefined;
switch (event.type) {
case GeminiEventType.ToolCallRequest:
// content chanting only happens in one single stream, reset if there
// is a tool call in between
this.resetContentTracking();
this.loopDetected = this.checkToolCallLoop(event.value);
isLoop = this.checkToolCallLoop(event.value);
if (isLoop) {
detail = `Repeated tool call: ${event.value.name} with arguments ${JSON.stringify(event.value.args)}`;
}
break;
case GeminiEventType.Content:
this.loopDetected = this.checkContentLoop(event.value);
isLoop = this.checkContentLoop(event.value);
if (isLoop) {
detail = `Repeating content detected: "${this.streamContentHistory.substring(Math.max(0, this.lastContentIndex - 20), this.lastContentIndex + CONTENT_CHUNK_SIZE).trim()}..."`;
}
break;
default:
break;
}
return this.loopDetected;
if (isLoop) {
this.loopDetected = true;
this.detectedCount++;
this.lastLoopDetail = detail;
this.lastLoopType =
event.type === GeminiEventType.ToolCallRequest
? LoopType.CONSECUTIVE_IDENTICAL_TOOL_CALLS
: LoopType.CONTENT_CHANTING_LOOP;
logLoopDetected(
this.config,
new LoopDetectedEvent(
this.lastLoopType,
this.promptId,
this.detectedCount,
),
);
}
return isLoop
? {
count: this.detectedCount,
type: this.lastLoopType,
detail: this.lastLoopDetail,
}
: { count: 0 };
}
/**
@@ -204,12 +253,20 @@ export class LoopDetectionService {
* is performed periodically based on the `llmCheckInterval`.
*
* @param signal - An AbortSignal to allow for cancellation of the asynchronous LLM check.
* @returns A promise that resolves to `true` if a loop is detected, and `false` otherwise.
* @returns A promise that resolves to a LoopDetectionResult.
*/
async turnStarted(signal: AbortSignal) {
async turnStarted(signal: AbortSignal): Promise<LoopDetectionResult> {
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
return false;
return { count: 0 };
}
if (this.loopDetected) {
return {
count: this.detectedCount,
type: this.lastLoopType,
detail: this.lastLoopDetail,
};
}
this.turnsInCurrentPrompt++;
if (
@@ -217,10 +274,35 @@ export class LoopDetectionService {
this.turnsInCurrentPrompt - this.lastCheckTurn >= this.llmCheckInterval
) {
this.lastCheckTurn = this.turnsInCurrentPrompt;
return this.checkForLoopWithLLM(signal);
}
const { isLoop, analysis, confirmedByModel } =
await this.checkForLoopWithLLM(signal);
if (isLoop) {
this.loopDetected = true;
this.detectedCount++;
this.lastLoopDetail = analysis;
this.lastLoopType = LoopType.LLM_DETECTED_LOOP;
return false;
logLoopDetected(
this.config,
new LoopDetectedEvent(
this.lastLoopType,
this.promptId,
this.detectedCount,
confirmedByModel,
analysis,
LLM_CONFIDENCE_THRESHOLD,
),
);
return {
count: this.detectedCount,
type: this.lastLoopType,
detail: this.lastLoopDetail,
confirmedByModel,
};
}
}
return { count: 0 };
}
private checkToolCallLoop(toolCall: { name: string; args: object }): boolean {
@@ -232,13 +314,6 @@ export class LoopDetectionService {
this.toolCallRepetitionCount = 1;
}
if (this.toolCallRepetitionCount >= TOOL_CALL_LOOP_THRESHOLD) {
logLoopDetected(
this.config,
new LoopDetectedEvent(
LoopType.CONSECUTIVE_IDENTICAL_TOOL_CALLS,
this.promptId,
),
);
return true;
}
return false;
@@ -345,13 +420,6 @@ export class LoopDetectionService {
const chunkHash = createHash('sha256').update(currentChunk).digest('hex');
if (this.isLoopDetectedForChunk(currentChunk, chunkHash)) {
logLoopDetected(
this.config,
new LoopDetectedEvent(
LoopType.CHANTING_IDENTICAL_SENTENCES,
this.promptId,
),
);
return true;
}
@@ -445,28 +513,29 @@ export class LoopDetectionService {
return originalChunk === currentChunk;
}
private trimRecentHistory(recentHistory: Content[]): Content[] {
private trimRecentHistory(history: Content[]): Content[] {
// A function response must be preceded by a function call.
// Continuously removes dangling function calls from the end of the history
// until the last turn is not a function call.
while (
recentHistory.length > 0 &&
isFunctionCall(recentHistory[recentHistory.length - 1])
) {
recentHistory.pop();
while (history.length > 0 && isFunctionCall(history[history.length - 1])) {
history.pop();
}
// A function response should follow a function call.
// Continuously removes leading function responses from the beginning of history
// until the first turn is not a function response.
while (recentHistory.length > 0 && isFunctionResponse(recentHistory[0])) {
recentHistory.shift();
while (history.length > 0 && isFunctionResponse(history[0])) {
history.shift();
}
return recentHistory;
return history;
}
private async checkForLoopWithLLM(signal: AbortSignal) {
private async checkForLoopWithLLM(signal: AbortSignal): Promise<{
isLoop: boolean;
analysis?: string;
confirmedByModel?: string;
}> {
const recentHistory = this.config
.getGeminiClient()
.getHistory()
@@ -506,13 +575,17 @@ export class LoopDetectionService {
);
if (!flashResult) {
return false;
return { isLoop: false };
}
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
const flashConfidence = flashResult[
'unproductive_state_confidence'
] as number;
const flashConfidence =
typeof flashResult['unproductive_state_confidence'] === 'number'
? flashResult['unproductive_state_confidence']
: 0;
const flashAnalysis =
typeof flashResult['unproductive_state_analysis'] === 'string'
? flashResult['unproductive_state_analysis']
: '';
const doubleCheckModelName =
this.config.modelConfigService.getResolvedConfig({
@@ -530,7 +603,7 @@ export class LoopDetectionService {
),
);
this.updateCheckInterval(flashConfidence);
return false;
return { isLoop: false };
}
const availability = this.config.getModelAvailabilityService();
@@ -539,8 +612,11 @@ export class LoopDetectionService {
const flashModelName = this.config.modelConfigService.getResolvedConfig({
model: 'loop-detection',
}).model;
this.handleConfirmedLoop(flashResult, flashModelName);
return true;
return {
isLoop: true,
analysis: flashAnalysis,
confirmedByModel: flashModelName,
};
}
// Double check with configured model
@@ -550,10 +626,16 @@ export class LoopDetectionService {
signal,
);
const mainModelConfidence = mainModelResult
? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
(mainModelResult['unproductive_state_confidence'] as number)
: 0;
const mainModelConfidence =
mainModelResult &&
typeof mainModelResult['unproductive_state_confidence'] === 'number'
? mainModelResult['unproductive_state_confidence']
: 0;
const mainModelAnalysis =
mainModelResult &&
typeof mainModelResult['unproductive_state_analysis'] === 'string'
? mainModelResult['unproductive_state_analysis']
: undefined;
logLlmLoopCheck(
this.config,
@@ -567,14 +649,17 @@ export class LoopDetectionService {
if (mainModelResult) {
if (mainModelConfidence >= LLM_CONFIDENCE_THRESHOLD) {
this.handleConfirmedLoop(mainModelResult, doubleCheckModelName);
return true;
return {
isLoop: true,
analysis: mainModelAnalysis,
confirmedByModel: doubleCheckModelName,
};
} else {
this.updateCheckInterval(mainModelConfidence);
}
}
return false;
return { isLoop: false };
}
private async queryLoopDetectionModel(
@@ -601,32 +686,16 @@ export class LoopDetectionService {
return result;
}
return null;
} catch (e) {
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
} catch (error) {
if (this.config.getDebugMode()) {
debugLogger.warn(
`Error querying loop detection model (${model}): ${String(error)}`,
);
}
return null;
}
}
private handleConfirmedLoop(
result: Record<string, unknown>,
modelName: string,
): void {
if (
typeof result['unproductive_state_analysis'] === 'string' &&
result['unproductive_state_analysis']
) {
debugLogger.warn(result['unproductive_state_analysis']);
}
logLoopDetected(
this.config,
new LoopDetectedEvent(
LoopType.LLM_DETECTED_LOOP,
this.promptId,
modelName,
),
);
}
private updateCheckInterval(unproductive_state_confidence: number): void {
this.llmCheckInterval = Math.round(
MIN_LLM_CHECK_INTERVAL +
@@ -645,6 +714,17 @@ export class LoopDetectionService {
this.resetContentTracking();
this.resetLlmCheckTracking();
this.loopDetected = false;
this.detectedCount = 0;
this.lastLoopDetail = undefined;
this.lastLoopType = undefined;
}
/**
* Resets the loop detected flag to allow a recovery turn to proceed.
* This preserves the detectedCount so that the next detection will be count 2.
*/
clearDetection(): void {
this.loopDetected = false;
}
private resetToolCallCount(): void {