mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-26 13:04:49 -07:00
refactor(core): Use BaseLlmClient for utility LLM calls in edit corrector (#8443)
This commit is contained in:
@@ -58,6 +58,7 @@ describe('EditTool', () => {
|
||||
let rootDir: string;
|
||||
let mockConfig: Config;
|
||||
let geminiClient: any;
|
||||
let baseLlmClient: any;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
@@ -69,8 +70,13 @@ describe('EditTool', () => {
|
||||
generateJson: mockGenerateJson, // mockGenerateJson is already defined and hoisted
|
||||
};
|
||||
|
||||
baseLlmClient = {
|
||||
generateJson: vi.fn(),
|
||||
};
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: vi.fn().mockReturnValue(geminiClient),
|
||||
getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient),
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
@@ -424,11 +430,12 @@ describe('EditTool', () => {
|
||||
// Set a specific mock for this test case
|
||||
let mockCalled = false;
|
||||
mockEnsureCorrectEdit.mockImplementationOnce(
|
||||
async (_, content, p, client) => {
|
||||
async (_, content, p, client, baseClient) => {
|
||||
mockCalled = true;
|
||||
expect(content).toBe(originalContent);
|
||||
expect(p).toBe(params);
|
||||
expect(client).toBe(geminiClient);
|
||||
expect(baseClient).toBe(baseLlmClient);
|
||||
return {
|
||||
params: {
|
||||
file_path: filePath,
|
||||
|
||||
@@ -164,6 +164,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
currentContent,
|
||||
params,
|
||||
this.config.getGeminiClient(),
|
||||
this.config.getBaseLlmClient(),
|
||||
abortSignal,
|
||||
);
|
||||
finalOldString = correctedEdit.params.old_string;
|
||||
|
||||
@@ -26,6 +26,7 @@ import path from 'node:path';
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { CorrectedEditResult } from '../utils/editCorrector.js';
|
||||
import {
|
||||
ensureCorrectEdit,
|
||||
@@ -47,6 +48,7 @@ vi.mock('../ide/ide-client.js', () => ({
|
||||
},
|
||||
}));
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
|
||||
const mockEnsureCorrectEdit = vi.fn<typeof ensureCorrectEdit>();
|
||||
const mockEnsureCorrectFileContent = vi.fn<typeof ensureCorrectFileContent>();
|
||||
const mockIdeClient = {
|
||||
@@ -70,6 +72,7 @@ const mockConfigInternal = {
|
||||
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
|
||||
setApprovalMode: vi.fn(),
|
||||
getGeminiClient: vi.fn(), // Initialize as a plain mock function
|
||||
getBaseLlmClient: vi.fn(), // Initialize as a plain mock function
|
||||
getFileSystemService: () => fsService,
|
||||
getIdeMode: vi.fn(() => false),
|
||||
getWorkspaceContext: () => createMockWorkspaceContext(rootDir),
|
||||
@@ -123,15 +126,23 @@ describe('WriteFileTool', () => {
|
||||
) as Mocked<GeminiClient>;
|
||||
vi.mocked(GeminiClient).mockImplementation(() => mockGeminiClientInstance);
|
||||
|
||||
// Setup BaseLlmClient mock
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
|
||||
vi.mocked(ensureCorrectEdit).mockImplementation(mockEnsureCorrectEdit);
|
||||
vi.mocked(ensureCorrectFileContent).mockImplementation(
|
||||
mockEnsureCorrectFileContent,
|
||||
);
|
||||
|
||||
// Now that mockGeminiClientInstance is initialized, set the mock implementation for getGeminiClient
|
||||
// Now that mock instances are initialized, set the mock implementations for config getters
|
||||
mockConfigInternal.getGeminiClient.mockReturnValue(
|
||||
mockGeminiClientInstance,
|
||||
);
|
||||
mockConfigInternal.getBaseLlmClient.mockReturnValue(
|
||||
mockBaseLlmClientInstance,
|
||||
);
|
||||
|
||||
tool = new WriteFileTool(mockConfig);
|
||||
|
||||
@@ -148,7 +159,8 @@ describe('WriteFileTool', () => {
|
||||
_currentContent: string,
|
||||
params: EditToolParams,
|
||||
_client: GeminiClient,
|
||||
signal?: AbortSignal, // Make AbortSignal optional to match usage
|
||||
_baseClient: BaseLlmClient,
|
||||
signal?: AbortSignal,
|
||||
): Promise<CorrectedEditResult> => {
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
@@ -162,10 +174,9 @@ describe('WriteFileTool', () => {
|
||||
mockEnsureCorrectFileContent.mockImplementation(
|
||||
async (
|
||||
content: string,
|
||||
_client: GeminiClient,
|
||||
_baseClient: BaseLlmClient,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> => {
|
||||
// Make AbortSignal optional
|
||||
if (signal?.aborted) {
|
||||
return Promise.reject(new Error('Aborted'));
|
||||
}
|
||||
@@ -263,7 +274,7 @@ describe('WriteFileTool', () => {
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectEdit).not.toHaveBeenCalled();
|
||||
@@ -307,6 +318,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockEnsureCorrectFileContent).not.toHaveBeenCalled();
|
||||
@@ -383,7 +395,7 @@ describe('WriteFileTool', () => {
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
@@ -433,6 +445,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(confirmation).toEqual(
|
||||
@@ -599,7 +612,7 @@ describe('WriteFileTool', () => {
|
||||
|
||||
expect(mockEnsureCorrectFileContent).toHaveBeenCalledWith(
|
||||
proposedContent,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(
|
||||
@@ -663,6 +676,7 @@ describe('WriteFileTool', () => {
|
||||
file_path: filePath,
|
||||
},
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result.llmContent).toMatch(/Successfully overwrote file/);
|
||||
|
||||
@@ -121,6 +121,7 @@ export async function getCorrectedFileContent(
|
||||
file_path: filePath,
|
||||
},
|
||||
config.getGeminiClient(),
|
||||
config.getBaseLlmClient(),
|
||||
abortSignal,
|
||||
);
|
||||
correctedContent = correctedParams.new_string;
|
||||
@@ -128,7 +129,7 @@ export async function getCorrectedFileContent(
|
||||
// This implies new file (ENOENT)
|
||||
correctedContent = await ensureCorrectFileContent(
|
||||
proposedContent,
|
||||
config.getGeminiClient(),
|
||||
config.getBaseLlmClient(),
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import type { Mock } from 'vitest';
|
||||
import { vi, describe, it, expect, beforeEach, type Mocked } from 'vitest';
|
||||
import * as fs from 'node:fs';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
// MOCKS
|
||||
let callCount = 0;
|
||||
@@ -28,10 +29,9 @@ vi.mock('../core/client.js', () => ({
|
||||
this: any,
|
||||
_config: Config,
|
||||
) {
|
||||
this.generateJson = (...params: any[]) => mockGenerateJson(...params); // Corrected: use mockGenerateJson
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params); // Corrected: use mockStartChat
|
||||
this.startChat = (...params: any[]) => mockStartChat(...params);
|
||||
this.sendMessageStream = (...params: any[]) =>
|
||||
mockSendMessageStream(...params); // Corrected: use mockSendMessageStream
|
||||
mockSendMessageStream(...params);
|
||||
return this;
|
||||
}),
|
||||
}));
|
||||
@@ -154,6 +154,7 @@ describe('editCorrector', () => {
|
||||
|
||||
describe('ensureCorrectEdit', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
@@ -233,6 +234,9 @@ describe('editCorrector', () => {
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
mockGeminiClientInstance.getHistory = vi.fn().mockResolvedValue([]);
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
@@ -252,6 +256,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -271,6 +276,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -293,6 +299,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -312,6 +319,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -335,6 +343,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -354,6 +363,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -373,6 +383,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -397,6 +408,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -420,6 +432,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -441,6 +454,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -464,6 +478,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -486,6 +501,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
@@ -505,6 +521,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
@@ -529,6 +546,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
@@ -585,6 +603,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -595,65 +614,10 @@ describe('editCorrector', () => {
|
||||
});
|
||||
|
||||
describe('ensureCorrectFileContent', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
let mockBaseLlmClientInstance: Mocked<BaseLlmClient>;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
const configParams = {
|
||||
apiKey: 'test-api-key',
|
||||
model: 'test-model',
|
||||
sandbox: false as boolean | string,
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
question: undefined as string | undefined,
|
||||
fullContext: false,
|
||||
coreTools: undefined as string[] | undefined,
|
||||
toolDiscoveryCommand: undefined as string | undefined,
|
||||
toolCallCommand: undefined as string | undefined,
|
||||
mcpServerCommand: undefined as string | undefined,
|
||||
mcpServers: undefined as Record<string, any> | undefined,
|
||||
userAgent: 'test-agent',
|
||||
userMemory: '',
|
||||
geminiMdFileCount: 0,
|
||||
alwaysSkipModificationConfirmation: false,
|
||||
};
|
||||
mockConfigInstance = {
|
||||
...configParams,
|
||||
getApiKey: vi.fn(() => configParams.apiKey),
|
||||
getModel: vi.fn(() => configParams.model),
|
||||
getSandbox: vi.fn(() => configParams.sandbox),
|
||||
getTargetDir: vi.fn(() => configParams.targetDir),
|
||||
getToolRegistry: vi.fn(() => mockToolRegistry),
|
||||
getDebugMode: vi.fn(() => configParams.debugMode),
|
||||
getQuestion: vi.fn(() => configParams.question),
|
||||
getFullContext: vi.fn(() => configParams.fullContext),
|
||||
getCoreTools: vi.fn(() => configParams.coreTools),
|
||||
getToolDiscoveryCommand: vi.fn(() => configParams.toolDiscoveryCommand),
|
||||
getToolCallCommand: vi.fn(() => configParams.toolCallCommand),
|
||||
getMcpServerCommand: vi.fn(() => configParams.mcpServerCommand),
|
||||
getMcpServers: vi.fn(() => configParams.mcpServers),
|
||||
getUserAgent: vi.fn(() => configParams.userAgent),
|
||||
getUserMemory: vi.fn(() => configParams.userMemory),
|
||||
setUserMemory: vi.fn((mem: string) => {
|
||||
configParams.userMemory = mem;
|
||||
}),
|
||||
getGeminiMdFileCount: vi.fn(() => configParams.geminiMdFileCount),
|
||||
setGeminiMdFileCount: vi.fn((count: number) => {
|
||||
configParams.geminiMdFileCount = count;
|
||||
}),
|
||||
getAlwaysSkipModificationConfirmation: vi.fn(
|
||||
() => configParams.alwaysSkipModificationConfirmation,
|
||||
),
|
||||
setAlwaysSkipModificationConfirmation: vi.fn((skip: boolean) => {
|
||||
configParams.alwaysSkipModificationConfirmation = skip;
|
||||
}),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi
|
||||
@@ -667,12 +631,10 @@ describe('editCorrector', () => {
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
mockGeminiClientInstance = new GeminiClient(
|
||||
mockConfigInstance,
|
||||
) as Mocked<GeminiClient>;
|
||||
mockBaseLlmClientInstance = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as Mocked<BaseLlmClient>;
|
||||
resetEditCorrectorCaches_TEST_ONLY();
|
||||
});
|
||||
|
||||
@@ -680,7 +642,7 @@ describe('editCorrector', () => {
|
||||
const content = 'This is normal content without escaping issues';
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBe(content);
|
||||
@@ -696,7 +658,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -716,7 +678,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -731,7 +693,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
@@ -751,7 +713,7 @@ describe('editCorrector', () => {
|
||||
|
||||
const result = await ensureCorrectFileContent(
|
||||
content,
|
||||
mockGeminiClientInstance,
|
||||
mockBaseLlmClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import type { Content, GenerateContentConfig } from '@google/genai';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import type { EditToolParams } from '../tools/edit.js';
|
||||
import { EditTool } from '../tools/edit.js';
|
||||
import { WriteFileTool } from '../tools/write-file.js';
|
||||
@@ -19,14 +20,27 @@ import {
|
||||
isFunctionCall,
|
||||
} from '../utils/messageInspectors.js';
|
||||
import * as fs from 'node:fs';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
|
||||
const EditModel = DEFAULT_GEMINI_FLASH_LITE_MODEL;
|
||||
const EditConfig: GenerateContentConfig = {
|
||||
const EDIT_MODEL = DEFAULT_GEMINI_FLASH_LITE_MODEL;
|
||||
const EDIT_CONFIG: GenerateContentConfig = {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const CODE_CORRECTION_SYSTEM_PROMPT = `
|
||||
You are an expert code-editing assistant. Your task is to analyze a failed edit attempt and provide a corrected version of the text snippets.
|
||||
The correction should be as minimal as possible, staying very close to the original.
|
||||
Focus ONLY on fixing issues like whitespace, indentation, line endings, or incorrect escaping.
|
||||
Do NOT invent a completely new edit. Your job is to fix the provided parameters to make the edit succeed.
|
||||
Return ONLY the corrected snippet in the specified JSON format.
|
||||
`.trim();
|
||||
|
||||
function getPromptId(): string {
|
||||
return promptIdContext.getStore() ?? `edit-corrector-${Date.now()}`;
|
||||
}
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
// Cache for ensureCorrectEdit results
|
||||
@@ -159,7 +173,8 @@ export async function ensureCorrectEdit(
|
||||
filePath: string,
|
||||
currentContent: string,
|
||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||
client: GeminiClient,
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CorrectedEditResult> {
|
||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||
@@ -181,7 +196,7 @@ export async function ensureCorrectEdit(
|
||||
if (occurrences === expectedReplacements) {
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewStringEscaping(
|
||||
client,
|
||||
baseLlmClient,
|
||||
finalOldString,
|
||||
originalParams.new_string,
|
||||
abortSignal,
|
||||
@@ -228,7 +243,7 @@ export async function ensureCorrectEdit(
|
||||
finalOldString = unescapedOldStringAttempt;
|
||||
if (newStringPotentiallyEscaped) {
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
baseLlmClient,
|
||||
originalParams.old_string, // original old
|
||||
unescapedOldStringAttempt, // corrected old
|
||||
originalParams.new_string, // original new (which is potentially escaped)
|
||||
@@ -242,7 +257,7 @@ export async function ensureCorrectEdit(
|
||||
// our system has done
|
||||
const lastEditedByUsTime = await findLastEditTimestamp(
|
||||
filePath,
|
||||
client,
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
// Add a 1-second buffer to account for timing inaccuracies. If the file
|
||||
@@ -265,7 +280,7 @@ export async function ensureCorrectEdit(
|
||||
}
|
||||
|
||||
const llmCorrectedOldString = await correctOldStringMismatch(
|
||||
client,
|
||||
baseLlmClient,
|
||||
currentContent,
|
||||
unescapedOldStringAttempt,
|
||||
abortSignal,
|
||||
@@ -284,7 +299,7 @@ export async function ensureCorrectEdit(
|
||||
originalParams.new_string,
|
||||
);
|
||||
finalNewString = await correctNewString(
|
||||
client,
|
||||
baseLlmClient,
|
||||
originalParams.old_string, // original old
|
||||
llmCorrectedOldString, // corrected old
|
||||
baseNewStringForLLMCorrection, // base new for correction
|
||||
@@ -335,7 +350,7 @@ export async function ensureCorrectEdit(
|
||||
|
||||
export async function ensureCorrectFileContent(
|
||||
content: string,
|
||||
client: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const cachedResult = fileContentCorrectionCache.get(content);
|
||||
@@ -352,7 +367,7 @@ export async function ensureCorrectFileContent(
|
||||
|
||||
const correctedContent = await correctStringEscaping(
|
||||
content,
|
||||
client,
|
||||
baseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
fileContentCorrectionCache.set(content, correctedContent);
|
||||
@@ -373,7 +388,7 @@ const OLD_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
};
|
||||
|
||||
export async function correctOldStringMismatch(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
fileContent: string,
|
||||
problematicSnippet: string,
|
||||
abortSignal: AbortSignal,
|
||||
@@ -402,13 +417,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
OLD_STRING_CORRECTION_SCHEMA,
|
||||
schema: OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -450,7 +467,7 @@ const NEW_STRING_CORRECTION_SCHEMA: Record<string, unknown> = {
|
||||
* Adjusts the new_string to align with a corrected old_string, maintaining the original intent.
|
||||
*/
|
||||
export async function correctNewString(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
originalOldString: string,
|
||||
correctedOldString: string,
|
||||
originalNewString: string,
|
||||
@@ -490,13 +507,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
NEW_STRING_CORRECTION_SCHEMA,
|
||||
schema: NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -530,7 +549,7 @@ const CORRECT_NEW_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
};
|
||||
|
||||
export async function correctNewStringEscaping(
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
oldString: string,
|
||||
potentiallyProblematicNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
@@ -559,13 +578,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await geminiClient.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
schema: CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
@@ -603,7 +624,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: Record<string, unknown> = {
|
||||
|
||||
export async function correctStringEscaping(
|
||||
potentiallyProblematicString: string,
|
||||
client: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
@@ -625,13 +646,15 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const contents: Content[] = [{ role: 'user', parts: [{ text: prompt }] }];
|
||||
|
||||
try {
|
||||
const result = await client.generateJson(
|
||||
const result = await baseLlmClient.generateJson({
|
||||
contents,
|
||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
schema: CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
model: EDIT_MODEL,
|
||||
config: EDIT_CONFIG,
|
||||
systemInstruction: CODE_CORRECTION_SYSTEM_PROMPT,
|
||||
promptId: getPromptId(),
|
||||
});
|
||||
|
||||
if (
|
||||
result &&
|
||||
|
||||
Reference in New Issue
Block a user