mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
refactor(core): Introduce LlmUtilityService and promptIdContext (#7952)
This commit is contained in:
@@ -13,6 +13,7 @@ import {
|
||||
parseAndFormatApiError,
|
||||
FatalInputError,
|
||||
FatalTurnLimitedError,
|
||||
promptIdContext,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { Content, Part } from '@google/genai';
|
||||
|
||||
@@ -24,115 +25,117 @@ export async function runNonInteractive(
|
||||
input: string,
|
||||
prompt_id: string,
|
||||
): Promise<void> {
|
||||
const consolePatcher = new ConsolePatcher({
|
||||
stderr: true,
|
||||
debugMode: config.getDebugMode(),
|
||||
});
|
||||
|
||||
try {
|
||||
consolePatcher.patch();
|
||||
// Handle EPIPE errors when the output is piped to a command that closes early.
|
||||
process.stdout.on('error', (err: NodeJS.ErrnoException) => {
|
||||
if (err.code === 'EPIPE') {
|
||||
// Exit gracefully if the pipe is closed.
|
||||
process.exit(0);
|
||||
}
|
||||
return promptIdContext.run(prompt_id, async () => {
|
||||
const consolePatcher = new ConsolePatcher({
|
||||
stderr: true,
|
||||
debugMode: config.getDebugMode(),
|
||||
});
|
||||
|
||||
const geminiClient = config.getGeminiClient();
|
||||
try {
|
||||
consolePatcher.patch();
|
||||
// Handle EPIPE errors when the output is piped to a command that closes early.
|
||||
process.stdout.on('error', (err: NodeJS.ErrnoException) => {
|
||||
if (err.code === 'EPIPE') {
|
||||
// Exit gracefully if the pipe is closed.
|
||||
process.exit(0);
|
||||
}
|
||||
});
|
||||
|
||||
const abortController = new AbortController();
|
||||
const geminiClient = config.getGeminiClient();
|
||||
|
||||
const { processedQuery, shouldProceed } = await handleAtCommand({
|
||||
query: input,
|
||||
config,
|
||||
addItem: (_item, _timestamp) => 0,
|
||||
onDebugMessage: () => {},
|
||||
messageId: Date.now(),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
const abortController = new AbortController();
|
||||
|
||||
if (!shouldProceed || !processedQuery) {
|
||||
// An error occurred during @include processing (e.g., file not found).
|
||||
// The error message is already logged by handleAtCommand.
|
||||
throw new FatalInputError(
|
||||
'Exiting due to an error processing the @ command.',
|
||||
);
|
||||
}
|
||||
const { processedQuery, shouldProceed } = await handleAtCommand({
|
||||
query: input,
|
||||
config,
|
||||
addItem: (_item, _timestamp) => 0,
|
||||
onDebugMessage: () => {},
|
||||
messageId: Date.now(),
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
let currentMessages: Content[] = [
|
||||
{ role: 'user', parts: processedQuery as Part[] },
|
||||
];
|
||||
|
||||
let turnCount = 0;
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (
|
||||
config.getMaxSessionTurns() >= 0 &&
|
||||
turnCount > config.getMaxSessionTurns()
|
||||
) {
|
||||
throw new FatalTurnLimitedError(
|
||||
'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
|
||||
if (!shouldProceed || !processedQuery) {
|
||||
// An error occurred during @include processing (e.g., file not found).
|
||||
// The error message is already logged by handleAtCommand.
|
||||
throw new FatalInputError(
|
||||
'Exiting due to an error processing the @ command.',
|
||||
);
|
||||
}
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
|
||||
const responseStream = geminiClient.sendMessageStream(
|
||||
currentMessages[0]?.parts || [],
|
||||
abortController.signal,
|
||||
prompt_id,
|
||||
);
|
||||
let currentMessages: Content[] = [
|
||||
{ role: 'user', parts: processedQuery as Part[] },
|
||||
];
|
||||
|
||||
for await (const event of responseStream) {
|
||||
if (abortController.signal.aborted) {
|
||||
console.error('Operation cancelled.');
|
||||
let turnCount = 0;
|
||||
while (true) {
|
||||
turnCount++;
|
||||
if (
|
||||
config.getMaxSessionTurns() >= 0 &&
|
||||
turnCount > config.getMaxSessionTurns()
|
||||
) {
|
||||
throw new FatalTurnLimitedError(
|
||||
'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
|
||||
);
|
||||
}
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
|
||||
const responseStream = geminiClient.sendMessageStream(
|
||||
currentMessages[0]?.parts || [],
|
||||
abortController.signal,
|
||||
prompt_id,
|
||||
);
|
||||
|
||||
for await (const event of responseStream) {
|
||||
if (abortController.signal.aborted) {
|
||||
console.error('Operation cancelled.');
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === GeminiEventType.Content) {
|
||||
process.stdout.write(event.value);
|
||||
} else if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCallRequests.push(event.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCallRequests.length > 0) {
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const requestInfo of toolCallRequests) {
|
||||
const toolResponse = await executeToolCall(
|
||||
config,
|
||||
requestInfo,
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
if (toolResponse.error) {
|
||||
console.error(
|
||||
`Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (toolResponse.responseParts) {
|
||||
toolResponseParts.push(...toolResponse.responseParts);
|
||||
}
|
||||
}
|
||||
currentMessages = [{ role: 'user', parts: toolResponseParts }];
|
||||
} else {
|
||||
process.stdout.write('\n'); // Ensure a final newline
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === GeminiEventType.Content) {
|
||||
process.stdout.write(event.value);
|
||||
} else if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCallRequests.push(event.value);
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCallRequests.length > 0) {
|
||||
const toolResponseParts: Part[] = [];
|
||||
for (const requestInfo of toolCallRequests) {
|
||||
const toolResponse = await executeToolCall(
|
||||
config,
|
||||
requestInfo,
|
||||
abortController.signal,
|
||||
);
|
||||
|
||||
if (toolResponse.error) {
|
||||
console.error(
|
||||
`Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (toolResponse.responseParts) {
|
||||
toolResponseParts.push(...toolResponse.responseParts);
|
||||
}
|
||||
}
|
||||
currentMessages = [{ role: 'user', parts: toolResponseParts }];
|
||||
} else {
|
||||
process.stdout.write('\n'); // Ensure a final newline
|
||||
return;
|
||||
} catch (error) {
|
||||
console.error(
|
||||
parseAndFormatApiError(
|
||||
error,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
consolePatcher.cleanup();
|
||||
if (isTelemetrySdkInitialized()) {
|
||||
await shutdownTelemetry(config);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(
|
||||
parseAndFormatApiError(
|
||||
error,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
),
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
consolePatcher.cleanup();
|
||||
if (isTelemetrySdkInitialized()) {
|
||||
await shutdownTelemetry(config);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
parseAndFormatApiError,
|
||||
getCodeAssistServer,
|
||||
UserTierId,
|
||||
promptIdContext,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
|
||||
import type {
|
||||
@@ -705,71 +706,72 @@ export const useGeminiStream = (
|
||||
if (!prompt_id) {
|
||||
prompt_id = config.getSessionId() + '########' + getPromptCount();
|
||||
}
|
||||
|
||||
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
|
||||
query,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
|
||||
if (!shouldProceed || queryToSend === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
setThought(null); // Reset thought when starting a new prompt
|
||||
}
|
||||
|
||||
setIsResponding(true);
|
||||
setInitError(null);
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
const processingStatus = await processGeminiStreamEvents(
|
||||
stream,
|
||||
return promptIdContext.run(prompt_id, async () => {
|
||||
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
|
||||
query,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
prompt_id,
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
if (!shouldProceed || queryToSend === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
setPendingHistoryItem(null);
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
setThought(null); // Reset thought when starting a new prompt
|
||||
}
|
||||
if (loopDetectedRef.current) {
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError('Session expired or is unauthorized.');
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.ERROR,
|
||||
text: parseAndFormatApiError(
|
||||
getErrorMessage(error) || 'Unknown error',
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
|
||||
setIsResponding(true);
|
||||
setInitError(null);
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
abortSignal,
|
||||
prompt_id,
|
||||
);
|
||||
const processingStatus = await processGeminiStreamEvents(
|
||||
stream,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
if (loopDetectedRef.current) {
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError('Session expired or is unauthorized.');
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.ERROR,
|
||||
text: parseAndFormatApiError(
|
||||
getErrorMessage(error) || 'Unknown error',
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setIsResponding(false);
|
||||
}
|
||||
} finally {
|
||||
setIsResponding(false);
|
||||
}
|
||||
});
|
||||
},
|
||||
[
|
||||
streamingState,
|
||||
|
||||
@@ -122,6 +122,10 @@ vi.mock('../ide/ide-client.js', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
vi.mock('../core/baseLlmClient.js');
|
||||
|
||||
describe('Server Config (config.ts)', () => {
|
||||
const MODEL = 'gemini-pro';
|
||||
const SANDBOX: SandboxConfig = {
|
||||
@@ -774,3 +778,58 @@ describe('setApprovalMode with folder trust', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('BaseLlmClient Lifecycle', () => {
|
||||
const MODEL = 'gemini-pro';
|
||||
const SANDBOX: SandboxConfig = {
|
||||
command: 'docker',
|
||||
image: 'gemini-cli-sandbox',
|
||||
};
|
||||
const TARGET_DIR = '/path/to/target';
|
||||
const DEBUG_MODE = false;
|
||||
const QUESTION = 'test question';
|
||||
const FULL_CONTEXT = false;
|
||||
const USER_MEMORY = 'Test User Memory';
|
||||
const TELEMETRY_SETTINGS = { enabled: false };
|
||||
const EMBEDDING_MODEL = 'gemini-embedding';
|
||||
const SESSION_ID = 'test-session-id';
|
||||
const baseParams: ConfigParameters = {
|
||||
cwd: '/tmp',
|
||||
embeddingModel: EMBEDDING_MODEL,
|
||||
sandbox: SANDBOX,
|
||||
targetDir: TARGET_DIR,
|
||||
debugMode: DEBUG_MODE,
|
||||
question: QUESTION,
|
||||
fullContext: FULL_CONTEXT,
|
||||
userMemory: USER_MEMORY,
|
||||
telemetry: TELEMETRY_SETTINGS,
|
||||
sessionId: SESSION_ID,
|
||||
model: MODEL,
|
||||
usageStatisticsEnabled: false,
|
||||
};
|
||||
|
||||
it('should throw an error if getBaseLlmClient is called before refreshAuth', () => {
|
||||
const config = new Config(baseParams);
|
||||
expect(() => config.getBaseLlmClient()).toThrow(
|
||||
'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should successfully initialize BaseLlmClient after refreshAuth is called', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const authType = AuthType.USE_GEMINI;
|
||||
const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' };
|
||||
|
||||
vi.mocked(createContentGeneratorConfig).mockReturnValue(mockContentConfig);
|
||||
|
||||
await config.refreshAuth(authType);
|
||||
|
||||
// Should not throw
|
||||
const llmService = config.getBaseLlmClient();
|
||||
expect(llmService).toBeDefined();
|
||||
expect(BaseLlmClient).toHaveBeenCalledWith(
|
||||
config.getContentGenerator(),
|
||||
config,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,6 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
|
||||
import { WebSearchTool } from '../tools/web-search.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
import { GitService } from '../services/gitService.js';
|
||||
import type { TelemetryTarget } from '../telemetry/index.js';
|
||||
@@ -257,6 +258,7 @@ export class Config {
|
||||
private readonly telemetrySettings: TelemetrySettings;
|
||||
private readonly usageStatisticsEnabled: boolean;
|
||||
private geminiClient!: GeminiClient;
|
||||
private baseLlmClient!: BaseLlmClient;
|
||||
private readonly fileFiltering: {
|
||||
respectGitIgnore: boolean;
|
||||
respectGeminiIgnore: boolean;
|
||||
@@ -455,6 +457,9 @@ export class Config {
|
||||
// Only assign to instance properties after successful initialization
|
||||
this.contentGeneratorConfig = newContentGeneratorConfig;
|
||||
|
||||
// Initialize BaseLlmClient now that the ContentGenerator is available
|
||||
this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this);
|
||||
|
||||
// Reset the session flag since we're explicitly changing auth and using default model
|
||||
this.inFallbackMode = false;
|
||||
}
|
||||
@@ -463,6 +468,26 @@ export class Config {
|
||||
return this.contentGenerator?.userTier;
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides access to the BaseLlmClient for stateless LLM operations.
|
||||
*/
|
||||
getBaseLlmClient(): BaseLlmClient {
|
||||
if (!this.baseLlmClient) {
|
||||
// Handle cases where initialization might be deferred or authentication failed
|
||||
if (this.contentGenerator) {
|
||||
this.baseLlmClient = new BaseLlmClient(
|
||||
this.getContentGenerator(),
|
||||
this,
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.',
|
||||
);
|
||||
}
|
||||
}
|
||||
return this.baseLlmClient;
|
||||
}
|
||||
|
||||
getSessionId(): string {
|
||||
return this.sessionId;
|
||||
}
|
||||
|
||||
291
packages/core/src/core/baseLlmClient.test.ts
Normal file
291
packages/core/src/core/baseLlmClient.test.ts
Normal file
@@ -0,0 +1,291 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
|
||||
vi.mock('../utils/errorReporting.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
vi.mock('../utils/errors.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../utils/errors.js')>();
|
||||
return {
|
||||
...actual,
|
||||
getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/retry.js', () => ({
|
||||
retryWithBackoff: vi.fn(async (fn) => await fn()),
|
||||
}));
|
||||
|
||||
const mockGenerateContent = vi.fn();
|
||||
|
||||
const mockContentGenerator = {
|
||||
generateContent: mockGenerateContent,
|
||||
} as unknown as Mocked<ContentGenerator>;
|
||||
|
||||
const mockConfig = {
|
||||
getSessionId: vi.fn().mockReturnValue('test-session-id'),
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
// Helper to create a mock GenerateContentResponse
|
||||
const createMockResponse = (text: string): GenerateContentResponse =>
|
||||
({
|
||||
candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }],
|
||||
}) as GenerateContentResponse;
|
||||
|
||||
describe('BaseLlmClient', () => {
|
||||
let client: BaseLlmClient;
|
||||
let abortController: AbortController;
|
||||
let defaultOptions: GenerateJsonOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset the mocked implementation for getErrorMessage for accurate error message assertions
|
||||
vi.mocked(getErrorMessage).mockImplementation((e) =>
|
||||
e instanceof Error ? e.message : String(e),
|
||||
);
|
||||
client = new BaseLlmClient(mockContentGenerator, mockConfig);
|
||||
abortController = new AbortController();
|
||||
defaultOptions = {
|
||||
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
||||
schema: { type: 'object', properties: { color: { type: 'string' } } },
|
||||
model: 'test-model',
|
||||
abortSignal: abortController.signal,
|
||||
promptId: 'test-prompt-id',
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
abortController.abort();
|
||||
});
|
||||
|
||||
describe('generateJson - Success Scenarios', () => {
|
||||
it('should call generateContent with correct parameters, defaults, and utilize retry mechanism', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "blue"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
|
||||
const result = await client.generateJson(defaultOptions);
|
||||
|
||||
expect(result).toEqual({ color: 'blue' });
|
||||
|
||||
// Ensure the retry mechanism was engaged
|
||||
expect(retryWithBackoff).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Validate the parameters passed to the underlying generator
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
{
|
||||
model: 'test-model',
|
||||
contents: defaultOptions.contents,
|
||||
config: {
|
||||
abortSignal: defaultOptions.abortSignal,
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
responseJsonSchema: defaultOptions.schema,
|
||||
responseMimeType: 'application/json',
|
||||
// Crucial: systemInstruction should NOT be in the config object if not provided
|
||||
},
|
||||
},
|
||||
'test-prompt-id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect configuration overrides', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "red"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
|
||||
const options: GenerateJsonOptions = {
|
||||
...defaultOptions,
|
||||
config: { temperature: 0.8, topK: 10 },
|
||||
};
|
||||
|
||||
await client.generateJson(options);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
temperature: 0.8,
|
||||
topP: 1, // Default should remain if not overridden
|
||||
topK: 10,
|
||||
}),
|
||||
}),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should include system instructions when provided', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "green"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
const systemInstruction = 'You are a helpful assistant.';
|
||||
|
||||
const options: GenerateJsonOptions = {
|
||||
...defaultOptions,
|
||||
systemInstruction,
|
||||
};
|
||||
|
||||
await client.generateJson(options);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
systemInstruction,
|
||||
}),
|
||||
}),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the provided promptId', async () => {
|
||||
const mockResponse = createMockResponse('{"color": "yellow"}');
|
||||
mockGenerateContent.mockResolvedValue(mockResponse);
|
||||
const customPromptId = 'custom-id-123';
|
||||
|
||||
const options: GenerateJsonOptions = {
|
||||
...defaultOptions,
|
||||
promptId: customPromptId,
|
||||
};
|
||||
|
||||
await client.generateJson(options);
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
customPromptId,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateJson - Response Cleaning', () => {
|
||||
it('should clean JSON wrapped in markdown backticks and log telemetry', async () => {
|
||||
const malformedResponse = '```json\n{"color": "purple"}\n```';
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse(malformedResponse),
|
||||
);
|
||||
|
||||
const result = await client.generateJson(defaultOptions);
|
||||
|
||||
expect(result).toEqual({ color: 'purple' });
|
||||
expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1);
|
||||
expect(logMalformedJsonResponse).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.any(MalformedJsonResponseEvent),
|
||||
);
|
||||
// Validate the telemetry event content
|
||||
const event = vi.mocked(logMalformedJsonResponse).mock
|
||||
.calls[0][1] as MalformedJsonResponseEvent;
|
||||
expect(event.model).toBe('test-model');
|
||||
});
|
||||
|
||||
it('should handle extra whitespace correctly without logging malformed telemetry', async () => {
|
||||
const responseWithWhitespace = ' \n {"color": "orange"} \n';
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse(responseWithWhitespace),
|
||||
);
|
||||
|
||||
const result = await client.generateJson(defaultOptions);
|
||||
|
||||
expect(result).toEqual({ color: 'orange' });
|
||||
expect(logMalformedJsonResponse).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateJson - Error Handling', () => {
|
||||
it('should throw and report error for empty response', async () => {
|
||||
mockGenerateContent.mockResolvedValue(createMockResponse(''));
|
||||
|
||||
// The final error message includes the prefix added by the client's outer catch block.
|
||||
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
|
||||
'Failed to generate JSON content: API returned an empty response for generateJson.',
|
||||
);
|
||||
|
||||
// Verify error reporting details
|
||||
expect(reportError).toHaveBeenCalledTimes(1);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
expect.any(Error),
|
||||
'Error in generateJson: API returned an empty response.',
|
||||
defaultOptions.contents,
|
||||
'generateJson-empty-response',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw and report error for invalid JSON syntax', async () => {
|
||||
const invalidJson = '{"color": "blue"'; // missing closing brace
|
||||
mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson));
|
||||
|
||||
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
|
||||
/^Failed to generate JSON content: Failed to parse API response as JSON:/,
|
||||
);
|
||||
|
||||
expect(reportError).toHaveBeenCalledTimes(1);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
expect.any(Error),
|
||||
'Failed to parse JSON response from generateJson.',
|
||||
expect.objectContaining({ responseTextFailedToParse: invalidJson }),
|
||||
'generateJson-parse',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw and report generic API errors', async () => {
|
||||
const apiError = new Error('Service Unavailable (503)');
|
||||
// Simulate the generator failing
|
||||
mockGenerateContent.mockRejectedValue(apiError);
|
||||
|
||||
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
|
||||
'Failed to generate JSON content: Service Unavailable (503)',
|
||||
);
|
||||
|
||||
// Verify generic error reporting
|
||||
expect(reportError).toHaveBeenCalledTimes(1);
|
||||
expect(reportError).toHaveBeenCalledWith(
|
||||
apiError,
|
||||
'Error generating JSON content via API.',
|
||||
defaultOptions.contents,
|
||||
'generateJson-api',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw immediately without reporting if aborted', async () => {
|
||||
const abortError = new DOMException('Aborted', 'AbortError');
|
||||
|
||||
// Simulate abortion happening during the API call
|
||||
mockGenerateContent.mockImplementation(() => {
|
||||
abortController.abort(); // Ensure the signal is aborted when the service checks
|
||||
throw abortError;
|
||||
});
|
||||
|
||||
const options = {
|
||||
...defaultOptions,
|
||||
abortSignal: abortController.signal,
|
||||
};
|
||||
|
||||
await expect(client.generateJson(options)).rejects.toThrow(abortError);
|
||||
|
||||
// Crucially, it should not report a cancellation as an application error
|
||||
expect(reportError).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
171
packages/core/src/core/baseLlmClient.ts
Normal file
171
packages/core/src/core/baseLlmClient.ts
Normal file
@@ -0,0 +1,171 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content, GenerateContentConfig, Part } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { getResponseText } from '../utils/partUtils.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
|
||||
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
|
||||
/**
|
||||
* Options for the generateJson utility function.
|
||||
*/
|
||||
export interface GenerateJsonOptions {
|
||||
/** The input prompt or history. */
|
||||
contents: Content[];
|
||||
/** The required JSON schema for the output. */
|
||||
schema: Record<string, unknown>;
|
||||
/** The specific model to use for this task. */
|
||||
model: string;
|
||||
/**
|
||||
* Task-specific system instructions.
|
||||
* If omitted, no system instruction is sent.
|
||||
*/
|
||||
systemInstruction?: string | Part | Part[] | Content;
|
||||
/**
|
||||
* Overrides for generation configuration (e.g., temperature).
|
||||
*/
|
||||
config?: Omit<
|
||||
GenerateContentConfig,
|
||||
| 'systemInstruction'
|
||||
| 'responseJsonSchema'
|
||||
| 'responseMimeType'
|
||||
| 'tools'
|
||||
| 'abortSignal'
|
||||
>;
|
||||
/** Signal for cancellation. */
|
||||
abortSignal: AbortSignal;
|
||||
/**
|
||||
* A unique ID for the prompt, used for logging/telemetry correlation.
|
||||
*/
|
||||
promptId: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* A client dedicated to stateless, utility-focused LLM calls.
|
||||
*/
|
||||
export class BaseLlmClient {
|
||||
// Default configuration for utility tasks
|
||||
private readonly defaultUtilityConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
};
|
||||
|
||||
constructor(
|
||||
private readonly contentGenerator: ContentGenerator,
|
||||
private readonly config: Config,
|
||||
) {}
|
||||
|
||||
async generateJson(
|
||||
options: GenerateJsonOptions,
|
||||
): Promise<Record<string, unknown>> {
|
||||
const {
|
||||
contents,
|
||||
schema,
|
||||
model,
|
||||
abortSignal,
|
||||
systemInstruction,
|
||||
promptId,
|
||||
} = options;
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
abortSignal,
|
||||
...this.defaultUtilityConfig,
|
||||
...options.config,
|
||||
...(systemInstruction && { systemInstruction }),
|
||||
responseJsonSchema: schema,
|
||||
responseMimeType: 'application/json',
|
||||
};
|
||||
|
||||
try {
|
||||
const apiCall = () =>
|
||||
this.contentGenerator.generateContent(
|
||||
{
|
||||
model,
|
||||
config: requestConfig,
|
||||
contents,
|
||||
},
|
||||
promptId,
|
||||
);
|
||||
|
||||
const result = await retryWithBackoff(apiCall);
|
||||
|
||||
let text = getResponseText(result)?.trim();
|
||||
if (!text) {
|
||||
const error = new Error(
|
||||
'API returned an empty response for generateJson.',
|
||||
);
|
||||
await reportError(
|
||||
error,
|
||||
'Error in generateJson: API returned an empty response.',
|
||||
contents,
|
||||
'generateJson-empty-response',
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
|
||||
text = this.cleanJsonResponse(text, model);
|
||||
|
||||
try {
|
||||
return JSON.parse(text);
|
||||
} catch (parseError) {
|
||||
const error = new Error(
|
||||
`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`,
|
||||
);
|
||||
await reportError(
|
||||
parseError,
|
||||
'Failed to parse JSON response from generateJson.',
|
||||
{
|
||||
responseTextFailedToParse: text,
|
||||
originalRequestContents: contents,
|
||||
},
|
||||
'generateJson-parse',
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
if (
|
||||
error instanceof Error &&
|
||||
(error.message === 'API returned an empty response for generateJson.' ||
|
||||
error.message.startsWith('Failed to parse API response as JSON:'))
|
||||
) {
|
||||
// We perform this check so that we don't report these again.
|
||||
} else {
|
||||
await reportError(
|
||||
error,
|
||||
'Error generating JSON content via API.',
|
||||
contents,
|
||||
'generateJson-api',
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`Failed to generate JSON content: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private cleanJsonResponse(text: string, model: string): string {
|
||||
const prefix = '```json';
|
||||
const suffix = '```';
|
||||
if (text.startsWith(prefix) && text.endsWith(suffix)) {
|
||||
logMalformedJsonResponse(
|
||||
this.config,
|
||||
new MalformedJsonResponseEvent(model),
|
||||
);
|
||||
return text.substring(prefix.length, text.length - suffix.length).trim();
|
||||
}
|
||||
return text;
|
||||
}
|
||||
}
|
||||
@@ -50,6 +50,7 @@ export * from './utils/workspaceContext.js';
|
||||
export * from './utils/ignorePatterns.js';
|
||||
export * from './utils/partUtils.js';
|
||||
export * from './utils/ide-trust.js';
|
||||
export * from './utils/promptIdContext.js';
|
||||
|
||||
// Export services
|
||||
export * from './services/fileDiscoveryService.js';
|
||||
|
||||
@@ -60,6 +60,7 @@ import { ApprovalMode, type Config } from '../config/config.js';
|
||||
import { type Content, type Part, type SchemaUnion } from '@google/genai';
|
||||
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
describe('SmartEditTool', () => {
|
||||
let tool: SmartEditTool;
|
||||
@@ -67,6 +68,7 @@ describe('SmartEditTool', () => {
|
||||
let rootDir: string;
|
||||
let mockConfig: Config;
|
||||
let geminiClient: any;
|
||||
let baseLlmClient: BaseLlmClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
@@ -78,8 +80,13 @@ describe('SmartEditTool', () => {
|
||||
generateJson: mockGenerateJson,
|
||||
};
|
||||
|
||||
baseLlmClient = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as BaseLlmClient;
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: vi.fn().mockReturnValue(geminiClient),
|
||||
getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient),
|
||||
getTargetDir: () => rootDir,
|
||||
getApprovalMode: vi.fn(),
|
||||
setApprovalMode: vi.fn(),
|
||||
|
||||
@@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
|
||||
params.new_string,
|
||||
initialError.raw,
|
||||
currentContent,
|
||||
this.config.getGeminiClient(),
|
||||
this.config.getBaseLlmClient(),
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
|
||||
203
packages/core/src/utils/llm-edit-fixer.test.ts
Normal file
203
packages/core/src/utils/llm-edit-fixer.test.ts
Normal file
@@ -0,0 +1,203 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
FixLLMEditWithInstruction,
|
||||
resetLlmEditFixerCaches_TEST_ONLY,
|
||||
type SearchReplaceEdit,
|
||||
} from './llm-edit-fixer.js';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
import type { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
|
||||
// Mock the BaseLlmClient
|
||||
const mockGenerateJson = vi.fn();
|
||||
const mockBaseLlmClient = {
|
||||
generateJson: mockGenerateJson,
|
||||
} as unknown as BaseLlmClient;
|
||||
|
||||
describe('FixLLMEditWithInstruction', () => {
|
||||
const instruction = 'Replace the title';
|
||||
const old_string = '<h1>Old Title</h1>';
|
||||
const new_string = '<h1>New Title</h1>';
|
||||
const error = 'String not found';
|
||||
const current_content = '<body><h1>Old Title</h1></body>';
|
||||
const abortController = new AbortController();
|
||||
const abortSignal = abortController.signal;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers(); // Reset timers after each test
|
||||
});
|
||||
|
||||
const mockApiResponse: SearchReplaceEdit = {
|
||||
search: '<h1>Old Title</h1>',
|
||||
replace: '<h1>New Title</h1>',
|
||||
noChangesRequired: false,
|
||||
explanation: 'The original search was correct.',
|
||||
};
|
||||
|
||||
it('should use the promptId from the AsyncLocalStorage context when available', async () => {
|
||||
const testPromptId = 'test-prompt-id-12345';
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
});
|
||||
|
||||
// Verify that generateJson was called with the promptId from the context
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
promptId: testPromptId,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should generate and use a fallback promptId when context is not available', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
|
||||
// Run the function outside of any context
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Verify the warning was logged
|
||||
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining(
|
||||
'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-',
|
||||
),
|
||||
);
|
||||
|
||||
// Verify that generateJson was called with the generated fallback promptId
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(mockGenerateJson).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
promptId: expect.stringContaining('llm-fixer-fallback-'),
|
||||
}),
|
||||
);
|
||||
|
||||
// Restore mocks
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should construct the user prompt correctly', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const promptId = 'test-prompt-id-prompt-construction';
|
||||
|
||||
await promptIdContext.run(promptId, async () => {
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
});
|
||||
|
||||
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
|
||||
const userPromptContent = generateJsonCall.contents[0].parts[0].text;
|
||||
|
||||
expect(userPromptContent).toContain(
|
||||
`<instruction>\n${instruction}\n</instruction>`,
|
||||
);
|
||||
expect(userPromptContent).toContain(`<search>\n${old_string}\n</search>`);
|
||||
expect(userPromptContent).toContain(`<replace>\n${new_string}\n</replace>`);
|
||||
expect(userPromptContent).toContain(`<error>\n${error}\n</error>`);
|
||||
expect(userPromptContent).toContain(
|
||||
`<file_content>\n${current_content}\n</file_content>`,
|
||||
);
|
||||
});
|
||||
|
||||
it('should return a cached result on subsequent identical calls', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const testPromptId = 'test-prompt-id-caching';
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
// First call - should call the API
|
||||
const result1 = await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Second call with identical parameters - should hit the cache
|
||||
const result2 = await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
expect(result1).toEqual(mockApiResponse);
|
||||
expect(result2).toEqual(mockApiResponse);
|
||||
// Verify the underlying service was only called ONCE
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should not use cache for calls with different parameters', async () => {
|
||||
mockGenerateJson.mockResolvedValue(mockApiResponse);
|
||||
const testPromptId = 'test-prompt-id-cache-miss';
|
||||
|
||||
await promptIdContext.run(testPromptId, async () => {
|
||||
// First call
|
||||
await FixLLMEditWithInstruction(
|
||||
instruction,
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Second call with a different instruction
|
||||
await FixLLMEditWithInstruction(
|
||||
'A different instruction',
|
||||
old_string,
|
||||
new_string,
|
||||
error,
|
||||
current_content,
|
||||
mockBaseLlmClient,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
// Verify the underlying service was called TWICE
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,9 +5,10 @@
|
||||
*/
|
||||
|
||||
import { type Content, Type } from '@google/genai';
|
||||
import { type GeminiClient } from '../core/client.js';
|
||||
import { type BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { LruCache } from './LruCache.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { promptIdContext } from './promptIdContext.js';
|
||||
|
||||
const MAX_CACHE_SIZE = 50;
|
||||
|
||||
@@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache<
|
||||
* @param new_string The original replacement string.
|
||||
* @param error The error that occurred during the initial edit.
|
||||
* @param current_content The current content of the file.
|
||||
* @param geminiClient The Gemini client to use for the LLM call.
|
||||
* @param baseLlmClient The BaseLlmClient to use for the LLM call.
|
||||
* @param abortSignal An abort signal to cancel the operation.
|
||||
* @param promptId A unique ID for the prompt.
|
||||
* @returns A new search and replace pair.
|
||||
*/
|
||||
export async function FixLLMEditWithInstruction(
|
||||
@@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction(
|
||||
new_string: string,
|
||||
error: string,
|
||||
current_content: string,
|
||||
geminiClient: GeminiClient,
|
||||
baseLlmClient: BaseLlmClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<SearchReplaceEdit> {
|
||||
let promptId = promptIdContext.getStore();
|
||||
if (!promptId) {
|
||||
promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`;
|
||||
console.warn(
|
||||
`Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
|
||||
);
|
||||
}
|
||||
|
||||
const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`;
|
||||
const cachedResult = editCorrectionWithInstructionCache.get(cacheKey);
|
||||
if (cachedResult) {
|
||||
@@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction(
|
||||
const contents: Content[] = [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
text: `${EDIT_SYS_PROMPT}
|
||||
${userPrompt}`,
|
||||
},
|
||||
],
|
||||
parts: [{ text: userPrompt }],
|
||||
},
|
||||
];
|
||||
|
||||
const result = (await geminiClient.generateJson(
|
||||
const result = (await baseLlmClient.generateJson({
|
||||
contents,
|
||||
SearchReplaceEditSchema,
|
||||
schema: SearchReplaceEditSchema,
|
||||
abortSignal,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
)) as unknown as SearchReplaceEdit;
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
systemInstruction: EDIT_SYS_PROMPT,
|
||||
promptId,
|
||||
})) as unknown as SearchReplaceEdit;
|
||||
|
||||
editCorrectionWithInstructionCache.set(cacheKey, result);
|
||||
return result;
|
||||
|
||||
9
packages/core/src/utils/promptIdContext.ts
Normal file
9
packages/core/src/utils/promptIdContext.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
export const promptIdContext = new AsyncLocalStorage<string>();
|
||||
Reference in New Issue
Block a user