mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-22 19:14:33 -07:00
feat(loop-reduction): implement iterative loop detection and model feedback (#20763)
This commit is contained in:
@@ -39,7 +39,7 @@ const LLM_LOOP_CHECK_HISTORY_COUNT = 20;
|
||||
/**
|
||||
* The number of turns that must pass in a single prompt before the LLM-based loop check is activated.
|
||||
*/
|
||||
const LLM_CHECK_AFTER_TURNS = 40;
|
||||
const LLM_CHECK_AFTER_TURNS = 30;
|
||||
|
||||
/**
|
||||
* The default interval, in number of turns, at which the LLM-based loop check is performed.
|
||||
@@ -51,7 +51,7 @@ const DEFAULT_LLM_CHECK_INTERVAL = 10;
|
||||
* The minimum interval for LLM-based loop checks.
|
||||
* This is used when the confidence of a loop is high, to check more frequently.
|
||||
*/
|
||||
const MIN_LLM_CHECK_INTERVAL = 7;
|
||||
const MIN_LLM_CHECK_INTERVAL = 5;
|
||||
|
||||
/**
|
||||
* The maximum interval for LLM-based loop checks.
|
||||
@@ -117,6 +117,15 @@ const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
|
||||
required: ['unproductive_state_analysis', 'unproductive_state_confidence'],
|
||||
};
|
||||
|
||||
/**
|
||||
* Result of a loop detection check.
|
||||
*/
|
||||
export interface LoopDetectionResult {
|
||||
count: number;
|
||||
type?: LoopType;
|
||||
detail?: string;
|
||||
confirmedByModel?: string;
|
||||
}
|
||||
/**
|
||||
* Service for detecting and preventing infinite loops in AI responses.
|
||||
* Monitors tool call repetitions and content sentence repetitions.
|
||||
@@ -135,8 +144,11 @@ export class LoopDetectionService {
|
||||
private contentStats = new Map<string, number[]>();
|
||||
private lastContentIndex = 0;
|
||||
private loopDetected = false;
|
||||
private detectedCount = 0;
|
||||
private lastLoopDetail?: string;
|
||||
private inCodeBlock = false;
|
||||
|
||||
private lastLoopType?: LoopType;
|
||||
// LLM loop track tracking
|
||||
private turnsInCurrentPrompt = 0;
|
||||
private llmCheckInterval = DEFAULT_LLM_CHECK_INTERVAL;
|
||||
@@ -169,31 +181,68 @@ export class LoopDetectionService {
|
||||
/**
|
||||
* Processes a stream event and checks for loop conditions.
|
||||
* @param event - The stream event to process
|
||||
* @returns true if a loop is detected, false otherwise
|
||||
* @returns A LoopDetectionResult
|
||||
*/
|
||||
addAndCheck(event: ServerGeminiStreamEvent): boolean {
|
||||
addAndCheck(event: ServerGeminiStreamEvent): LoopDetectionResult {
|
||||
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
|
||||
return false;
|
||||
return { count: 0 };
|
||||
}
|
||||
if (this.loopDetected) {
|
||||
return {
|
||||
count: this.detectedCount,
|
||||
type: this.lastLoopType,
|
||||
detail: this.lastLoopDetail,
|
||||
};
|
||||
}
|
||||
|
||||
if (this.loopDetected) {
|
||||
return this.loopDetected;
|
||||
}
|
||||
let isLoop = false;
|
||||
let detail: string | undefined;
|
||||
|
||||
switch (event.type) {
|
||||
case GeminiEventType.ToolCallRequest:
|
||||
// content chanting only happens in one single stream, reset if there
|
||||
// is a tool call in between
|
||||
this.resetContentTracking();
|
||||
this.loopDetected = this.checkToolCallLoop(event.value);
|
||||
isLoop = this.checkToolCallLoop(event.value);
|
||||
if (isLoop) {
|
||||
detail = `Repeated tool call: ${event.value.name} with arguments ${JSON.stringify(event.value.args)}`;
|
||||
}
|
||||
break;
|
||||
case GeminiEventType.Content:
|
||||
this.loopDetected = this.checkContentLoop(event.value);
|
||||
isLoop = this.checkContentLoop(event.value);
|
||||
if (isLoop) {
|
||||
detail = `Repeating content detected: "${this.streamContentHistory.substring(Math.max(0, this.lastContentIndex - 20), this.lastContentIndex + CONTENT_CHUNK_SIZE).trim()}..."`;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return this.loopDetected;
|
||||
|
||||
if (isLoop) {
|
||||
this.loopDetected = true;
|
||||
this.detectedCount++;
|
||||
this.lastLoopDetail = detail;
|
||||
this.lastLoopType =
|
||||
event.type === GeminiEventType.ToolCallRequest
|
||||
? LoopType.CONSECUTIVE_IDENTICAL_TOOL_CALLS
|
||||
: LoopType.CONTENT_CHANTING_LOOP;
|
||||
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
this.lastLoopType,
|
||||
this.promptId,
|
||||
this.detectedCount,
|
||||
),
|
||||
);
|
||||
}
|
||||
return isLoop
|
||||
? {
|
||||
count: this.detectedCount,
|
||||
type: this.lastLoopType,
|
||||
detail: this.lastLoopDetail,
|
||||
}
|
||||
: { count: 0 };
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -204,12 +253,20 @@ export class LoopDetectionService {
|
||||
* is performed periodically based on the `llmCheckInterval`.
|
||||
*
|
||||
* @param signal - An AbortSignal to allow for cancellation of the asynchronous LLM check.
|
||||
* @returns A promise that resolves to `true` if a loop is detected, and `false` otherwise.
|
||||
* @returns A promise that resolves to a LoopDetectionResult.
|
||||
*/
|
||||
async turnStarted(signal: AbortSignal) {
|
||||
async turnStarted(signal: AbortSignal): Promise<LoopDetectionResult> {
|
||||
if (this.disabledForSession || this.config.getDisableLoopDetection()) {
|
||||
return false;
|
||||
return { count: 0 };
|
||||
}
|
||||
if (this.loopDetected) {
|
||||
return {
|
||||
count: this.detectedCount,
|
||||
type: this.lastLoopType,
|
||||
detail: this.lastLoopDetail,
|
||||
};
|
||||
}
|
||||
|
||||
this.turnsInCurrentPrompt++;
|
||||
|
||||
if (
|
||||
@@ -217,10 +274,35 @@ export class LoopDetectionService {
|
||||
this.turnsInCurrentPrompt - this.lastCheckTurn >= this.llmCheckInterval
|
||||
) {
|
||||
this.lastCheckTurn = this.turnsInCurrentPrompt;
|
||||
return this.checkForLoopWithLLM(signal);
|
||||
}
|
||||
const { isLoop, analysis, confirmedByModel } =
|
||||
await this.checkForLoopWithLLM(signal);
|
||||
if (isLoop) {
|
||||
this.loopDetected = true;
|
||||
this.detectedCount++;
|
||||
this.lastLoopDetail = analysis;
|
||||
this.lastLoopType = LoopType.LLM_DETECTED_LOOP;
|
||||
|
||||
return false;
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
this.lastLoopType,
|
||||
this.promptId,
|
||||
this.detectedCount,
|
||||
confirmedByModel,
|
||||
analysis,
|
||||
LLM_CONFIDENCE_THRESHOLD,
|
||||
),
|
||||
);
|
||||
|
||||
return {
|
||||
count: this.detectedCount,
|
||||
type: this.lastLoopType,
|
||||
detail: this.lastLoopDetail,
|
||||
confirmedByModel,
|
||||
};
|
||||
}
|
||||
}
|
||||
return { count: 0 };
|
||||
}
|
||||
|
||||
private checkToolCallLoop(toolCall: { name: string; args: object }): boolean {
|
||||
@@ -232,13 +314,6 @@ export class LoopDetectionService {
|
||||
this.toolCallRepetitionCount = 1;
|
||||
}
|
||||
if (this.toolCallRepetitionCount >= TOOL_CALL_LOOP_THRESHOLD) {
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
LoopType.CONSECUTIVE_IDENTICAL_TOOL_CALLS,
|
||||
this.promptId,
|
||||
),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -345,13 +420,6 @@ export class LoopDetectionService {
|
||||
const chunkHash = createHash('sha256').update(currentChunk).digest('hex');
|
||||
|
||||
if (this.isLoopDetectedForChunk(currentChunk, chunkHash)) {
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
LoopType.CHANTING_IDENTICAL_SENTENCES,
|
||||
this.promptId,
|
||||
),
|
||||
);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -445,28 +513,29 @@ export class LoopDetectionService {
|
||||
return originalChunk === currentChunk;
|
||||
}
|
||||
|
||||
private trimRecentHistory(recentHistory: Content[]): Content[] {
|
||||
private trimRecentHistory(history: Content[]): Content[] {
|
||||
// A function response must be preceded by a function call.
|
||||
// Continuously removes dangling function calls from the end of the history
|
||||
// until the last turn is not a function call.
|
||||
while (
|
||||
recentHistory.length > 0 &&
|
||||
isFunctionCall(recentHistory[recentHistory.length - 1])
|
||||
) {
|
||||
recentHistory.pop();
|
||||
while (history.length > 0 && isFunctionCall(history[history.length - 1])) {
|
||||
history.pop();
|
||||
}
|
||||
|
||||
// A function response should follow a function call.
|
||||
// Continuously removes leading function responses from the beginning of history
|
||||
// until the first turn is not a function response.
|
||||
while (recentHistory.length > 0 && isFunctionResponse(recentHistory[0])) {
|
||||
recentHistory.shift();
|
||||
while (history.length > 0 && isFunctionResponse(history[0])) {
|
||||
history.shift();
|
||||
}
|
||||
|
||||
return recentHistory;
|
||||
return history;
|
||||
}
|
||||
|
||||
private async checkForLoopWithLLM(signal: AbortSignal) {
|
||||
private async checkForLoopWithLLM(signal: AbortSignal): Promise<{
|
||||
isLoop: boolean;
|
||||
analysis?: string;
|
||||
confirmedByModel?: string;
|
||||
}> {
|
||||
const recentHistory = this.config
|
||||
.getGeminiClient()
|
||||
.getHistory()
|
||||
@@ -506,13 +575,17 @@ export class LoopDetectionService {
|
||||
);
|
||||
|
||||
if (!flashResult) {
|
||||
return false;
|
||||
return { isLoop: false };
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const flashConfidence = flashResult[
|
||||
'unproductive_state_confidence'
|
||||
] as number;
|
||||
const flashConfidence =
|
||||
typeof flashResult['unproductive_state_confidence'] === 'number'
|
||||
? flashResult['unproductive_state_confidence']
|
||||
: 0;
|
||||
const flashAnalysis =
|
||||
typeof flashResult['unproductive_state_analysis'] === 'string'
|
||||
? flashResult['unproductive_state_analysis']
|
||||
: '';
|
||||
|
||||
const doubleCheckModelName =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
@@ -530,7 +603,7 @@ export class LoopDetectionService {
|
||||
),
|
||||
);
|
||||
this.updateCheckInterval(flashConfidence);
|
||||
return false;
|
||||
return { isLoop: false };
|
||||
}
|
||||
|
||||
const availability = this.config.getModelAvailabilityService();
|
||||
@@ -539,8 +612,11 @@ export class LoopDetectionService {
|
||||
const flashModelName = this.config.modelConfigService.getResolvedConfig({
|
||||
model: 'loop-detection',
|
||||
}).model;
|
||||
this.handleConfirmedLoop(flashResult, flashModelName);
|
||||
return true;
|
||||
return {
|
||||
isLoop: true,
|
||||
analysis: flashAnalysis,
|
||||
confirmedByModel: flashModelName,
|
||||
};
|
||||
}
|
||||
|
||||
// Double check with configured model
|
||||
@@ -550,10 +626,16 @@ export class LoopDetectionService {
|
||||
signal,
|
||||
);
|
||||
|
||||
const mainModelConfidence = mainModelResult
|
||||
? // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
(mainModelResult['unproductive_state_confidence'] as number)
|
||||
: 0;
|
||||
const mainModelConfidence =
|
||||
mainModelResult &&
|
||||
typeof mainModelResult['unproductive_state_confidence'] === 'number'
|
||||
? mainModelResult['unproductive_state_confidence']
|
||||
: 0;
|
||||
const mainModelAnalysis =
|
||||
mainModelResult &&
|
||||
typeof mainModelResult['unproductive_state_analysis'] === 'string'
|
||||
? mainModelResult['unproductive_state_analysis']
|
||||
: undefined;
|
||||
|
||||
logLlmLoopCheck(
|
||||
this.config,
|
||||
@@ -567,14 +649,17 @@ export class LoopDetectionService {
|
||||
|
||||
if (mainModelResult) {
|
||||
if (mainModelConfidence >= LLM_CONFIDENCE_THRESHOLD) {
|
||||
this.handleConfirmedLoop(mainModelResult, doubleCheckModelName);
|
||||
return true;
|
||||
return {
|
||||
isLoop: true,
|
||||
analysis: mainModelAnalysis,
|
||||
confirmedByModel: doubleCheckModelName,
|
||||
};
|
||||
} else {
|
||||
this.updateCheckInterval(mainModelConfidence);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
return { isLoop: false };
|
||||
}
|
||||
|
||||
private async queryLoopDetectionModel(
|
||||
@@ -601,32 +686,16 @@ export class LoopDetectionService {
|
||||
return result;
|
||||
}
|
||||
return null;
|
||||
} catch (e) {
|
||||
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
|
||||
} catch (error) {
|
||||
if (this.config.getDebugMode()) {
|
||||
debugLogger.warn(
|
||||
`Error querying loop detection model (${model}): ${String(error)}`,
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private handleConfirmedLoop(
|
||||
result: Record<string, unknown>,
|
||||
modelName: string,
|
||||
): void {
|
||||
if (
|
||||
typeof result['unproductive_state_analysis'] === 'string' &&
|
||||
result['unproductive_state_analysis']
|
||||
) {
|
||||
debugLogger.warn(result['unproductive_state_analysis']);
|
||||
}
|
||||
logLoopDetected(
|
||||
this.config,
|
||||
new LoopDetectedEvent(
|
||||
LoopType.LLM_DETECTED_LOOP,
|
||||
this.promptId,
|
||||
modelName,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
private updateCheckInterval(unproductive_state_confidence: number): void {
|
||||
this.llmCheckInterval = Math.round(
|
||||
MIN_LLM_CHECK_INTERVAL +
|
||||
@@ -645,6 +714,17 @@ export class LoopDetectionService {
|
||||
this.resetContentTracking();
|
||||
this.resetLlmCheckTracking();
|
||||
this.loopDetected = false;
|
||||
this.detectedCount = 0;
|
||||
this.lastLoopDetail = undefined;
|
||||
this.lastLoopType = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the loop detected flag to allow a recovery turn to proceed.
|
||||
* This preserves the detectedCount so that the next detection will be count 2.
|
||||
*/
|
||||
clearDetection(): void {
|
||||
this.loopDetected = false;
|
||||
}
|
||||
|
||||
private resetToolCallCount(): void {
|
||||
|
||||
Reference in New Issue
Block a user