feat(core): enhance loop detection with 2-stage check (#12902)

This commit is contained in:
Sandy Tao
2025-11-11 20:49:00 -08:00
committed by GitHub
parent cab9b1f370
commit 408b885689
10 changed files with 438 additions and 60 deletions

View File

@@ -139,6 +139,12 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
extends: 'gemini-2.5-flash-base',
modelConfig: {},
},
'loop-detection-double-check': {
extends: 'base',
modelConfig: {
model: 'gemini-2.5-pro',
},
},
'llm-edit-fixer': {
extends: 'gemini-2.5-flash-base',
modelConfig: {},

View File

@@ -22,6 +22,7 @@ import { LoopDetectionService } from './loopDetectionService.js';
vi.mock('../telemetry/loggers.js', () => ({
logLoopDetected: vi.fn(),
logLoopDetectionDisabled: vi.fn(),
logLlmLoopCheck: vi.fn(),
}));
const TOOL_CALL_LOOP_THRESHOLD = 5;
@@ -736,10 +737,17 @@ describe('LoopDetectionService LLM Checks', () => {
getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false,
getTelemetryEnabled: () => true,
getModel: vi.fn().mockReturnValue('cognitive-loop-v1'),
isInFallbackMode: vi.fn().mockReturnValue(false),
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'cognitive-loop-v1',
generateContentConfig: {},
getResolvedConfig: vi.fn().mockImplementation((key) => {
if (key.model === 'loop-detection') {
return { model: 'gemini-2.5-flash', generateContentConfig: {} };
}
return {
model: 'cognitive-loop-v1',
generateContentConfig: {},
};
}),
},
isInteractive: () => false,
@@ -773,7 +781,7 @@ describe('LoopDetectionService LLM Checks', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: expect.any(Object),
modelConfigKey: { model: 'loop-detection' },
systemInstruction: expect.any(String),
contents: expect.any(Array),
schema: expect.any(Object),
@@ -790,6 +798,11 @@ describe('LoopDetectionService LLM Checks', () => {
});
await advanceTurns(30);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
// The confidence of 0.85 will result in a low interval.
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
@@ -807,6 +820,7 @@ describe('LoopDetectionService LLM Checks', () => {
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'cognitive-loop-v1',
}),
);
});
@@ -885,4 +899,122 @@ describe('LoopDetectionService LLM Checks', () => {
// Verify the original history follows
expect(calledArg.contents[1]).toEqual(functionCallHistory[0]);
});
it('should detect a loop when confidence is exactly equal to the threshold (0.9)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
})
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Main says loop',
});
await advanceTurns(30);
// It should have called generateJson twice
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection-double-check' },
}),
);
// And it should have detected a loop
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'cognitive-loop-v1',
}),
);
});
it('should not detect a loop when Flash is confident (0.9) but Main model is not (0.89)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
})
.mockResolvedValueOnce({
unproductive_state_confidence: 0.89,
unproductive_state_analysis: 'Main says no loop',
});
await advanceTurns(30);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection-double-check' },
}),
);
// Should NOT have detected a loop
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
// But should have updated the interval based on the main model's confidence (0.89)
// Interval = 5 + (15-5) * (1 - 0.89) = 5 + 10 * 0.11 = 5 + 1.1 = 6.1 -> 6
// Advance by 6 turns
await advanceTurns(6);
// Next turn (37) should trigger another check
await service.turnStarted(abortController.signal);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3);
});
it('should only call Flash model if in fallback mode', async () => {
// Mock isInFallbackMode to true
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
});
await advanceTurns(30);
// It should have called generateJson only once
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
// And it should have detected a loop
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'gemini-2.5-flash',
}),
);
});
});

View File

