mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-24 12:04:56 -07:00
refactor(core): Use BaseLlmClient for utility LLM calls in edit corrector (#8443)
This commit is contained in:
@@ -9,6 +9,7 @@ import type { Mock } from 'vitest';
|
||||
import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest';
|
||||
import * as fs from 'node:fs';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
// MOCKS
|
||||
let callCount = 0;
|
||||
@@ -28,10 +29,9 @@ vi.mock('../core/client.js', () => ({
|
||||
this: any,
|
||||
_config: Config,
|
||||
) {
|
||||
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params);
|
||||
this.sendMessageStream = (...params: any[]) =>
|
||||
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
|
||||
mockSendMessageStream(...params);
|
||||
return this;
|
||||
}),
|
||||
}));
|
||||
@@ -154,6 +154,7 @@ describe('editCorrector', () => {
|
||||
|
||||
describe('ensureCorrectEdit', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
@@ -233,6 +234,9 @@ describe('editCorrector', () => {
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
@@ -252,6 +256,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -271,6 +276,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -293,6 +299,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -312,6 +319,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -335,6 +343,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -354,6 +363,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -373,6 +383,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -397,6 +408,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -420,6 +432,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -441,6 +454,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -464,6 +478,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -486,6 +501,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -505,6 +521,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -529,6 +546,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -585,6 +603,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -595,65 +614,10 @@ describe('editCorrector', () => {
|
||||
});
|
||||
|
||||
describe('ensureCorrectFileContent', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
const configParams = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
sandbox: false as boolean | string,
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
question: undefined as string | undefined,
|
||||
fullContext: false,
|
||||
coreTools: undefined as string[] | undefined,
|
||||
toolDiscoveryCommand: undefined as string | undefined,
|
||||
toolCallCommand: undefined as string | undefined,
|
||||
mcpServerCommand: undefined as string | undefined,
|
||||
mcpServers: undefined as Record<string, any> | undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
};
|
||||
mockConfigInstance = {
|
||||
...configParams,
|
||||
getApiKey: vi.fn(() => configParams.apiKey),
|
||||
getModel: vi.fn(() => configParams.model),
|
||||
getSandbox: vi.fn(() => configParams.sandbox),
|
||||
getTargetDir: vi.fn(() => configParams.targetDir),
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
getDebugMode: vi.fn(() => configParams.debugMode),
|
||||
getQuestion: vi.fn(() => configParams.question),
|
||||
getFullContext: vi.fn(() => configParams.fullContext),
|
||||
getCoreTools: vi.fn(() => configParams.coreTools),
|
||||
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
|
||||
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
|
||||
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
|
||||
getMcpServers: vi.fn(() => configParams.mcpServers),
|
||||
getUserAgent: vi.fn(() => configParams.userAgent),
|
||||
getUserMemory: vi.fn(() => configParams.userMemory),
|
||||
setUserMemory: vi.fn((mem: string) => {
|
||||
configParams.userMemory = mem;
|
||||
}),
|
||||
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
|
||||
setGeminiMdFileCount: vi.fn((count: number) => {
|
||||
configParams.geminiMdFileCount = count;
|
||||
}),
|
||||
getAlwaysSkipModificationConfirmation: vi.fn(
|
||||
() => configParams.alwaysSkipModificationConfirmation,
|
||||
),
|
||||
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
|
||||
configParams.alwaysSkipModificationConfirmation = skip;
|
||||
}),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi
|
||||
@@ -667,12 +631,10 @@ describe('editCorrector', () => {
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
mockGeminiClientInstance = new GeminiClient(
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
@@ -680,7 +642,7 @@ describe('editCorrector', () => {
|
||||
const content = 'This is normal content without escaping issues';
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBe(content);
|
||||
@@ -696,7 +658,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -716,7 +678,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -731,7 +693,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -751,7 +713,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import type { Content, GenerateContentConfig } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { EditToolParams } from '../tools/edit.js';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import { WriteFileTool } from '../tools/write-file.js';
|
||||
@@ -19,14 +20,27 @@ import {
|
||||
isFunctionCall,
|
||||
} from '../utils/messageInspectors.js';
|
||||
import * as fs from 'node:fs';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
|
||||
const EditModel = DEFAULT_GEMINI_FLASH_LITE_MODEL;
|
||||
const EditConfig: GenerateContentConfig = {
|
||||
const EDIT_MODEL = DEFAULT_GEMINI_FLASH_LITE_MODEL;
|
||||
const EDIT_CONFIG: GenerateContentConfig = {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const CODE_CORRECTION_SYSTEM_PROMPT = `
|
||||
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
|
||||
The correction should be as minimal as possible, staying very close to the original.
|
||||
Focus ONLY on fixing issues like whitespace, indentation, line endings, or incorrect escaping.
|
||||
Do NOT invent a completely new edit. Your job is to fix the provided parameters to make the edit succeed.
|
||||
Return ONLY the corrected snippet in the specified JSON format.
|
||||
`.trim();
|
||||
|
||||
function getPromptId(): string {
|
||||
return promptIdContext.getStore() ?? `edit-corrector-${Date.now()}`;
|
||||
}
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
// Cache for ensureCorrectEdit results
|
||||
@@ -159,7 +173,8 @@ export async function ensureCorrectEdit(
|
||||
filePath: string,
|
||||
currentContent: string,
|
||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||
client: GeminiClient,
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CorrectedEditResult> {
|
||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||
@@ -181,7 +196,7 @@ export async function ensureCorrectEdit(
|
||||
if (occurrences === expectedReplacements) {
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewStringEscaping(
|
||||
client,
|
||||
baseLlmClient,
|
||||
finalOldString,
|
||||
originalParams.new_string,
|
||||
abortSignal,
|
||||
@@ -228,7 +243,7 @@ export async function ensureCorrectEdit(
|
||||
finalOldString = unescapedOldStringAttempt;
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
baseLlmClient,
|
||||
originalParams.old_string, // original old
|
||||
unescapedOldStringAttempt, // corrected old
|
||||
originalParams.new_string, // original new (which is potentially escaped)
|
||||
@@ -242,7 +257,7 @@ export async function ensureCorrectEdit(
|
||||
// our system has done
|
||||
const lastEditedByUsTime = await findLastEditTimestamp(
|
||||
filePath,
|
||||
client,
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
// Add a 1-second buffer to account for timing inaccuracies. If the file
|
||||
@@ -265,7 +280,7 @@ export async function ensureCorrectEdit(
|
||||
}
|
||||
|
||||
const llmCorrectedOldString = await correctOldStringMismatch(
|
||||
client,
|
||||
baseLlmClient,
|
||||
currentContent,
|
||||
unescapedOldStringAttempt,
|
||||
abortSignal,
|
||||
@@ -284,7 +299,7 @@ export async function ensureCorrectEdit(
|
||||
originalParams.new_string,
|
||||
);
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
baseLlmClient,
|
||||
originalParams.old_string, // original old
|
||||
llmCorrectedOldString, // corrected old
|
||||
baseNewStringForLLMCorrection, // base new for correction
|
||||
@@ -335,7 +350,7 @@ export async function ensureCorrectEdit(
|
||||
|
||||
export async function ensureCorrectFileContent(
|
||||
content: string,
|
||||
client: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const cachedResult = fileContentCorrectionCache.get(content);
|
||||
@@ -352,7 +367,7 @@ export async function ensureCorrectFileContent(
|
||||
|
||||
const correctedContent = await correctStringEscaping(
|
||||
content,
|
||||
client,
|
||||
baseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
fileContentCorrectionCache.set(content, correctedContent);
|
||||
@@ -373,7 +388,7 @@ const OLD_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
};
|
||||
|
||||
export async function correctOldStringMismatch(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
fileContent: string,
|
||||
problematicSnippet: string,
|
||||
abortSignal: AbortSignal,
|
||||
@@ -402,13 +417,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
OLD_STRING_CORRECTION_SCHEMA,
|
||||
schema: OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -450,7 +467,7 @@ const NEW_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
|
||||
*/
|
||||
export async function correctNewString(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
originalOldString: string,
|
||||
correctedOldString: string,
|
||||
originalNewString: string,
|
||||
@@ -490,13 +507,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
NEW_STRING_CORRECTION_SCHEMA,
|
||||
schema: NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -530,7 +549,7 @@ const CORRECT_NEW_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
};
|
||||
|
||||
export async function correctNewStringEscaping(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
oldString: string,
|
||||
potentiallyProblematicNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
@@ -559,13 +578,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
schema: CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -603,7 +624,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
|
||||
export async function correctStringEscaping(
|
||||
potentiallyProblematicString: string,
|
||||
client: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
@@ -625,13 +646,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await client.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
schema: CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
|
||||
Reference in New Issue
Block a user