mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-24 21:10:43 -07:00
refactor(core): Introduce LlmUtilityService and promptIdContext (#7952)
This commit is contained in:
203
packages/core/src/utils/llm-edit-fixer.test.ts
Normal file
203
packages/core/src/utils/llm-edit-fixer.test.ts
Normal file
@@ -0,0 +1,203 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
FixLLMEditWithInstruction,
|
||||
resetLlmEditFixerCaches_TEST_ONLY,
|
||||
type SearchReplaceEdit,
|
||||
} from './llm-edit-fixer.js';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
// Mock the BaseLlmClient
|
||||
const mockGenerateJson = vi.fn();
|
||||
const mockBaseLlmClient = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as BaseLlmClient;
|
||||
|
||||
describe('FixLLMEditWithInstruction', () => {
|
||||
const instruction = 'Replace the title';
|
||||
const old_string = '<h1>Old Title</h1>';
|
||||
const new_string = '<h1>New Title</h1>';
|
||||
const error = 'String not found';
|
||||
const current_content = '<body><h1>Old Title</h1></body>';
|
||||
const abortController = new AbortController();
|
||||
const abortSignal = abortController.signal;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers(); // Reset timers after each test
|
||||
});
|
||||
|
||||
const mockApiResponse: SearchReplaceEdit = {
|
||||
search: '<h1>Old Title</h1>',
|
||||
replace: '<h1>New Title</h1>',
|
||||
noChangesRequired: false,
|
||||
explanation: 'The original search was correct.',
|
||||
};
|
||||
|
||||
it('should use the promptId from the AsyncLocalStorage context when available', async () => {
|
||||
const testPromptId = 'test-prompt-id-12345';
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
});
|
||||
|
||||
// Verify that generateJson was called with the promptId from the context
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
promptId: testPromptId,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should generate and use a fallback promptId when context is not available', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
// Run the function outside of any context
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Verify the warning was logged
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-',
|
||||
),
|
||||
);
|
||||
|
||||
// Verify that generateJson was called with the generated fallback promptId
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
promptId: expect.stringContaining('llm-fixer-fallback-'),
|
||||
}),
|
||||
);
|
||||
|
||||
// Restore mocks
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should construct the user prompt correctly', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const promptId = 'test-prompt-id-prompt-construction';
|
||||
|
||||
await promptIdContext.run(promptId, async () => {
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
});
|
||||
|
||||
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||
const userPromptContent = generateJsonCall.contents[0].parts[0].text;
|
||||
|
||||
expect(userPromptContent).toContain(
|
||||
`<instruction>\n${instruction}\n</instruction>`,
|
||||
);
|
||||
expect(userPromptContent).toContain(`<search>\n${old_string}\n</search>`);
|
||||
expect(userPromptContent).toContain(`<replace>\n${new_string}\n</replace>`);
|
||||
expect(userPromptContent).toContain(`<error>\n${error}\n</error>`);
|
||||
expect(userPromptContent).toContain(
|
||||
`<file_content>\n${current_content}\n</file_content>`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return a cached result on subsequent identical calls', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const testPromptId = 'test-prompt-id-caching';
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
// First call - should call the API
|
||||
const result1 = await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Second call with identical parameters - should hit the cache
|
||||
const result2 = await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result1).toEqual(mockApiResponse);
|
||||
expect(result2).toEqual(mockApiResponse);
|
||||
// Verify the underlying service was only called ONCE
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not use cache for calls with different parameters', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const testPromptId = 'test-prompt-id-cache-miss';
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
// First call
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Second call with a different instruction
|
||||
await FixLLMEditWithInstruction(
|
||||
'A different instruction',
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Verify the underlying service was called TWICE
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,9 +5,10 @@
|
||||
*/
|
||||
|
||||
import { type Content, Type } from '@google/genai';
|
||||
import { type GeminiClient } from '../core/client.js';
|
||||
import { type BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
@@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache<
|
||||
* @param new_string The original replacement string.
|
||||
* @param error The error that occurred during the initial edit.
|
||||
* @param current_content The current content of the file.
|
||||
* @param geminiClient The Gemini client to use for the LLM call.
|
||||
* @param baseLlmClient The BaseLlmClient to use for the LLM call.
|
||||
* @param abortSignal An abort signal to cancel the operation.
|
||||
* @param promptId A unique ID for the prompt.
|
||||
* @returns A new search and replace pair.
|
||||
*/
|
||||
export async function FixLLMEditWithInstruction(
|
||||
@@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction(
|
||||
new_string: string,
|
||||
error: string,
|
||||
current_content: string,
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<SearchReplaceEdit> {
|
||||
let promptId = promptIdContext.getStore();
|
||||
if (!promptId) {
|
||||
promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`;
|
||||
console.warn(
|
||||
`Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
|
||||
);
|
||||
}
|
||||
|
||||
const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`;
|
||||
const cachedResult = editCorrectionWithInstructionCache.get(cacheKey);
|
||||
if (cachedResult) {
|
||||
@@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction(
|
||||
const contents: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `${EDIT_SYS_PROMPT}
|
||||
${userPrompt}`,
|
||||
},
|
||||
],
|
||||
parts: [{ text: userPrompt }],
|
||||
},
|
||||
];
|
||||
|
||||
const result = (await geminiClient.generateJson(
|
||||
const result = (await baseLlmClient.generateJson({
|
||||
contents,
|
||||
SearchReplaceEditSchema,
|
||||
schema: SearchReplaceEditSchema,
|
||||
abortSignal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
)) as unknown as SearchReplaceEdit;
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
systemInstruction: EDIT_SYS_PROMPT,
|
||||
promptId,
|
||||
})) as unknown as SearchReplaceEdit;
|
||||
|
||||
editCorrectionWithInstructionCache.set(cacheKey, result);
|
||||
return result;
|
||||
|
||||
9
packages/core/src/utils/promptIdContext.ts
Normal file
9
packages/core/src/utils/promptIdContext.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
export const promptIdContext = new AsyncLocalStorage<string>();
|
||||
Reference in New Issue
Block a user