@@ -11,11 +11,13 @@ import { GeminiEventType } from '../core/turn.js';
import {
logLoopDetected,
logLoopDetectionDisabled,
logLlmLoopCheck,
} from '../telemetry/loggers.js';
import {
LoopDetectedEvent,
LoopDetectionDisabledEvent,
LoopType,
LlmLoopCheckEvent,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
import {
@@ -57,6 +59,12 @@ const MIN_LLM_CHECK_INTERVAL = 5;
*/
const MAX_LLM_CHECK_INTERVAL = 15;
/**
* The confidence threshold above which the LLM is considered to have detected a loop.
*/
const LLM_CONFIDENCE_THRESHOLD = 0.9;
const DOUBLE_CHECK_MODEL_ALIAS = 'loop-detection-double-check';
const LOOP_DETECTION_SYSTEM_PROMPT = `You are a sophisticated AI diagnostic agent specializing in identifying when a conversational AI is stuck in an unproductive state. Your task is to analyze the provided conversation history and determine if the assistant has ceased to make meaningful progress.
An unproductive state is characterized by one or more of the following patterns over the last 5 or more assistant turns:
@@ -68,6 +76,23 @@ Cognitive Loop: The assistant seems unable to determine the next logical step. I
Crucially, differentiate between a true unproductive state and legitimate, incremental progress.
For example, a series of 'tool_A' or 'tool_B' tool calls that make small, distinct changes to the same file (like adding docstrings to functions one by one) is considered forward progress and is NOT a loop. A loop would be repeatedly replacing the same text with the same content, or cycling between a small set of files with no net change.`;
const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
unproductive_state_analysis: {
type: 'string',
description:
'Your reasoning on if the conversation is looping without forward progress.',
},
unproductive_state_confidence: {
type: 'number',
description:
'A number between 0.0 and 1.0 representing your confidence that the conversation is in an unproductive state.',
},
},
required: ['unproductive_state_analysis', 'unproductive_state_confidence'],
};
/**
* Service for detecting and preventing infinite loops in AI responses.
* Monitors tool call repetitions and content sentence repetitions.
@@ -413,65 +438,138 @@ export class LoopDetectionService {
parts: [{ text: 'Recent conversation history:' }],
});
}
const schema: Record<string, unknown> = {
type: 'object',
properties: {
unproductive_state_analysis: {
type: 'string',
description:
'Your reasoning on if the conversation is looping without forward progress.',
},
unproductive_state_confidence: {
type: 'number',
description:
'A number between 0.0 and 1.0 representing your confidence that the conversation is in an unproductive state.',
},
},
required: [
'unproductive_state_analysis',
'unproductive_state_confidence',
],
};
let result;
try {
result = await this.config.getBaseLlmClient().generateJson({
modelConfigKey: { model: 'loop-detection' },
contents,
schema,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
});
} catch (e) {
// Do nothing, treat it as a non-loop.
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
const flashResult = await this.queryLoopDetectionModel(
'loop-detection',
contents,
signal,
);
if (!flashResult) {
return false;
}
if (typeof result['unproductive_state_confidence'] === 'number') {
if (result['unproductive_state_confidence'] > 0.9) {
if (
typeof result['unproductive_state_analysis'] === 'string' &&
result['unproductive_state_analysis']
) {
debugLogger.warn(result['unproductive_state_analysis']);
}
logLoopDetected(
this.config,
new LoopDetectedEvent(LoopType.LLM_DETECTED_LOOP, this.promptId),
);
const flashConfidence = flashResult[
'unproductive_state_confidence'
] as number;
const doubleCheckModelName =
this.config.modelConfigService.getResolvedConfig({
model: DOUBLE_CHECK_MODEL_ALIAS,
}).model;
if (flashConfidence < LLM_CONFIDENCE_THRESHOLD) {
logLlmLoopCheck(
this.config,
new LlmLoopCheckEvent(
this.promptId,
flashConfidence,
doubleCheckModelName,
-1,
),
);
this.updateCheckInterval(flashConfidence);
return false;
}
if (this.config.isInFallbackMode()) {
const flashModelName = this.config.modelConfigService.getResolvedConfig({
model: 'loop-detection',
}).model;
this.handleConfirmedLoop(flashResult, flashModelName);
return true;
}
// Double check with configured model
const mainModelResult = await this.queryLoopDetectionModel(
DOUBLE_CHECK_MODEL_ALIAS,
contents,
signal,
);
const mainModelConfidence = mainModelResult
? (mainModelResult['unproductive_state_confidence'] as number)
: 0;
logLlmLoopCheck(
this.config,
new LlmLoopCheckEvent(
this.promptId,
flashConfidence,
doubleCheckModelName,
mainModelConfidence,
),
);
if (mainModelResult) {
if (mainModelConfidence >= LLM_CONFIDENCE_THRESHOLD) {
this.handleConfirmedLoop(mainModelResult, doubleCheckModelName);
return true;
} else {
this.llmCheckInterval = Math.round(
MIN_LLM_CHECK_INTERVAL +
(MAX_LLM_CHECK_INTERVAL - MIN_LLM_CHECK_INTERVAL) *
(1 - result['unproductive_state_confidence']),
);
this.updateCheckInterval(mainModelConfidence);
}
}
return false;
}
private async queryLoopDetectionModel(
model: string,
contents: Content[],
signal: AbortSignal,
): Promise<Record<string, unknown> | null> {
try {
const result = (await this.config.getBaseLlmClient().generateJson({
modelConfigKey: { model },
contents,
schema: LOOP_DETECTION_SCHEMA,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
maxAttempts: 2,
})) as Record<string, unknown>;
if (
result &&
typeof result['unproductive_state_confidence'] === 'number'
) {
return result;
}
return null;
} catch (e) {
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
return null;
}
}
private handleConfirmedLoop(
result: Record<string, unknown>,
modelName: string,
): void {
if (
typeof result['unproductive_state_analysis'] === 'string' &&
result['unproductive_state_analysis']
) {
debugLogger.warn(result['unproductive_state_analysis']);
}
logLoopDetected(
this.config,
new LoopDetectedEvent(
LoopType.LLM_DETECTED_LOOP,
this.promptId,
modelName,
),
);
}
private updateCheckInterval(unproductive_state_confidence: number): void {
this.llmCheckInterval = Math.round(
MIN_LLM_CHECK_INTERVAL +
(MAX_LLM_CHECK_INTERVAL - MIN_LLM_CHECK_INTERVAL) *
(1 - unproductive_state_confidence),
);
}
/**
* Resets all loop detection state.
*/

