mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 05:55:17 -07:00
refactor(core): Use BaseLlmClient for LLM-based loop detection (#8427)
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type {
|
||||
ServerGeminiContentEvent,
|
||||
ServerGeminiStreamEvent,
|
||||
@@ -625,16 +626,21 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
let service: LoopDetectionService;
|
||||
let mockConfig: Config;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let mockBaseLlmClient: BaseLlmClient;
|
||||
let abortController: AbortController;
|
||||
|
||||
beforeEach(() => {
|
||||
mockGeminiClient = {
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as GeminiClient;
|
||||
|
||||
mockBaseLlmClient = {
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as BaseLlmClient;
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
getBaseLlmClient: () => mockBaseLlmClient,
|
||||
getDebugMode: () => false,
|
||||
getTelemetryEnabled: () => true,
|
||||
} as unknown as Config;
|
||||
@@ -656,30 +662,39 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
|
||||
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
|
||||
await advanceTurns(29);
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should trigger LLM check on the 30th turn', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.1 });
|
||||
await advanceTurns(30);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
systemInstruction: expect.any(String),
|
||||
contents: expect.any(Array),
|
||||
model: expect.any(String),
|
||||
schema: expect.any(Object),
|
||||
promptId: expect.any(String),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should detect a cognitive loop when confidence is high', async () => {
|
||||
// First check at turn 30
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.85, reasoning: 'Repetitive actions' });
|
||||
await advanceTurns(30);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
// The confidence of 0.85 will result in a low interval.
|
||||
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
|
||||
await advanceTurns(6); // advance to turn 36
|
||||
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.95, reasoning: 'Repetitive actions' });
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
|
||||
@@ -695,7 +710,7 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
});
|
||||
|
||||
it('should not detect a loop when confidence is low', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.5, reasoning: 'Looks okay' });
|
||||
await advanceTurns(30);
|
||||
@@ -706,21 +721,21 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
|
||||
it('should adjust the check interval based on confidence', async () => {
|
||||
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.0 });
|
||||
await advanceTurns(30); // First check at turn 30
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await advanceTurns(14); // Advance to turn 44
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await service.turnStarted(abortController.signal); // Turn 45
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle errors from generateJson gracefully', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
mockBaseLlmClient.generateJson = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('API error'));
|
||||
await advanceTurns(30);
|
||||
@@ -734,6 +749,6 @@ describe('LoopDetectionService LLM Checks', () => {
|
||||
await advanceTurns(30);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user