mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-18 18:11:02 -07:00
fix(core): reduce LLM-based loop detection false positives (#20701)
This commit is contained in:
@@ -806,15 +806,15 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
};
|
||||
|
||||
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
|
||||
await advanceTurns(29);
|
||||
await advanceTurns(39);
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should trigger LLM check on the 30th turn', async () => {
|
||||
it('should trigger LLM check on the 40th turn', async () => {
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -828,12 +828,12 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
});
|
||||
|
||||
it('should detect a cognitive loop when confidence is high', async () => {
|
||||
// First check at turn 30
|
||||
// First check at turn 40
|
||||
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
|
||||
unproductive_state_confidence: 0.85,
|
||||
unproductive_state_analysis: 'Repetitive actions',
|
||||
});
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -842,14 +842,14 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
);
|
||||
|
||||
// The confidence of 0.85 will result in a low interval.
|
||||
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
|
||||
await advanceTurns(6); // advance to turn 36
|
||||
// The interval will be: 7 + (15 - 7) * (1 - 0.85) = 7 + 8 * 0.15 = 8.2 -> rounded to 8
|
||||
await advanceTurns(7); // advance to turn 47
|
||||
|
||||
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValue({
|
||||
unproductive_state_confidence: 0.95,
|
||||
unproductive_state_analysis: 'Repetitive actions',
|
||||
});
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 48
|
||||
|
||||
expect(finalResult).toBe(true);
|
||||
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
|
||||
@@ -867,7 +867,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_confidence: 0.5,
|
||||
unproductive_state_analysis: 'Looks okay',
|
||||
});
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
@@ -875,16 +875,17 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
|
||||
it('should adjust the check interval based on confidence', async () => {
|
||||
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
|
||||
// Interval = 7 + (15 - 7) * (1 - 0.0) = 15
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.0 });
|
||||
await advanceTurns(30); // First check at turn 30
|
||||
await advanceTurns(40); // First check at turn 40
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await advanceTurns(14); // Advance to turn 44
|
||||
await advanceTurns(14); // Advance to turn 54
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await service.turnStarted(abortController.signal); // Turn 45
|
||||
await service.turnStarted(abortController.signal); // Turn 55
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
@@ -892,7 +893,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('API error'));
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
@@ -901,7 +902,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
it('should not trigger LLM check when disabled for session', async () => {
|
||||
service.disableForSession();
|
||||
expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1);
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
@@ -924,7 +925,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
@@ -949,7 +950,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Main says loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
|
||||
// It should have called generateJson twice
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -989,7 +990,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Main says no loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenNthCalledWith(
|
||||
@@ -1009,12 +1010,12 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
|
||||
// But should have updated the interval based on the main model's confidence (0.89)
|
||||
// Interval = 5 + (15-5) * (1 - 0.89) = 5 + 10 * 0.11 = 5 + 1.1 = 6.1 -> 6
|
||||
// Interval = 7 + (15-7) * (1 - 0.89) = 7 + 8 * 0.11 = 7 + 0.88 = 7.88 -> 8
|
||||
|
||||
// Advance by 6 turns
|
||||
await advanceTurns(6);
|
||||
// Advance by 7 turns
|
||||
await advanceTurns(7);
|
||||
|
||||
// Next turn (37) should trigger another check
|
||||
// Next turn (48) should trigger another check
|
||||
await service.turnStarted(abortController.signal);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
@@ -1032,7 +1033,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
unproductive_state_analysis: 'Flash says loop',
|
||||
});
|
||||
|
||||
await advanceTurns(30);
|
||||
await advanceTurns(40);
|
||||
|
||||
// It should have called generateJson only once
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -1052,4 +1053,53 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should include user prompt in LLM check contents when provided', async () => {
|
||||
service.reset('test-prompt-id', 'Add license headers to all files');
|
||||
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
|
||||
await advanceTurns(40);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
// First content should be the user prompt context wrapped in XML
|
||||
expect(calledArg.contents[0]).toEqual({
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: '<original_user_request>\nAdd license headers to all files\n</original_user_request>',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should not include user prompt in contents when not provided', async () => {
|
||||
service.reset('test-prompt-id');
|
||||
|
||||
vi.mocked(mockGeminiClient.getHistory).mockReturnValue([
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ text: 'Some response' }],
|
||||
},
|
||||
]);
|
||||
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ unproductive_state_confidence: 0.1 });
|
||||
|
||||
await advanceTurns(40);
|
||||
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
const calledArg = vi.mocked(mockBaseLlmClient.generateJson).mock
|
||||
.calls[0][0];
|
||||
// First content should be the history, not a user prompt message
|
||||
expect(calledArg.contents[0]).toEqual({
|
||||
role: 'model',
|
||||
parts: [{ text: 'Some response' }],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user