feat(core): enhance loop detection with 2-stage check (#12902)

This commit is contained in:
Sandy Tao
2025-11-11 20:49:00 -08:00
committed by GitHub
parent cab9b1f370
commit 408b885689
10 changed files with 438 additions and 60 deletions

View File

@@ -22,6 +22,7 @@ import { LoopDetectionService } from './loopDetectionService.js';
vi.mock('../telemetry/loggers.js', () => ({
logLoopDetected: vi.fn(),
logLoopDetectionDisabled: vi.fn(),
logLlmLoopCheck: vi.fn(),
}));
const TOOL_CALL_LOOP_THRESHOLD = 5;
@@ -736,10 +737,17 @@ describe('LoopDetectionService LLM Checks', () => {
getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false,
getTelemetryEnabled: () => true,
getModel: vi.fn().mockReturnValue('cognitive-loop-v1'),
isInFallbackMode: vi.fn().mockReturnValue(false),
modelConfigService: {
getResolvedConfig: vi.fn().mockReturnValue({
model: 'cognitive-loop-v1',
generateContentConfig: {},
getResolvedConfig: vi.fn().mockImplementation((key) => {
if (key.model === 'loop-detection') {
return { model: 'gemini-2.5-flash', generateContentConfig: {} };
}
return {
model: 'cognitive-loop-v1',
generateContentConfig: {},
};
}),
},
isInteractive: () => false,
@@ -773,7 +781,7 @@ describe('LoopDetectionService LLM Checks', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: expect.any(Object),
modelConfigKey: { model: 'loop-detection' },
systemInstruction: expect.any(String),
contents: expect.any(Array),
schema: expect.any(Object),
@@ -790,6 +798,11 @@ describe('LoopDetectionService LLM Checks', () => {
});
await advanceTurns(30);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
// The confidence of 0.85 will result in a low interval.
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
@@ -807,6 +820,7 @@ describe('LoopDetectionService LLM Checks', () => {
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'cognitive-loop-v1',
}),
);
});
@@ -885,4 +899,122 @@ describe('LoopDetectionService LLM Checks', () => {
// Verify the original history follows
expect(calledArg.contents[1]).toEqual(functionCallHistory[0]);
});
it('should detect a loop when confidence is exactly equal to the threshold (0.9)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
})
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Main says loop',
});
await advanceTurns(30);
// It should have called generateJson twice
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection-double-check' },
}),
);
// And it should have detected a loop
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'cognitive-loop-v1',
}),
);
});
it('should not detect a loop when Flash is confident (0.9) but Main model is not (0.89)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
})
.mockResolvedValueOnce({
unproductive_state_confidence: 0.89,
unproductive_state_analysis: 'Main says no loop',
});
await advanceTurns(30);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
modelConfigKey: { model: 'loop-detection-double-check' },
}),
);
// Should NOT have detected a loop
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
// But should have updated the interval based on the main model's confidence (0.89)
// Interval = 5 + (15-5) * (1 - 0.89) = 5 + 10 * 0.11 = 5 + 1.1 = 6.1 -> 6
// Advance by 6 turns
await advanceTurns(6);
// Next turn (37) should trigger another check
await service.turnStarted(abortController.signal);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3);
});
it('should only call Flash model if in fallback mode', async () => {
// Mock isInFallbackMode to true
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValueOnce({
unproductive_state_confidence: 0.9,
unproductive_state_analysis: 'Flash says loop',
});
await advanceTurns(30);
// It should have called generateJson only once
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
modelConfigKey: { model: 'loop-detection' },
}),
);
// And it should have detected a loop
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
mockConfig,
expect.objectContaining({
'event.name': 'loop_detected',
loop_type: LoopType.LLM_DETECTED_LOOP,
confirmed_by_model: 'gemini-2.5-flash',
}),
);
});
});

View File