View File

@@ -141,6 +141,13 @@
"topP": 1
}
},
"loop-detection-double-check": {
"model": "gemini-2.5-pro",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"llm-edit-fixer": {
"model": "gemini-2.5-flash",
"generateContentConfig": {

View File

@@ -37,6 +37,7 @@ import type {
RecoveryAttemptEvent,
WebFetchFallbackAttemptEvent,
ExtensionUpdateEvent,
LlmLoopCheckEvent,
} from '../types.js';
import { EventMetadataKey } from './event-metadata-key.js';
import type { Config } from '../../config/config.js';
@@ -92,6 +93,7 @@ export enum EventNames {
AGENT_FINISH = 'agent_finish',
RECOVERY_ATTEMPT = 'recovery_attempt',
WEB_FETCH_FALLBACK_ATTEMPT = 'web_fetch_fallback_attempt',
LLM_LOOP_CHECK = 'llm_loop_check',
}
export interface LogResponse {
@@ -735,6 +737,14 @@ export class ClearcutLogger {
},
];
if (event.confirmed_by_model) {
data.push({
gemini_cli_key:
EventMetadataKey.GEMINI_CLI_LOOP_DETECTED_CONFIRMED_BY_MODEL,
value: event.confirmed_by_model,
});
}
this.enqueueLogEvent(this.createLogEvent(EventNames.LOOP_DETECTED, data));
this.flushIfNeeded();
}
@@ -1269,6 +1279,32 @@ export class ClearcutLogger {
this.flushIfNeeded();
}
logLlmLoopCheckEvent(event: LlmLoopCheckEvent): void {
const data: EventValue[] = [
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_PROMPT_ID,
value: event.prompt_id,
},
{
gemini_cli_key:
EventMetadataKey.GEMINI_CLI_LLM_LOOP_CHECK_FLASH_CONFIDENCE,
value: event.flash_confidence.toString(),
},
{
gemini_cli_key: EventMetadataKey.GEMINI_CLI_LLM_LOOP_CHECK_MAIN_MODEL,
value: event.main_model,
},
{
gemini_cli_key:
EventMetadataKey.GEMINI_CLI_LLM_LOOP_CHECK_MAIN_MODEL_CONFIDENCE,
value: event.main_model_confidence.toString(),
},
];
this.enqueueLogEvent(this.createLogEvent(EventNames.LLM_LOOP_CHECK, data));
this.flushIfNeeded();
}
/**
* Adds default fields to data, and returns a new data array. This fields
* should exist on all log events.

View File

@@ -7,7 +7,7 @@
// Defines valid event metadata keys for Clearcut logging.
export enum EventMetadataKey {
// Deleted enums: 24
// Next ID: 122
// Next ID: 129
GEMINI_CLI_KEY_UNKNOWN = 0,
@@ -476,4 +476,20 @@ export enum EventMetadataKey {
// Logs whether the session is interactive.
GEMINI_CLI_INTERACTIVE = 125,
// ==========================================================================
// LLM Loop Check Event Keys
// ==========================================================================
// Logs the confidence score from the flash model loop check.
GEMINI_CLI_LLM_LOOP_CHECK_FLASH_CONFIDENCE = 126,
// Logs the name of the main model used for the secondary loop check.
GEMINI_CLI_LLM_LOOP_CHECK_MAIN_MODEL = 127,
// Logs the confidence score from the main model loop check.
GEMINI_CLI_LLM_LOOP_CHECK_MAIN_MODEL_CONFIDENCE = 128,
// Logs the model that confirmed the loop.
GEMINI_CLI_LOOP_DETECTED_CONFIRMED_BY_MODEL = 129,
}

View File

@@ -48,6 +48,7 @@ import type {
RecoveryAttemptEvent,
WebFetchFallbackAttemptEvent,
ExtensionUpdateEvent,
LlmLoopCheckEvent,
} from './types.js';
import {
recordApiErrorMetrics,
@@ -654,3 +655,18 @@ export function logWebFetchFallbackAttempt(
};
logger.emit(logRecord);
}
export function logLlmLoopCheck(
config: Config,
event: LlmLoopCheckEvent,
): void {
ClearcutLogger.getInstance(config)?.logLlmLoopCheckEvent(event);
if (!isTelemetrySdkInitialized()) return;
const logger = logs.getLogger(SERVICE_NAME);
const logRecord: LogRecord = {
body: event.toLogBody(),
attributes: event.toOpenTelemetryAttributes(config),
};
logger.emit(logRecord);
}

View File

@@ -708,26 +708,38 @@ export class LoopDetectedEvent implements BaseTelemetryEvent {
'event.timestamp': string;
loop_type: LoopType;
prompt_id: string;
confirmed_by_model?: string;
constructor(loop_type: LoopType, prompt_id: string) {
constructor(
loop_type: LoopType,
prompt_id: string,
confirmed_by_model?: string,
) {
this['event.name'] = 'loop_detected';
this['event.timestamp'] = new Date().toISOString();
this.loop_type = loop_type;
this.prompt_id = prompt_id;
this.confirmed_by_model = confirmed_by_model;
}
toOpenTelemetryAttributes(config: Config): LogAttributes {
return {
const attributes: LogAttributes = {
...getCommonAttributes(config),
'event.name': this['event.name'],
'event.timestamp': this['event.timestamp'],
loop_type: this.loop_type,
prompt_id: this.prompt_id,
};
if (this.confirmed_by_model) {
attributes['confirmed_by_model'] = this.confirmed_by_model;
}
return attributes;
}
toLogBody(): string {
return `Loop detected. Type: ${this.loop_type}.`;
return `Loop detected. Type: ${this.loop_type}.${this.confirmed_by_model ? ` Confirmed by: ${this.confirmed_by_model}` : ''}`;
}
}
@@ -1417,6 +1429,48 @@ export class ModelSlashCommandEvent implements BaseTelemetryEvent {
}
}
export const EVENT_LLM_LOOP_CHECK = 'gemini_cli.llm_loop_check';
export class LlmLoopCheckEvent implements BaseTelemetryEvent {
'event.name': 'llm_loop_check';
'event.timestamp': string;
prompt_id: string;
flash_confidence: number;
main_model: string;
main_model_confidence: number;
constructor(
prompt_id: string,
flash_confidence: number,
main_model: string,
main_model_confidence: number,
) {
this['event.name'] = 'llm_loop_check';
this['event.timestamp'] = new Date().toISOString();
this.prompt_id = prompt_id;
this.flash_confidence = flash_confidence;
this.main_model = main_model;
this.main_model_confidence = main_model_confidence;
}
toOpenTelemetryAttributes(config: Config): LogAttributes {
return {
...getCommonAttributes(config),
'event.name': EVENT_LLM_LOOP_CHECK,
'event.timestamp': this['event.timestamp'],
prompt_id: this.prompt_id,
flash_confidence: this.flash_confidence,
main_model: this.main_model,
main_model_confidence: this.main_model_confidence,
};
}
toLogBody(): string {
return this.main_model_confidence === -1
? `LLM loop check. Flash confidence: ${this.flash_confidence.toFixed(2)}. Main model (${this.main_model}) check skipped`
: `LLM loop check. Flash confidence: ${this.flash_confidence.toFixed(2)}. Main model (${this.main_model}) confidence: ${this.main_model_confidence.toFixed(2)}`;
}
}
export type TelemetryEvent =
| StartSessionEvent
| EndSessionEvent
@@ -1446,6 +1500,7 @@ export type TelemetryEvent =
| AgentStartEvent
| AgentFinishEvent
| RecoveryAttemptEvent
| LlmLoopCheckEvent
| WebFetchFallbackAttemptEvent;
export const EVENT_EXTENSION_DISABLE = 'gemini_cli.extension_disable';