refactor(core): Use BaseLlmClient for LLM-based loop detection (#8427)

This commit is contained in:
Abhi
2025-09-13 23:07:33 -04:00
committed by GitHub
parent f11d79a931
commit 12720a9fa7
2 changed files with 50 additions and 29 deletions
@@ -7,6 +7,7 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { Config } from '../config/config.js';
import type { GeminiClient } from '../core/client.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type {
ServerGeminiContentEvent,
ServerGeminiStreamEvent,
@@ -625,16 +626,21 @@ describe('LoopDetectionService LLM Checks', () => {
let service: LoopDetectionService;
let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let mockBaseLlmClient: BaseLlmClient;
let abortController: AbortController;
beforeEach(() => {
mockGeminiClient = {
getHistory: vi.fn().mockReturnValue([]),
generateJson: vi.fn(),
} as unknown as GeminiClient;
mockBaseLlmClient = {
generateJson: vi.fn(),
} as unknown as BaseLlmClient;
mockConfig = {
getGeminiClient: () => mockGeminiClient,
getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false,
getTelemetryEnabled: () => true,
} as unknown as Config;
@@ -656,30 +662,39 @@ describe('LoopDetectionService LLM Checks', () => {
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
await advanceTurns(29);
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
it('should trigger LLM check on the 30th turn', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.1 });
await advanceTurns(30);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
systemInstruction: expect.any(String),
contents: expect.any(Array),
model: expect.any(String),
schema: expect.any(Object),
promptId: expect.any(String),
}),
);
});
it('should detect a cognitive loop when confidence is high', async () => {
// First check at turn 30
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.85, reasoning: 'Repetitive actions' });
await advanceTurns(30);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
// The confidence of 0.85 will result in a low interval.
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
await advanceTurns(6); // advance to turn 36
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.95, reasoning: 'Repetitive actions' });
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
@@ -695,7 +710,7 @@ describe('LoopDetectionService LLM Checks', () => {
});
it('should not detect a loop when confidence is low', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.5, reasoning: 'Looks okay' });
await advanceTurns(30);
@@ -706,21 +721,21 @@ describe('LoopDetectionService LLM Checks', () => {
it('should adjust the check interval based on confidence', async () => {
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.0 });
await advanceTurns(30); // First check at turn 30
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await advanceTurns(14); // Advance to turn 44
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await service.turnStarted(abortController.signal); // Turn 45
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
});
it('should handle errors from generateJson gracefully', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockRejectedValue(new Error('API error'));
await advanceTurns(30);
@@ -734,6 +749,6 @@ describe('LoopDetectionService LLM Checks', () => {
await advanceTurns(30);
const result = await service.turnStarted(abortController.signal);
expect(result).toBe(false);
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
});