@@ -11,11 +11,13 @@ import { GeminiEventType } from '../core/turn.js';
import {
logLoopDetected,
logLoopDetectionDisabled,
logLlmLoopCheck,
} from '../telemetry/loggers.js';
import {
LoopDetectedEvent,
LoopDetectionDisabledEvent,
LoopType,
LlmLoopCheckEvent,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
import {
@@ -57,6 +59,12 @@ const MIN_LLM_CHECK_INTERVAL = 5;
*/
const MAX_LLM_CHECK_INTERVAL = 15;
/**
* The confidence threshold above which the LLM is considered to have detected a loop.
*/
const LLM_CONFIDENCE_THRESHOLD = 0.9;
const DOUBLE_CHECK_MODEL_ALIAS = 'loop-detection-double-check';
const LOOP_DETECTION_SYSTEM_PROMPT = `You are a sophisticated AI diagnostic agent specializing in identifying when a conversational AI is stuck in an unproductive state. Your task is to analyze the provided conversation history and determine if the assistant has ceased to make meaningful progress.
An unproductive state is characterized by one or more of the following patterns over the last 5 or more assistant turns:
@@ -68,6 +76,23 @@ Cognitive Loop: The assistant seems unable to determine the next logical step. I
Crucially, differentiate between a true unproductive state and legitimate, incremental progress.
For example, a series of 'tool_A' or 'tool_B' tool calls that make small, distinct changes to the same file (like adding docstrings to functions one by one) is considered forward progress and is NOT a loop. A loop would be repeatedly replacing the same text with the same content, or cycling between a small set of files with no net change.`;
const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
type: 'object',
properties: {
unproductive_state_analysis: {
type: 'string',
description:
'Your reasoning on if the conversation is looping without forward progress.',
},
unproductive_state_confidence: {
type: 'number',
description:
'A number between 0.0 and 1.0 representing your confidence that the conversation is in an unproductive state.',
},
},
required: ['unproductive_state_analysis', 'unproductive_state_confidence'],
};
/**
* Service for detecting and preventing infinite loops in AI responses.
* Monitors tool call repetitions and content sentence repetitions.
@@ -413,65 +438,138 @@ export class LoopDetectionService {
parts: [{ text: 'Recent conversation history:' }],
});
}
const schema: Record<string, unknown> = {
type: 'object',
properties: {
unproductive_state_analysis: {
type: 'string',
description:
'Your reasoning on if the conversation is looping without forward progress.',
},
unproductive_state_confidence: {
type: 'number',
description:
'A number between 0.0 and 1.0 representing your confidence that the conversation is in an unproductive state.',
},
},
required: [
'unproductive_state_analysis',
'unproductive_state_confidence',
],
};
let result;
try {
result = await this.config.getBaseLlmClient().generateJson({
modelConfigKey: { model: 'loop-detection' },
contents,
schema,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
});
} catch (e) {
// Do nothing, treat it as a non-loop.
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
const flashResult = await this.queryLoopDetectionModel(
'loop-detection',
contents,
signal,
);
if (!flashResult) {
return false;
}
if (typeof result['unproductive_state_confidence'] === 'number') {
if (result['unproductive_state_confidence'] > 0.9) {
if (
typeof result['unproductive_state_analysis'] === 'string' &&
result['unproductive_state_analysis']
) {
debugLogger.warn(result['unproductive_state_analysis']);
}
logLoopDetected(
this.config,
new LoopDetectedEvent(LoopType.LLM_DETECTED_LOOP, this.promptId),
);
const flashConfidence = flashResult[
'unproductive_state_confidence'
] as number;
const doubleCheckModelName =
this.config.modelConfigService.getResolvedConfig({
model: DOUBLE_CHECK_MODEL_ALIAS,
}).model;
if (flashConfidence < LLM_CONFIDENCE_THRESHOLD) {
logLlmLoopCheck(
this.config,
new LlmLoopCheckEvent(
this.promptId,
flashConfidence,
doubleCheckModelName,
-1,
),
);
this.updateCheckInterval(flashConfidence);
return false;
}
if (this.config.isInFallbackMode()) {
const flashModelName = this.config.modelConfigService.getResolvedConfig({
model: 'loop-detection',
}).model;
this.handleConfirmedLoop(flashResult, flashModelName);
return true;
}
// Double check with configured model
const mainModelResult = await this.queryLoopDetectionModel(
DOUBLE_CHECK_MODEL_ALIAS,
contents,
signal,
);
const mainModelConfidence = mainModelResult
? (mainModelResult['unproductive_state_confidence'] as number)
: 0;
logLlmLoopCheck(
this.config,
new LlmLoopCheckEvent(
this.promptId,
flashConfidence,
doubleCheckModelName,
mainModelConfidence,
),
);
if (mainModelResult) {
if (mainModelConfidence >= LLM_CONFIDENCE_THRESHOLD) {
this.handleConfirmedLoop(mainModelResult, doubleCheckModelName);
return true;
} else {
this.llmCheckInterval = Math.round(
MIN_LLM_CHECK_INTERVAL +
(MAX_LLM_CHECK_INTERVAL - MIN_LLM_CHECK_INTERVAL) *
(1 - result['unproductive_state_confidence']),
);
this.updateCheckInterval(mainModelConfidence);
}
}
return false;
}
private async queryLoopDetectionModel(
model: string,
contents: Content[],
signal: AbortSignal,
): Promise<Record<string, unknown> | null> {
try {
const result = (await this.config.getBaseLlmClient().generateJson({
modelConfigKey: { model },
contents,
schema: LOOP_DETECTION_SCHEMA,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
maxAttempts: 2,
})) as Record<string, unknown>;
if (
result &&
typeof result['unproductive_state_confidence'] === 'number'
) {
return result;
}
return null;
} catch (e) {
this.config.getDebugMode() ? debugLogger.warn(e) : debugLogger.debug(e);
return null;
}
}
private handleConfirmedLoop(
result: Record<string, unknown>,
modelName: string,
): void {
if (
typeof result['unproductive_state_analysis'] === 'string' &&
result['unproductive_state_analysis']
) {
debugLogger.warn(result['unproductive_state_analysis']);
}
logLoopDetected(
this.config,
new LoopDetectedEvent(
LoopType.LLM_DETECTED_LOOP,
this.promptId,
modelName,
),
);
}
private updateCheckInterval(unproductive_state_confidence: number): void {
this.llmCheckInterval = Math.round(
MIN_LLM_CHECK_INTERVAL +
(MAX_LLM_CHECK_INTERVAL - MIN_LLM_CHECK_INTERVAL) *
(1 - unproductive_state_confidence),
);
}
/**
* Resets all loop detection state.
*/

View File

@@ -141,6 +141,13 @@
"topP": 1
}
},
"loop-detection-double-check": {
"model": "gemini-2.5-pro",
"generateContentConfig": {
"temperature": 0,
"topP": 1
}
},
"llm-edit-fixer": {
"model": "gemini-2.5-flash",
"generateContentConfig": {