fix(core): reduce LLM-based loop detection false positives (#20701)

This commit is contained in:
Sandy Tao
2026-03-02 11:08:15 -08:00
committed by GitHub
parent 740efa2ac2
commit 7c9fceba7f
3 changed files with 123 additions and 34 deletions

View File

@@ -12,6 +12,7 @@ import type {
GenerateContentResponse, GenerateContentResponse,
} from '@google/genai'; } from '@google/genai';
import { createUserContent } from '@google/genai'; import { createUserContent } from '@google/genai';
import { partListUnionToString } from './geminiRequest.js';
import { import {
getDirectoryContextString, getDirectoryContextString,
getInitialChatHistory, getInitialChatHistory,
@@ -802,7 +803,7 @@ export class GeminiClient {
const messageBus = this.config.getMessageBus(); const messageBus = this.config.getMessageBus();
if (this.lastPromptId !== prompt_id) { if (this.lastPromptId !== prompt_id) {
this.loopDetector.reset(prompt_id); this.loopDetector.reset(prompt_id, partListUnionToString(request));
this.hookStateMap.delete(this.lastPromptId); this.hookStateMap.delete(this.lastPromptId);
this.lastPromptId = prompt_id; this.lastPromptId = prompt_id;
this.currentSequenceModel = null; this.currentSequenceModel = null;

View File

@@ -806,15 +806,15 @@ describe('LoopDetectionService LLM Checks', () => {
}; };
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => { it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
await advanceTurns(29); await advanceTurns(39);
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
}); });
it('should trigger LLM check on the 30th turn', async () => { it('should trigger LLM check on the 40th turn', async () => {
mockBaseLlmClient.generateJson = vi mockBaseLlmClient.generateJson = vi
.fn() .fn()
.mockResolvedValue({ unproductive_state_confidence: 0.1 }); .mockResolvedValue({ unproductive_state_confidence: 0.1 });
await advanceTurns(30); await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith( expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
@@ -828,12 +828,12 @@ describe('LoopDetectionService LLM Checks', () => {
}); });
it('should detect a cognitive loop when confidence is high', async () => { it('should detect a cognitive loop when confidence is high', async () => {
// First check at turn 30 // First check at turn 40
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({ mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
unproductive_state_confidence: 0.85, unproductive_state_confidence: 0.85,
unproductive_state_analysis: 'Repetitive actions', unproductive_state_analysis: 'Repetitive actions',
}); });
await advanceTurns(30); await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith( expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
@@ -842,14 +842,14 @@ describe('LoopDetectionService LLM Checks', () => {
); );
// The confidence of 0.85 will result in a low interval. // The confidence of 0.85 will result in a low interval.
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7 // The interval will be: 7 + (15 - 7) * (1 - 0.85) = 7 + 8 * 0.15 = 8.2 -> rounded to 8
await advanceTurns(6); // advance to turn 36 await advanceTurns(7); // advance to turn 47
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({ mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
unproductive_state_confidence: 0.95, unproductive_state_confidence: 0.95,
unproductive_state_analysis: 'Repetitive actions', unproductive_state_analysis: 'Repetitive actions',
}); });
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37 const finalResult = await service.turnStarted(abortController.signal); // This is turn 48
expect(finalResult).toBe(true); expect(finalResult).toBe(true);
expect(loggers.logLoopDetected).toHaveBeenCalledWith( expect(loggers.logLoopDetected).toHaveBeenCalledWith(
@@ -867,7 +867,7 @@ describe('LoopDetectionService LLM Checks', () => {
unproductive_state_confidence: 0.5, unproductive_state_confidence: 0.5,
unproductive_state_analysis: 'Looks okay', unproductive_state_analysis: 'Looks okay',
}); });
await advanceTurns(30); await advanceTurns(40);
const result = await service.turnStarted(abortController.signal); const result = await service.turnStarted(abortController.signal);
expect(result).toBe(false); expect(result).toBe(false);
expect(loggers.logLoopDetected).not.toHaveBeenCalled(); expect(loggers.logLoopDetected).not.toHaveBeenCalled();
@@ -875,16 +875,17 @@ describe('LoopDetectionService LLM Checks', () => {
it('should adjust the check interval based on confidence', async () => { it('should adjust the check interval based on confidence', async () => {
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15) // Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
// Interval = 7 + (15 - 7) * (1 - 0.0) = 15
mockBaseLlmClient.generateJson = vi mockBaseLlmClient.generateJson = vi
.fn() .fn()
.mockResolvedValue({ unproductive_state_confidence: 0.0 }); .mockResolvedValue({ unproductive_state_confidence: 0.0 });
await advanceTurns(30); // First check at turn 30 await advanceTurns(40); // First check at turn 40
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await advanceTurns(14); // Advance to turn 44 await advanceTurns(14); // Advance to turn 54
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await service.turnStarted(abortController.signal); // Turn 45 await service.turnStarted(abortController.signal); // Turn 55
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
}); });
@@ -892,7 +893,7 @@ describe('LoopDetectionService LLM Checks', () => {
mockBaseLlmClient.generateJson = vi mockBaseLlmClient.generateJson = vi
.fn() .fn()
.mockRejectedValue(new Error('API error')); .mockRejectedValue(new Error('API error'));
await advanceTurns(30); await advanceTurns(40);
const result = await service.turnStarted(abortController.signal); const result = await service.turnStarted(abortController.signal);
expect(result).toBe(false); expect(result).toBe(false);
expect(loggers.logLoopDetected).not.toHaveBeenCalled(); expect(loggers.logLoopDetected).not.toHaveBeenCalled();
@@ -901,7 +902,7 @@ describe('LoopDetectionService LLM Checks', () => {
it('should not trigger LLM check when disabled for session', async () => { it('should not trigger LLM check when disabled for session', async () => {
service.disableForSession(); service.disableForSession();
expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1); expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1);
await advanceTurns(30); await advanceTurns(40);
const result = await service.turnStarted(abortController.signal); const result = await service.turnStarted(abortController.signal);
expect(result).toBe(false); expect(result).toBe(false);
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled(); expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
@@ -924,7 +925,7 @@ describe('LoopDetectionService LLM Checks', () => {
.fn() .fn()
.mockResolvedValue({ unproductive_state_confidence: 0.1 }); .mockResolvedValue({ unproductive_state_confidence: 0.1 });
await advanceTurns(30); await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
@@ -949,7 +950,7 @@ describe('LoopDetectionService LLM Checks', () => {
unproductive_state_analysis: 'Main says loop', unproductive_state_analysis: 'Main says loop',
}); });
await advanceTurns(30); await advanceTurns(40);
// It should have called generateJson twice // It should have called generateJson twice
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
@@ -989,7 +990,7 @@ describe('LoopDetectionService LLM Checks', () => {
unproductive_state_analysis: 'Main says no loop', unproductive_state_analysis: 'Main says no loop',
}); });
await advanceTurns(30); await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith( expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
@@ -1009,12 +1010,12 @@ describe('LoopDetectionService LLM Checks', () => {
expect(loggers.logLoopDetected).not.toHaveBeenCalled(); expect(loggers.logLoopDetected).not.toHaveBeenCalled();
// But should have updated the interval based on the main model's confidence (0.89) // But should have updated the interval based on the main model's confidence (0.89)
// Interval = 5 + (15-5) * (1 - 0.89) = 5 + 10 * 0.11 = 5 + 1.1 = 6.1 -> 6 // Interval = 7 + (15-7) * (1 - 0.89) = 7 + 8 * 0.11 = 7 + 0.88 = 7.88 -> 8
// Advance by 6 turns // Advance by 7 turns
await advanceTurns(6); await advanceTurns(7);
// Next turn (37) should trigger another check // Next turn (48) should trigger another check
await service.turnStarted(abortController.signal); await service.turnStarted(abortController.signal);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3);
}); });
@@ -1032,7 +1033,7 @@ describe('LoopDetectionService LLM Checks', () => {
unproductive_state_analysis: 'Flash says loop', unproductive_state_analysis: 'Flash says loop',
}); });
await advanceTurns(30); await advanceTurns(40);
// It should have called generateJson only once // It should have called generateJson only once
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
@@ -1052,4 +1053,53 @@ describe('LoopDetectionService LLM Checks', () => {
}), }),
); );
}); });
it('should include user prompt in LLM check contents when provided', async () => {
service.reset('test-prompt-id', 'Add license headers to all files');
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
// First content should be the user prompt context wrapped in XML
expect(calledArg.contents[0]).toEqual({
role: 'user',
parts: [
{
text: '<original_user_request>\nAdd license headers to all files\n</original_user_request>',
},
],
});
});
it('should not include user prompt in contents when not provided', async () => {
service.reset('test-prompt-id');
vi.mocked(mockGeminiClient.getHistory).mockReturnValue([
{
role: 'model',
parts: [{ text: 'Some response' }],
},
]);
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
await advanceTurns(40);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
.calls[0][0];
// First content should be the history, not a user prompt message
expect(calledArg.contents[0]).toEqual({
role: 'model',
parts: [{ text: 'Some response' }],
});
});
}); });

