refactor(core): Introduce LlmUtilityService and promptIdContext (#7952)

This commit is contained in:
Abhi
2025-09-09 01:14:15 -04:00
committed by GitHub
parent 471cbcd450
commit 1eaf21f6a2
12 changed files with 943 additions and 165 deletions

View File

@@ -0,0 +1,203 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
FixLLMEditWithInstruction,
resetLlmEditFixerCaches_TEST_ONLY,
type SearchReplaceEdit,
} from './llm-edit-fixer.js';
import { promptIdContext } from './promptIdContext.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
// Mock the BaseLlmClient
const mockGenerateJson = vi.fn();
const mockBaseLlmClient = {
generateJson: mockGenerateJson,
} as unknown as BaseLlmClient;
describe('FixLLMEditWithInstruction', () => {
const instruction = 'Replace the title';
const old_string = '<h1>Old Title</h1>';
const new_string = '<h1>New Title</h1>';
const error = 'String not found';
const current_content = '<body><h1>Old Title</h1></body>';
const abortController = new AbortController();
const abortSignal = abortController.signal;
beforeEach(() => {
vi.clearAllMocks();
resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test
});
afterEach(() => {
vi.useRealTimers(); // Reset timers after each test
});
const mockApiResponse: SearchReplaceEdit = {
search: '<h1>Old Title</h1>',
replace: '<h1>New Title</h1>',
noChangesRequired: false,
explanation: 'The original search was correct.',
};
it('should use the promptId from the AsyncLocalStorage context when available', async () => {
const testPromptId = 'test-prompt-id-12345';
mockGenerateJson.mockResolvedValue(mockApiResponse);
await promptIdContext.run(testPromptId, async () => {
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
});
// Verify that generateJson was called with the promptId from the context
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(mockGenerateJson).toHaveBeenCalledWith(
expect.objectContaining({
promptId: testPromptId,
}),
);
});
it('should generate and use a fallback promptId when context is not available', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
// Run the function outside of any context
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Verify the warning was logged
expect(consoleWarnSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-',
),
);
// Verify that generateJson was called with the generated fallback promptId
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(mockGenerateJson).toHaveBeenCalledWith(
expect.objectContaining({
promptId: expect.stringContaining('llm-fixer-fallback-'),
}),
);
// Restore mocks
consoleWarnSpy.mockRestore();
});
it('should construct the user prompt correctly', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const promptId = 'test-prompt-id-prompt-construction';
await promptIdContext.run(promptId, async () => {
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
});
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
const userPromptContent = generateJsonCall.contents[0].parts[0].text;
expect(userPromptContent).toContain(
`<instruction>\n${instruction}\n</instruction>`,
);
expect(userPromptContent).toContain(`<search>\n${old_string}\n</search>`);
expect(userPromptContent).toContain(`<replace>\n${new_string}\n</replace>`);
expect(userPromptContent).toContain(`<error>\n${error}\n</error>`);
expect(userPromptContent).toContain(
`<file_content>\n${current_content}\n</file_content>`,
);
});
it('should return a cached result on subsequent identical calls', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const testPromptId = 'test-prompt-id-caching';
await promptIdContext.run(testPromptId, async () => {
// First call - should call the API
const result1 = await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Second call with identical parameters - should hit the cache
const result2 = await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
expect(result1).toEqual(mockApiResponse);
expect(result2).toEqual(mockApiResponse);
// Verify the underlying service was only called ONCE
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
});
});
it('should not use cache for calls with different parameters', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const testPromptId = 'test-prompt-id-cache-miss';
await promptIdContext.run(testPromptId, async () => {
// First call
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Second call with a different instruction
await FixLLMEditWithInstruction(
'A different instruction',
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Verify the underlying service was called TWICE
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -5,9 +5,10 @@
*/
import { type Content, Type } from '@google/genai';
import { type GeminiClient } from '../core/client.js';
import { type BaseLlmClient } from '../core/baseLlmClient.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { promptIdContext } from './promptIdContext.js';
const MAX_CACHE_SIZE = 50;
@@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache<
* @param new_string The original replacement string.
* @param error The error that occurred during the initial edit.
* @param current_content The current content of the file.
* @param geminiClient The Gemini client to use for the LLM call.
* @param baseLlmClient The BaseLlmClient to use for the LLM call.
* @param abortSignal An abort signal to cancel the operation.
* @param promptId A unique ID for the prompt.
* @returns A new search and replace pair.
*/
export async function FixLLMEditWithInstruction(
@@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction(
new_string: string,
error: string,
current_content: string,
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
abortSignal: AbortSignal,
): Promise<SearchReplaceEdit> {
let promptId = promptIdContext.getStore();
if (!promptId) {
promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`;
console.warn(
`Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
);
}
const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`;
const cachedResult = editCorrectionWithInstructionCache.get(cacheKey);
if (cachedResult) {
@@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction(
const contents: Content[] = [
{
role: 'user',
parts: [
{
text: `${EDIT_SYS_PROMPT}
${userPrompt}`,
},
],
parts: [{ text: userPrompt }],
},
];
const result = (await geminiClient.generateJson(
const result = (await baseLlmClient.generateJson({
contents,
SearchReplaceEditSchema,
schema: SearchReplaceEditSchema,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
)) as unknown as SearchReplaceEdit;
model: DEFAULT_GEMINI_FLASH_MODEL,
systemInstruction: EDIT_SYS_PROMPT,
promptId,
})) as unknown as SearchReplaceEdit;
editCorrectionWithInstructionCache.set(cacheKey, result);
return result;

View File

@@ -0,0 +1,9 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { AsyncLocalStorage } from 'node:async_hooks';
export const promptIdContext = new AsyncLocalStorage<string>();