refactor(core): Use BaseLlmClient for utility LLM calls in edit corrector (#8443)

This commit is contained in:
Abhi
2025-09-15 20:46:41 -04:00
committed by GitHub
parent d5d150449d
commit 1634d5fcca
6 changed files with 120 additions and 112 deletions
+31 -69
View File
@@ -9,6 +9,7 @@ import type { Mock } from 'vitest';
import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest';
import * as fs from 'node:fs';
import { EditTool } from '../tools/edit.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
// MOCKS
let callCount = 0;
@@ -28,10 +29,9 @@ vi.mock('../core/client.js', () => ({
this: any,
_config: Config,
) {
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
this.startChat = (...params: any[]) => mockStartChat(...params);
this.sendMessageStream = (...params: any[]) =>
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
mockSendMessageStream(...params);
return this;
}),
}));
@@ -154,6 +154,7 @@ describe('editCorrector', () => {
describe('ensureCorrectEdit', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config;
const abortSignal = new AbortController().signal;
@@ -233,6 +234,9 @@ describe('editCorrector', () => {
mockConfigInstance,
) as Mocked<GeminiClient>;
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
mockBaseLlmClientInstance = {
generateJson: mockGenerateJson,
} as unknown as Mocked<BaseLlmClient>;
resetEditCorrectorCaches_TEST_ONLY();
});
@@ -252,6 +256,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -271,6 +276,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
@@ -293,6 +299,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -312,6 +319,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
@@ -335,6 +343,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -354,6 +363,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
@@ -373,6 +383,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
@@ -397,6 +408,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -420,6 +432,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
@@ -441,6 +454,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -464,6 +478,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -486,6 +501,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
@@ -505,6 +521,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
@@ -529,6 +546,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
@@ -585,6 +603,7 @@ describe('editCorrector', () => {
currentContent,
originalParams,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
@@ -595,65 +614,10 @@ describe('editCorrector', () => {
});
describe('ensureCorrectFileContent', () => {
let mockGeminiClientInstance: Mocked<GeminiClient>;
let mockToolRegistry: Mocked<ToolRegistry>;
let mockConfigInstance: Config;
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
const abortSignal = new AbortController().signal;
beforeEach(() => {
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
const configParams = {
apiKey: 'test-api-key',
model: 'test-model',
sandbox: false as boolean | string,
targetDir: '/test',
debugMode: false,
question: undefined as string | undefined,
fullContext: false,
coreTools: undefined as string[] | undefined,
toolDiscoveryCommand: undefined as string | undefined,
toolCallCommand: undefined as string | undefined,
mcpServerCommand: undefined as string | undefined,
mcpServers: undefined as Record<string, any> | undefined,
userAgent: 'test-agent',
userMemory: '',
geminiMdFileCount: 0,
alwaysSkipModificationConfirmation: false,
};
mockConfigInstance = {
...configParams,
getApiKey: vi.fn(() => configParams.apiKey),
getModel: vi.fn(() => configParams.model),
getSandbox: vi.fn(() => configParams.sandbox),
getTargetDir: vi.fn(() => configParams.targetDir),
getToolRegistry: vi.fn(() => mockToolRegistry),
getDebugMode: vi.fn(() => configParams.debugMode),
getQuestion: vi.fn(() => configParams.question),
getFullContext: vi.fn(() => configParams.fullContext),
getCoreTools: vi.fn(() => configParams.coreTools),
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
getMcpServers: vi.fn(() => configParams.mcpServers),
getUserAgent: vi.fn(() => configParams.userAgent),
getUserMemory: vi.fn(() => configParams.userMemory),
setUserMemory: vi.fn((mem: string) => {
configParams.userMemory = mem;
}),
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
setGeminiMdFileCount: vi.fn((count: number) => {
configParams.geminiMdFileCount = count;
}),
getAlwaysSkipModificationConfirmation: vi.fn(
() => configParams.alwaysSkipModificationConfirmation,
),
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
configParams.alwaysSkipModificationConfirmation = skip;
}),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
} as unknown as Config;
callCount = 0;
mockResponses.length = 0;
mockGenerateJson = vi
@@ -667,12 +631,10 @@ describe('editCorrector', () => {
if (response === undefined) return Promise.resolve({});
return Promise.resolve(response);
});
mockStartChat = vi.fn();
mockSendMessageStream = vi.fn();
mockGeminiClientInstance = new GeminiClient(
mockConfigInstance,
) as Mocked<GeminiClient>;
mockBaseLlmClientInstance = {
generateJson: mockGenerateJson,
} as unknown as Mocked<BaseLlmClient>;
resetEditCorrectorCaches_TEST_ONLY();
});
@@ -680,7 +642,7 @@ describe('editCorrector', () => {
const content = 'This is normal content without escaping issues';
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
expect(result).toBe(content);
@@ -696,7 +658,7 @@ describe('editCorrector', () => {
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
@@ -716,7 +678,7 @@ describe('editCorrector', () => {
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
@@ -731,7 +693,7 @@ describe('editCorrector', () => {
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
@@ -751,7 +713,7 @@ describe('editCorrector', () => {
const result = await ensureCorrectFileContent(
content,
mockGeminiClientInstance,
mockBaseLlmClientInstance,
abortSignal,
);
+57 -34
View File
@@ -6,6 +6,7 @@
import type { Content, GenerateContentConfig } from '@google/genai';
import type { GeminiClient } from '../core/client.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type { EditToolParams } from '../tools/edit.js';
import { EditTool } from '../tools/edit.js';
import { WriteFileTool } from '../tools/write-file.js';
@@ -19,14 +20,27 @@ import {
isFunctionCall,
} from '../utils/messageInspectors.js';
import * as fs from 'node:fs';
import { promptIdContext } from './promptIdContext.js';
const EditModel = DEFAULT_GEMINI_FLASH_LITE_MODEL;
const EditConfig: GenerateContentConfig = {
const EDIT_MODEL = DEFAULT_GEMINI_FLASH_LITE_MODEL;
const EDIT_CONFIG: GenerateContentConfig = {
thinkingConfig: {
thinkingBudget: 0,
},
};
const CODE_CORRECTION_SYSTEM_PROMPT = `
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
The correction should be as minimal as possible, staying very close to the original.
Focus ONLY on fixing issues like whitespace, indentation, line endings, or incorrect escaping.
Do NOT invent a completely new edit. Your job is to fix the provided parameters to make the edit succeed.
Return ONLY the corrected snippet in the specified JSON format.
`.trim();
function getPromptId(): string {
return promptIdContext.getStore() ?? `edit-corrector-${Date.now()}`;
}
const MAX_CACHE_SIZE = 50;
// Cache for ensureCorrectEdit results
@@ -159,7 +173,8 @@ export async function ensureCorrectEdit(
filePath: string,
currentContent: string,
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
client: GeminiClient,
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
abortSignal: AbortSignal,
): Promise<CorrectedEditResult> {
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
@@ -181,7 +196,7 @@ export async function ensureCorrectEdit(
if (occurrences === expectedReplacements) {
if (newStringPotentiallyEscaped) {
finalNewString = await correctNewStringEscaping(
client,
baseLlmClient,
finalOldString,
originalParams.new_string,
abortSignal,
@@ -228,7 +243,7 @@ export async function ensureCorrectEdit(
finalOldString = unescapedOldStringAttempt;
if (newStringPotentiallyEscaped) {
finalNewString = await correctNewString(
client,
baseLlmClient,
originalParams.old_string, // original old
unescapedOldStringAttempt, // corrected old
originalParams.new_string, // original new (which is potentially escaped)
@@ -242,7 +257,7 @@ export async function ensureCorrectEdit(
// our system has done
const lastEditedByUsTime = await findLastEditTimestamp(
filePath,
client,
geminiClient,
);
// Add a 1-second buffer to account for timing inaccuracies. If the file
@@ -265,7 +280,7 @@ export async function ensureCorrectEdit(
}
const llmCorrectedOldString = await correctOldStringMismatch(
client,
baseLlmClient,
currentContent,
unescapedOldStringAttempt,
abortSignal,
@@ -284,7 +299,7 @@ export async function ensureCorrectEdit(
originalParams.new_string,
);
finalNewString = await correctNewString(
client,
baseLlmClient,
originalParams.old_string, // original old
llmCorrectedOldString, // corrected old
baseNewStringForLLMCorrection, // base new for correction
@@ -335,7 +350,7 @@ export async function ensureCorrectEdit(
export async function ensureCorrectFileContent(
content: string,
client: GeminiClient,
baseLlmClient: BaseLlmClient,
abortSignal: AbortSignal,
): Promise<string> {
const cachedResult = fileContentCorrectionCache.get(content);
@@ -352,7 +367,7 @@ export async function ensureCorrectFileContent(
const correctedContent = await correctStringEscaping(
content,
client,
baseLlmClient,
abortSignal,
);
fileContentCorrectionCache.set(content, correctedContent);
@@ -373,7 +388,7 @@ const OLD_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
};
export async function correctOldStringMismatch(
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
fileContent: string,
problematicSnippet: string,
abortSignal: AbortSignal,
@@ -402,13 +417,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
const result = await baseLlmClient.generateJson({
contents,
OLD_STRING_CORRECTION_SCHEMA,
schema: OLD_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
if (
result &&
@@ -450,7 +467,7 @@ const NEW_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
*/
export async function correctNewString(
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
originalOldString: string,
correctedOldString: string,
originalNewString: string,
@@ -490,13 +507,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
const result = await baseLlmClient.generateJson({
contents,
NEW_STRING_CORRECTION_SCHEMA,
schema: NEW_STRING_CORRECTION_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
if (
result &&
@@ -530,7 +549,7 @@ const CORRECT_NEW_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
};
export async function correctNewStringEscaping(
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
oldString: string,
potentiallyProblematicNewString: string,
abortSignal: AbortSignal,
@@ -559,13 +578,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await geminiClient.generateJson(
const result = await baseLlmClient.generateJson({
contents,
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
schema: CORRECT_NEW_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
if (
result &&
@@ -603,7 +624,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
export async function correctStringEscaping(
potentiallyProblematicString: string,
client: GeminiClient,
baseLlmClient: BaseLlmClient,
abortSignal: AbortSignal,
): Promise<string> {
const prompt = `
@@ -625,13 +646,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
try {
const result = await client.generateJson(
const result = await baseLlmClient.generateJson({
contents,
CORRECT_STRING_ESCAPING_SCHEMA,
schema: CORRECT_STRING_ESCAPING_SCHEMA,
abortSignal,
EditModel,
EditConfig,
);
model: EDIT_MODEL,
config: EDIT_CONFIG,
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
promptId: getPromptId(),
});
if (
result &&