View File

@@ -40,19 +40,19 @@ const LLM_LOOP_CHECK_HISTORY_COUNT = 20;
/** /**
* The number of turns that must pass in a single prompt before the LLM-based loop check is activated. * The number of turns that must pass in a single prompt before the LLM-based loop check is activated.
*/ */
const LLM_CHECK_AFTER_TURNS = 30; const LLM_CHECK_AFTER_TURNS = 40;
/** /**
* The default interval, in number of turns, at which the LLM-based loop check is performed. * The default interval, in number of turns, at which the LLM-based loop check is performed.
* This value is adjusted dynamically based on the LLM's confidence. * This value is adjusted dynamically based on the LLM's confidence.
*/ */
const DEFAULT_LLM_CHECK_INTERVAL = 3; const DEFAULT_LLM_CHECK_INTERVAL = 10;
/** /**
* The minimum interval for LLM-based loop checks. * The minimum interval for LLM-based loop checks.
* This is used when the confidence of a loop is high, to check more frequently. * This is used when the confidence of a loop is high, to check more frequently.
*/ */
const MIN_LLM_CHECK_INTERVAL = 5; const MIN_LLM_CHECK_INTERVAL = 7;
/** /**
* The maximum interval for LLM-based loop checks. * The maximum interval for LLM-based loop checks.
@@ -66,16 +66,40 @@ const MAX_LLM_CHECK_INTERVAL = 15;
const LLM_CONFIDENCE_THRESHOLD = 0.9; const LLM_CONFIDENCE_THRESHOLD = 0.9;
const DOUBLE_CHECK_MODEL_ALIAS = 'loop-detection-double-check'; const DOUBLE_CHECK_MODEL_ALIAS = 'loop-detection-double-check';
const LOOP_DETECTION_SYSTEM_PROMPT = `You are a sophisticated AI diagnostic agent specializing in identifying when a conversational AI is stuck in an unproductive state. Your task is to analyze the provided conversation history and determine if the assistant has ceased to make meaningful progress. const LOOP_DETECTION_SYSTEM_PROMPT = `You are a diagnostic agent that determines whether a conversational AI assistant is stuck in an unproductive loop. Analyze the conversation history (and, if provided, the original user request) to make this determination.
An unproductive state is characterized by one or more of the following patterns over the last 5 or more assistant turns: ## What constitutes an unproductive state
Repetitive Actions: The assistant repeats the same tool calls or conversational responses a decent number of times. This includes simple loops (e.g., tool_A, tool_A, tool_A) and alternating patterns (e.g., tool_A, tool_B, tool_A, tool_B, ...). An unproductive state requires BOTH of the following to be true:
1. The assistant has exhibited a repetitive pattern over at least 5 consecutive model actions (tool calls or text responses, counting only model-role turns).
2. The repetition produces NO net change or forward progress toward the user's goal.
Cognitive Loop: The assistant seems unable to determine the next logical step. It might express confusion, repeatedly ask the same questions, or generate responses that don't logically follow from the previous turns, indicating it's stuck and not advancing the task. Specific patterns to look for:
- **Alternating cycles with no net effect:** The assistant cycles between the same actions (e.g., edit_file → run_build → edit_file → run_build) where each iteration applies the same edit and encounters the same error, making zero progress. Note: alternating between actions is only a loop if the arguments and outcomes are substantively identical each cycle. If the assistant is modifying different code or getting different errors, that is debugging progress, not a loop.
- **Semantic repetition with identical outcomes:** The assistant calls the same tool with semantically equivalent arguments (same file, same line range, same content) multiple times consecutively, and each call produces the same outcome. This does NOT include build/test commands that are re-run after making code changes between invocations — re-running a build to verify a fix is normal workflow.
- **Stuck reasoning:** The assistant produces multiple consecutive text responses that restate the same plan, question, or analysis without taking any new action or making a decision. This does NOT include command output that happens to contain repeated status lines or warnings.
Crucially, differentiate between a true unproductive state and legitimate, incremental progress. ## What is NOT an unproductive state
For example, a series of 'tool_A' or 'tool_B' tool calls that make small, distinct changes to the same file (like adding docstrings to functions one by one) is considered forward progress and is NOT a loop. A loop would be repeatedly replacing the same text with the same content, or cycling between a small set of files with no net change.`;
You MUST distinguish repetitive-looking but productive work from true loops. The following are examples of forward progress and must NOT be flagged:
- **Cross-file batch operations:** A series of tool calls with the same tool name but targeting different files (different file paths in the arguments). For example, adding license headers to 20 files, or running the same refactoring across multiple modules.
- **Incremental same-file edits:** Multiple edits to the same file that target different line ranges, different functions, or different text content (e.g., adding docstrings to functions one by one).
- **Sequential processing:** A series of read or search operations on different files/paths to gather information.
- **Retry with variation:** Re-attempting a failed operation with modified arguments or a different approach.
## Argument analysis (critical)
When evaluating tool calls, you MUST compare the **arguments** of each call, not just the tool name. Pay close attention to:
- **File paths:** Different file paths mean different targets — this is distinct work, not repetition.
- **Line numbers and text content:** Different line ranges or different old_string/new_string values indicate distinct edits.
- **Search queries and patterns:** Different search terms indicate information gathering, not looping.
A loop exists only when the same tool is called with semantically equivalent arguments repeatedly, indicating no forward progress.
## Using the original user request
If the original user request is provided, use it to contextualize the assistant's behavior. If the request implies a batch or multi-step operation (e.g., "update all files", "refactor every module", "add tests for each function"), then repetitive tool calls with varying arguments are expected and should weigh heavily against flagging a loop.`;
const LOOP_DETECTION_SCHEMA: Record<string, unknown> = { const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
type: 'object', type: 'object',
@@ -101,6 +125,7 @@ const LOOP_DETECTION_SCHEMA: Record<string, unknown> = {
export class LoopDetectionService { export class LoopDetectionService {
private readonly config: Config; private readonly config: Config;
private promptId = ''; private promptId = '';
private userPrompt = '';
// Tool call tracking // Tool call tracking
private lastToolCallKey: string | null = null; private lastToolCallKey: string | null = null;
@@ -450,9 +475,21 @@ export class LoopDetectionService {
const trimmedHistory = this.trimRecentHistory(recentHistory); const trimmedHistory = this.trimRecentHistory(recentHistory);
const taskPrompt = `Please analyze the conversation history to determine the possibility that the conversation is stuck in a repetitive, non-productive state. Provide your response in the requested JSON format.`; const taskPrompt = `Please analyze the conversation history to determine the possibility that the conversation is stuck in a repetitive, non-productive state. Consider the original user request when evaluating whether repeated tool calls represent legitimate batch work or an actual loop. Provide your response in the requested JSON format.`;
const contents = [ const contents = [
...(this.userPrompt
? [
{
role: 'user' as const,
parts: [
{
text: `<original_user_request>\n${this.userPrompt}\n</original_user_request>`,
},
],
},
]
: []),
...trimmedHistory, ...trimmedHistory,
{ role: 'user', parts: [{ text: taskPrompt }] }, { role: 'user', parts: [{ text: taskPrompt }] },
]; ];
@@ -602,8 +639,9 @@ export class LoopDetectionService {
/** /**
* Resets all loop detection state. * Resets all loop detection state.
*/ */
reset(promptId: string): void { reset(promptId: string, userPrompt?: string): void {
this.promptId = promptId; this.promptId = promptId;
this.userPrompt = userPrompt ?? '';
this.resetToolCallCount(); this.resetToolCallCount();
this.resetContentTracking(); this.resetContentTracking();
this.resetLlmCheckTracking(); this.resetLlmCheckTracking();