From 1eaf21f6a288cd96bf648eb93d73634fd9b2c6df Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Tue, 9 Sep 2025 01:14:15 -0400 Subject: [PATCH] refactor(core): Introduce LlmUtilityService and promptIdContext (#7952) --- packages/cli/src/nonInteractiveCli.ts | 195 ++++++------ packages/cli/src/ui/hooks/useGeminiStream.ts | 112 +++---- packages/core/src/config/config.test.ts | 59 ++++ packages/core/src/config/config.ts | 25 ++ packages/core/src/core/baseLlmClient.test.ts | 291 ++++++++++++++++++ packages/core/src/core/baseLlmClient.ts | 171 ++++++++++ packages/core/src/index.ts | 1 + packages/core/src/tools/smart-edit.test.ts | 7 + packages/core/src/tools/smart-edit.ts | 2 +- .../core/src/utils/llm-edit-fixer.test.ts | 203 ++++++++++++ packages/core/src/utils/llm-edit-fixer.ts | 33 +- packages/core/src/utils/promptIdContext.ts | 9 + 12 files changed, 943 insertions(+), 165 deletions(-) create mode 100644 packages/core/src/core/baseLlmClient.test.ts create mode 100644 packages/core/src/core/baseLlmClient.ts create mode 100644 packages/core/src/utils/llm-edit-fixer.test.ts create mode 100644 packages/core/src/utils/promptIdContext.ts diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 73e8ae2371..ff33bd86ec 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -13,6 +13,7 @@ import { parseAndFormatApiError, FatalInputError, FatalTurnLimitedError, + promptIdContext, } from '@google/gemini-cli-core'; import type { Content, Part } from '@google/genai'; @@ -24,115 +25,117 @@ export async function runNonInteractive( input: string, prompt_id: string, ): Promise { - const consolePatcher = new ConsolePatcher({ - stderr: true, - debugMode: config.getDebugMode(), - }); - - try { - consolePatcher.patch(); - // Handle EPIPE errors when the output is piped to a command that closes early. - process.stdout.on('error', (err: NodeJS.ErrnoException) => { - if (err.code === 'EPIPE') { - // Exit gracefully if the pipe is closed. - process.exit(0); - } + return promptIdContext.run(prompt_id, async () => { + const consolePatcher = new ConsolePatcher({ + stderr: true, + debugMode: config.getDebugMode(), }); - const geminiClient = config.getGeminiClient(); + try { + consolePatcher.patch(); + // Handle EPIPE errors when the output is piped to a command that closes early. + process.stdout.on('error', (err: NodeJS.ErrnoException) => { + if (err.code === 'EPIPE') { + // Exit gracefully if the pipe is closed. + process.exit(0); + } + }); - const abortController = new AbortController(); + const geminiClient = config.getGeminiClient(); - const { processedQuery, shouldProceed } = await handleAtCommand({ - query: input, - config, - addItem: (_item, _timestamp) => 0, - onDebugMessage: () => {}, - messageId: Date.now(), - signal: abortController.signal, - }); + const abortController = new AbortController(); - if (!shouldProceed || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. - throw new FatalInputError( - 'Exiting due to an error processing the @ command.', - ); - } + const { processedQuery, shouldProceed } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + }); - let currentMessages: Content[] = [ - { role: 'user', parts: processedQuery as Part[] }, - ]; - - let turnCount = 0; - while (true) { - turnCount++; - if ( - config.getMaxSessionTurns() >= 0 && - turnCount > config.getMaxSessionTurns() - ) { - throw new FatalTurnLimitedError( - 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + if (!shouldProceed || !processedQuery) { + // An error occurred during @include processing (e.g., file not found). + // The error message is already logged by handleAtCommand. + throw new FatalInputError( + 'Exiting due to an error processing the @ command.', ); } - const toolCallRequests: ToolCallRequestInfo[] = []; - const responseStream = geminiClient.sendMessageStream( - currentMessages[0]?.parts || [], - abortController.signal, - prompt_id, - ); + let currentMessages: Content[] = [ + { role: 'user', parts: processedQuery as Part[] }, + ]; - for await (const event of responseStream) { - if (abortController.signal.aborted) { - console.error('Operation cancelled.'); + let turnCount = 0; + while (true) { + turnCount++; + if ( + config.getMaxSessionTurns() >= 0 && + turnCount > config.getMaxSessionTurns() + ) { + throw new FatalTurnLimitedError( + 'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.', + ); + } + const toolCallRequests: ToolCallRequestInfo[] = []; + + const responseStream = geminiClient.sendMessageStream( + currentMessages[0]?.parts || [], + abortController.signal, + prompt_id, + ); + + for await (const event of responseStream) { + if (abortController.signal.aborted) { + console.error('Operation cancelled.'); + return; + } + + if (event.type === GeminiEventType.Content) { + process.stdout.write(event.value); + } else if (event.type === GeminiEventType.ToolCallRequest) { + toolCallRequests.push(event.value); + } + } + + if (toolCallRequests.length > 0) { + const toolResponseParts: Part[] = []; + for (const requestInfo of toolCallRequests) { + const toolResponse = await executeToolCall( + config, + requestInfo, + abortController.signal, + ); + + if (toolResponse.error) { + console.error( + `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, + ); + } + + if (toolResponse.responseParts) { + toolResponseParts.push(...toolResponse.responseParts); + } + } + currentMessages = [{ role: 'user', parts: toolResponseParts }]; + } else { + process.stdout.write('\n'); // Ensure a final newline return; } - - if (event.type === GeminiEventType.Content) { - process.stdout.write(event.value); - } else if (event.type === GeminiEventType.ToolCallRequest) { - toolCallRequests.push(event.value); - } } - - if (toolCallRequests.length > 0) { - const toolResponseParts: Part[] = []; - for (const requestInfo of toolCallRequests) { - const toolResponse = await executeToolCall( - config, - requestInfo, - abortController.signal, - ); - - if (toolResponse.error) { - console.error( - `Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`, - ); - } - - if (toolResponse.responseParts) { - toolResponseParts.push(...toolResponse.responseParts); - } - } - currentMessages = [{ role: 'user', parts: toolResponseParts }]; - } else { - process.stdout.write('\n'); // Ensure a final newline - return; + } catch (error) { + console.error( + parseAndFormatApiError( + error, + config.getContentGeneratorConfig()?.authType, + ), + ); + throw error; + } finally { + consolePatcher.cleanup(); + if (isTelemetrySdkInitialized()) { + await shutdownTelemetry(config); } } - } catch (error) { - console.error( - parseAndFormatApiError( - error, - config.getContentGeneratorConfig()?.authType, - ), - ); - throw error; - } finally { - consolePatcher.cleanup(); - if (isTelemetrySdkInitialized()) { - await shutdownTelemetry(config); - } - } + }); } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index edaf33d547..ac92068edb 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -33,6 +33,7 @@ import { parseAndFormatApiError, getCodeAssistServer, UserTierId, + promptIdContext, } from '@google/gemini-cli-core'; import { type Part, type PartListUnion, FinishReason } from '@google/genai'; import type { @@ -705,71 +706,72 @@ export const useGeminiStream = ( if (!prompt_id) { prompt_id = config.getSessionId() + '########' + getPromptCount(); } - - const { queryToSend, shouldProceed } = await prepareQueryForGemini( - query, - userMessageTimestamp, - abortSignal, - prompt_id!, - ); - - if (!shouldProceed || queryToSend === null) { - return; - } - - if (!options?.isContinuation) { - startNewPrompt(); - setThought(null); // Reset thought when starting a new prompt - } - - setIsResponding(true); - setInitError(null); - - try { - const stream = geminiClient.sendMessageStream( - queryToSend, - abortSignal, - prompt_id!, - ); - const processingStatus = await processGeminiStreamEvents( - stream, + return promptIdContext.run(prompt_id, async () => { + const { queryToSend, shouldProceed } = await prepareQueryForGemini( + query, userMessageTimestamp, abortSignal, + prompt_id, ); - if (processingStatus === StreamProcessingStatus.UserCancelled) { + if (!shouldProceed || queryToSend === null) { return; } - if (pendingHistoryItemRef.current) { - addItem(pendingHistoryItemRef.current, userMessageTimestamp); - setPendingHistoryItem(null); + if (!options?.isContinuation) { + startNewPrompt(); + setThought(null); // Reset thought when starting a new prompt } - if (loopDetectedRef.current) { - loopDetectedRef.current = false; - handleLoopDetectedEvent(); - } - } catch (error: unknown) { - if (error instanceof UnauthorizedError) { - onAuthError('Session expired or is unauthorized.'); - } else if (!isNodeError(error) || error.name !== 'AbortError') { - addItem( - { - type: MessageType.ERROR, - text: parseAndFormatApiError( - getErrorMessage(error) || 'Unknown error', - config.getContentGeneratorConfig()?.authType, - undefined, - config.getModel(), - DEFAULT_GEMINI_FLASH_MODEL, - ), - }, - userMessageTimestamp, + + setIsResponding(true); + setInitError(null); + + try { + const stream = geminiClient.sendMessageStream( + queryToSend, + abortSignal, + prompt_id, ); + const processingStatus = await processGeminiStreamEvents( + stream, + userMessageTimestamp, + abortSignal, + ); + + if (processingStatus === StreamProcessingStatus.UserCancelled) { + return; + } + + if (pendingHistoryItemRef.current) { + addItem(pendingHistoryItemRef.current, userMessageTimestamp); + setPendingHistoryItem(null); + } + if (loopDetectedRef.current) { + loopDetectedRef.current = false; + handleLoopDetectedEvent(); + } + } catch (error: unknown) { + if (error instanceof UnauthorizedError) { + onAuthError('Session expired or is unauthorized.'); + } else if (!isNodeError(error) || error.name !== 'AbortError') { + addItem( + { + type: MessageType.ERROR, + text: parseAndFormatApiError( + getErrorMessage(error) || 'Unknown error', + config.getContentGeneratorConfig()?.authType, + undefined, + config.getModel(), + DEFAULT_GEMINI_FLASH_MODEL, + ), + }, + userMessageTimestamp, + ); + } + } finally { + setIsResponding(false); } - } finally { - setIsResponding(false); - } + }); }, [ streamingState, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 13e051c0b9..81096b69d5 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -122,6 +122,10 @@ vi.mock('../ide/ide-client.js', () => ({ }, })); +import { BaseLlmClient } from '../core/baseLlmClient.js'; + +vi.mock('../core/baseLlmClient.js'); + describe('Server Config (config.ts)', () => { const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { @@ -774,3 +778,58 @@ describe('setApprovalMode with folder trust', () => { }); }); }); + +describe('BaseLlmClient Lifecycle', () => { + const MODEL = 'gemini-pro'; + const SANDBOX: SandboxConfig = { + command: 'docker', + image: 'gemini-cli-sandbox', + }; + const TARGET_DIR = '/path/to/target'; + const DEBUG_MODE = false; + const QUESTION = 'test question'; + const FULL_CONTEXT = false; + const USER_MEMORY = 'Test User Memory'; + const TELEMETRY_SETTINGS = { enabled: false }; + const EMBEDDING_MODEL = 'gemini-embedding'; + const SESSION_ID = 'test-session-id'; + const baseParams: ConfigParameters = { + cwd: '/tmp', + embeddingModel: EMBEDDING_MODEL, + sandbox: SANDBOX, + targetDir: TARGET_DIR, + debugMode: DEBUG_MODE, + question: QUESTION, + fullContext: FULL_CONTEXT, + userMemory: USER_MEMORY, + telemetry: TELEMETRY_SETTINGS, + sessionId: SESSION_ID, + model: MODEL, + usageStatisticsEnabled: false, + }; + + it('should throw an error if getBaseLlmClient is called before refreshAuth', () => { + const config = new Config(baseParams); + expect(() => config.getBaseLlmClient()).toThrow( + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', + ); + }); + + it('should successfully initialize BaseLlmClient after refreshAuth is called', async () => { + const config = new Config(baseParams); + const authType = AuthType.USE_GEMINI; + const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' }; + + vi.mocked(createContentGeneratorConfig).mockReturnValue(mockContentConfig); + + await config.refreshAuth(authType); + + // Should not throw + const llmService = config.getBaseLlmClient(); + expect(llmService).toBeDefined(); + expect(BaseLlmClient).toHaveBeenCalledWith( + config.getContentGenerator(), + config, + ); + }); +}); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 746f1051b2..c4dfc3a56c 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -31,6 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { WebSearchTool } from '../tools/web-search.js'; import { GeminiClient } from '../core/client.js'; +import { BaseLlmClient } from '../core/baseLlmClient.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { GitService } from '../services/gitService.js'; import type { TelemetryTarget } from '../telemetry/index.js'; @@ -257,6 +258,7 @@ export class Config { private readonly telemetrySettings: TelemetrySettings; private readonly usageStatisticsEnabled: boolean; private geminiClient!: GeminiClient; + private baseLlmClient!: BaseLlmClient; private readonly fileFiltering: { respectGitIgnore: boolean; respectGeminiIgnore: boolean; @@ -455,6 +457,9 @@ export class Config { // Only assign to instance properties after successful initialization this.contentGeneratorConfig = newContentGeneratorConfig; + // Initialize BaseLlmClient now that the ContentGenerator is available + this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this); + // Reset the session flag since we're explicitly changing auth and using default model this.inFallbackMode = false; } @@ -463,6 +468,26 @@ export class Config { return this.contentGenerator?.userTier; } + /** + * Provides access to the BaseLlmClient for stateless LLM operations. + */ + getBaseLlmClient(): BaseLlmClient { + if (!this.baseLlmClient) { + // Handle cases where initialization might be deferred or authentication failed + if (this.contentGenerator) { + this.baseLlmClient = new BaseLlmClient( + this.getContentGenerator(), + this, + ); + } else { + throw new Error( + 'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.', + ); + } + } + return this.baseLlmClient; + } + getSessionId(): string { return this.sessionId; } diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts new file mode 100644 index 0000000000..1b1787f5fd --- /dev/null +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -0,0 +1,291 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, +} from 'vitest'; + +import type { GenerateContentResponse } from '@google/genai'; +import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import type { Config } from '../config/config.js'; +import { AuthType } from './contentGenerator.js'; +import { reportError } from '../utils/errorReporting.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { retryWithBackoff } from '../utils/retry.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { getErrorMessage } from '../utils/errors.js'; + +vi.mock('../utils/errorReporting.js'); +vi.mock('../telemetry/loggers.js'); +vi.mock('../utils/errors.js', async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))), + }; +}); + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: vi.fn(async (fn) => await fn()), +})); + +const mockGenerateContent = vi.fn(); + +const mockContentGenerator = { + generateContent: mockGenerateContent, +} as unknown as Mocked; + +const mockConfig = { + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getContentGeneratorConfig: vi + .fn() + .mockReturnValue({ authType: AuthType.USE_GEMINI }), +} as unknown as Mocked; + +// Helper to create a mock GenerateContentResponse +const createMockResponse = (text: string): GenerateContentResponse => + ({ + candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }], + }) as GenerateContentResponse; + +describe('BaseLlmClient', () => { + let client: BaseLlmClient; + let abortController: AbortController; + let defaultOptions: GenerateJsonOptions; + + beforeEach(() => { + vi.clearAllMocks(); + // Reset the mocked implementation for getErrorMessage for accurate error message assertions + vi.mocked(getErrorMessage).mockImplementation((e) => + e instanceof Error ? e.message : String(e), + ); + client = new BaseLlmClient(mockContentGenerator, mockConfig); + abortController = new AbortController(); + defaultOptions = { + contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], + schema: { type: 'object', properties: { color: { type: 'string' } } }, + model: 'test-model', + abortSignal: abortController.signal, + promptId: 'test-prompt-id', + }; + }); + + afterEach(() => { + abortController.abort(); + }); + + describe('generateJson - Success Scenarios', () => { + it('should call generateContent with correct parameters, defaults, and utilize retry mechanism', async () => { + const mockResponse = createMockResponse('{"color": "blue"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'blue' }); + + // Ensure the retry mechanism was engaged + expect(retryWithBackoff).toHaveBeenCalledTimes(1); + + // Validate the parameters passed to the underlying generator + expect(mockGenerateContent).toHaveBeenCalledTimes(1); + expect(mockGenerateContent).toHaveBeenCalledWith( + { + model: 'test-model', + contents: defaultOptions.contents, + config: { + abortSignal: defaultOptions.abortSignal, + temperature: 0, + topP: 1, + responseJsonSchema: defaultOptions.schema, + responseMimeType: 'application/json', + // Crucial: systemInstruction should NOT be in the config object if not provided + }, + }, + 'test-prompt-id', + ); + }); + + it('should respect configuration overrides', async () => { + const mockResponse = createMockResponse('{"color": "red"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + const options: GenerateJsonOptions = { + ...defaultOptions, + config: { temperature: 0.8, topK: 10 }, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + temperature: 0.8, + topP: 1, // Default should remain if not overridden + topK: 10, + }), + }), + expect.any(String), + ); + }); + + it('should include system instructions when provided', async () => { + const mockResponse = createMockResponse('{"color": "green"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const systemInstruction = 'You are a helpful assistant.'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + systemInstruction, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + systemInstruction, + }), + }), + expect.any(String), + ); + }); + + it('should use the provided promptId', async () => { + const mockResponse = createMockResponse('{"color": "yellow"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const customPromptId = 'custom-id-123'; + + const options: GenerateJsonOptions = { + ...defaultOptions, + promptId: customPromptId, + }; + + await client.generateJson(options); + + expect(mockGenerateContent).toHaveBeenCalledWith( + expect.any(Object), + customPromptId, + ); + }); + }); + + describe('generateJson - Response Cleaning', () => { + it('should clean JSON wrapped in markdown backticks and log telemetry', async () => { + const malformedResponse = '```json\n{"color": "purple"}\n```'; + mockGenerateContent.mockResolvedValue( + createMockResponse(malformedResponse), + ); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'purple' }); + expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1); + expect(logMalformedJsonResponse).toHaveBeenCalledWith( + mockConfig, + expect.any(MalformedJsonResponseEvent), + ); + // Validate the telemetry event content + const event = vi.mocked(logMalformedJsonResponse).mock + .calls[0][1] as MalformedJsonResponseEvent; + expect(event.model).toBe('test-model'); + }); + + it('should handle extra whitespace correctly without logging malformed telemetry', async () => { + const responseWithWhitespace = ' \n {"color": "orange"} \n'; + mockGenerateContent.mockResolvedValue( + createMockResponse(responseWithWhitespace), + ); + + const result = await client.generateJson(defaultOptions); + + expect(result).toEqual({ color: 'orange' }); + expect(logMalformedJsonResponse).not.toHaveBeenCalled(); + }); + }); + + describe('generateJson - Error Handling', () => { + it('should throw and report error for empty response', async () => { + mockGenerateContent.mockResolvedValue(createMockResponse('')); + + // The final error message includes the prefix added by the client's outer catch block. + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: API returned an empty response for generateJson.', + ); + + // Verify error reporting details + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Error in generateJson: API returned an empty response.', + defaultOptions.contents, + 'generateJson-empty-response', + ); + }); + + it('should throw and report error for invalid JSON syntax', async () => { + const invalidJson = '{"color": "blue"'; // missing closing brace + mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson)); + + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + /^Failed to generate JSON content: Failed to parse API response as JSON:/, + ); + + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + expect.any(Error), + 'Failed to parse JSON response from generateJson.', + expect.objectContaining({ responseTextFailedToParse: invalidJson }), + 'generateJson-parse', + ); + }); + + it('should throw and report generic API errors', async () => { + const apiError = new Error('Service Unavailable (503)'); + // Simulate the generator failing + mockGenerateContent.mockRejectedValue(apiError); + + await expect(client.generateJson(defaultOptions)).rejects.toThrow( + 'Failed to generate JSON content: Service Unavailable (503)', + ); + + // Verify generic error reporting + expect(reportError).toHaveBeenCalledTimes(1); + expect(reportError).toHaveBeenCalledWith( + apiError, + 'Error generating JSON content via API.', + defaultOptions.contents, + 'generateJson-api', + ); + }); + + it('should throw immediately without reporting if aborted', async () => { + const abortError = new DOMException('Aborted', 'AbortError'); + + // Simulate abortion happening during the API call + mockGenerateContent.mockImplementation(() => { + abortController.abort(); // Ensure the signal is aborted when the service checks + throw abortError; + }); + + const options = { + ...defaultOptions, + abortSignal: abortController.signal, + }; + + await expect(client.generateJson(options)).rejects.toThrow(abortError); + + // Crucially, it should not report a cancellation as an application error + expect(reportError).not.toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts new file mode 100644 index 0000000000..25a92dabdd --- /dev/null +++ b/packages/core/src/core/baseLlmClient.ts @@ -0,0 +1,171 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content, GenerateContentConfig, Part } from '@google/genai'; +import type { Config } from '../config/config.js'; +import type { ContentGenerator } from './contentGenerator.js'; +import { getResponseText } from '../utils/partUtils.js'; +import { reportError } from '../utils/errorReporting.js'; +import { getErrorMessage } from '../utils/errors.js'; +import { logMalformedJsonResponse } from '../telemetry/loggers.js'; +import { MalformedJsonResponseEvent } from '../telemetry/types.js'; +import { retryWithBackoff } from '../utils/retry.js'; + +/** + * Options for the generateJson utility function. + */ +export interface GenerateJsonOptions { + /** The input prompt or history. */ + contents: Content[]; + /** The required JSON schema for the output. */ + schema: Record; + /** The specific model to use for this task. */ + model: string; + /** + * Task-specific system instructions. + * If omitted, no system instruction is sent. + */ + systemInstruction?: string | Part | Part[] | Content; + /** + * Overrides for generation configuration (e.g., temperature). + */ + config?: Omit< + GenerateContentConfig, + | 'systemInstruction' + | 'responseJsonSchema' + | 'responseMimeType' + | 'tools' + | 'abortSignal' + >; + /** Signal for cancellation. */ + abortSignal: AbortSignal; + /** + * A unique ID for the prompt, used for logging/telemetry correlation. + */ + promptId: string; +} + +/** + * A client dedicated to stateless, utility-focused LLM calls. + */ +export class BaseLlmClient { + // Default configuration for utility tasks + private readonly defaultUtilityConfig: GenerateContentConfig = { + temperature: 0, + topP: 1, + }; + + constructor( + private readonly contentGenerator: ContentGenerator, + private readonly config: Config, + ) {} + + async generateJson( + options: GenerateJsonOptions, + ): Promise> { + const { + contents, + schema, + model, + abortSignal, + systemInstruction, + promptId, + } = options; + + const requestConfig: GenerateContentConfig = { + abortSignal, + ...this.defaultUtilityConfig, + ...options.config, + ...(systemInstruction && { systemInstruction }), + responseJsonSchema: schema, + responseMimeType: 'application/json', + }; + + try { + const apiCall = () => + this.contentGenerator.generateContent( + { + model, + config: requestConfig, + contents, + }, + promptId, + ); + + const result = await retryWithBackoff(apiCall); + + let text = getResponseText(result)?.trim(); + if (!text) { + const error = new Error( + 'API returned an empty response for generateJson.', + ); + await reportError( + error, + 'Error in generateJson: API returned an empty response.', + contents, + 'generateJson-empty-response', + ); + throw error; + } + + text = this.cleanJsonResponse(text, model); + + try { + return JSON.parse(text); + } catch (parseError) { + const error = new Error( + `Failed to parse API response as JSON: ${getErrorMessage(parseError)}`, + ); + await reportError( + parseError, + 'Failed to parse JSON response from generateJson.', + { + responseTextFailedToParse: text, + originalRequestContents: contents, + }, + 'generateJson-parse', + ); + throw error; + } + } catch (error) { + if (abortSignal.aborted) { + throw error; + } + + if ( + error instanceof Error && + (error.message === 'API returned an empty response for generateJson.' || + error.message.startsWith('Failed to parse API response as JSON:')) + ) { + // We perform this check so that we don't report these again. + } else { + await reportError( + error, + 'Error generating JSON content via API.', + contents, + 'generateJson-api', + ); + } + + throw new Error( + `Failed to generate JSON content: ${getErrorMessage(error)}`, + ); + } + } + + private cleanJsonResponse(text: string, model: string): string { + const prefix = '```json'; + const suffix = '```'; + if (text.startsWith(prefix) && text.endsWith(suffix)) { + logMalformedJsonResponse( + this.config, + new MalformedJsonResponseEvent(model), + ); + return text.substring(prefix.length, text.length - suffix.length).trim(); + } + return text; + } +} diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 047e43a529..aa49b0df30 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -50,6 +50,7 @@ export * from './utils/workspaceContext.js'; export * from './utils/ignorePatterns.js'; export * from './utils/partUtils.js'; export * from './utils/ide-trust.js'; +export * from './utils/promptIdContext.js'; // Export services export * from './services/fileDiscoveryService.js'; diff --git a/packages/core/src/tools/smart-edit.test.ts b/packages/core/src/tools/smart-edit.test.ts index 132d993306..c72fcb48df 100644 --- a/packages/core/src/tools/smart-edit.test.ts +++ b/packages/core/src/tools/smart-edit.test.ts @@ -60,6 +60,7 @@ import { ApprovalMode, type Config } from '../config/config.js'; import { type Content, type Part, type SchemaUnion } from '@google/genai'; import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; describe('SmartEditTool', () => { let tool: SmartEditTool; @@ -67,6 +68,7 @@ describe('SmartEditTool', () => { let rootDir: string; let mockConfig: Config; let geminiClient: any; + let baseLlmClient: BaseLlmClient; beforeEach(() => { vi.restoreAllMocks(); @@ -78,8 +80,13 @@ describe('SmartEditTool', () => { generateJson: mockGenerateJson, }; + baseLlmClient = { + generateJson: mockGenerateJson, + } as unknown as BaseLlmClient; + mockConfig = { getGeminiClient: vi.fn().mockReturnValue(geminiClient), + getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient), getTargetDir: () => rootDir, getApprovalMode: vi.fn(), setApprovalMode: vi.fn(), diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index 3647b73246..6291296f13 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation { params.new_string, initialError.raw, currentContent, - this.config.getGeminiClient(), + this.config.getBaseLlmClient(), abortSignal, ); diff --git a/packages/core/src/utils/llm-edit-fixer.test.ts b/packages/core/src/utils/llm-edit-fixer.test.ts new file mode 100644 index 0000000000..4c236ad342 --- /dev/null +++ b/packages/core/src/utils/llm-edit-fixer.test.ts @@ -0,0 +1,203 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + FixLLMEditWithInstruction, + resetLlmEditFixerCaches_TEST_ONLY, + type SearchReplaceEdit, +} from './llm-edit-fixer.js'; +import { promptIdContext } from './promptIdContext.js'; +import type { BaseLlmClient } from '../core/baseLlmClient.js'; + +// Mock the BaseLlmClient +const mockGenerateJson = vi.fn(); +const mockBaseLlmClient = { + generateJson: mockGenerateJson, +} as unknown as BaseLlmClient; + +describe('FixLLMEditWithInstruction', () => { + const instruction = 'Replace the title'; + const old_string = '

Old Title

'; + const new_string = '

New Title

'; + const error = 'String not found'; + const current_content = '

Old Title

'; + const abortController = new AbortController(); + const abortSignal = abortController.signal; + + beforeEach(() => { + vi.clearAllMocks(); + resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test + }); + + afterEach(() => { + vi.useRealTimers(); // Reset timers after each test + }); + + const mockApiResponse: SearchReplaceEdit = { + search: '

Old Title

', + replace: '

New Title

', + noChangesRequired: false, + explanation: 'The original search was correct.', + }; + + it('should use the promptId from the AsyncLocalStorage context when available', async () => { + const testPromptId = 'test-prompt-id-12345'; + mockGenerateJson.mockResolvedValue(mockApiResponse); + + await promptIdContext.run(testPromptId, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + }); + + // Verify that generateJson was called with the promptId from the context + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: testPromptId, + }), + ); + }); + + it('should generate and use a fallback promptId when context is not available', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const consoleWarnSpy = vi + .spyOn(console, 'warn') + .mockImplementation(() => {}); + + // Run the function outside of any context + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Verify the warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining( + 'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-', + ), + ); + + // Verify that generateJson was called with the generated fallback promptId + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + expect(mockGenerateJson).toHaveBeenCalledWith( + expect.objectContaining({ + promptId: expect.stringContaining('llm-fixer-fallback-'), + }), + ); + + // Restore mocks + consoleWarnSpy.mockRestore(); + }); + + it('should construct the user prompt correctly', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const promptId = 'test-prompt-id-prompt-construction'; + + await promptIdContext.run(promptId, async () => { + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + }); + + const generateJsonCall = mockGenerateJson.mock.calls[0][0]; + const userPromptContent = generateJsonCall.contents[0].parts[0].text; + + expect(userPromptContent).toContain( + `\n${instruction}\n`, + ); + expect(userPromptContent).toContain(`\n${old_string}\n`); + expect(userPromptContent).toContain(`\n${new_string}\n`); + expect(userPromptContent).toContain(`\n${error}\n`); + expect(userPromptContent).toContain( + `\n${current_content}\n`, + ); + }); + + it('should return a cached result on subsequent identical calls', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-caching'; + + await promptIdContext.run(testPromptId, async () => { + // First call - should call the API + const result1 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Second call with identical parameters - should hit the cache + const result2 = await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + expect(result1).toEqual(mockApiResponse); + expect(result2).toEqual(mockApiResponse); + // Verify the underlying service was only called ONCE + expect(mockGenerateJson).toHaveBeenCalledTimes(1); + }); + }); + + it('should not use cache for calls with different parameters', async () => { + mockGenerateJson.mockResolvedValue(mockApiResponse); + const testPromptId = 'test-prompt-id-cache-miss'; + + await promptIdContext.run(testPromptId, async () => { + // First call + await FixLLMEditWithInstruction( + instruction, + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Second call with a different instruction + await FixLLMEditWithInstruction( + 'A different instruction', + old_string, + new_string, + error, + current_content, + mockBaseLlmClient, + abortSignal, + ); + + // Verify the underlying service was called TWICE + expect(mockGenerateJson).toHaveBeenCalledTimes(2); + }); + }); +}); diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index 95496d4779..a4b4b131c0 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -5,9 +5,10 @@ */ import { type Content, Type } from '@google/genai'; -import { type GeminiClient } from '../core/client.js'; +import { type BaseLlmClient } from '../core/baseLlmClient.js'; import { LruCache } from './LruCache.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { promptIdContext } from './promptIdContext.js'; const MAX_CACHE_SIZE = 50; @@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache< * @param new_string The original replacement string. * @param error The error that occurred during the initial edit. * @param current_content The current content of the file. - * @param geminiClient The Gemini client to use for the LLM call. + * @param baseLlmClient The BaseLlmClient to use for the LLM call. * @param abortSignal An abort signal to cancel the operation. + * @param promptId A unique ID for the prompt. * @returns A new search and replace pair. */ export async function FixLLMEditWithInstruction( @@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction( new_string: string, error: string, current_content: string, - geminiClient: GeminiClient, + baseLlmClient: BaseLlmClient, abortSignal: AbortSignal, ): Promise { + let promptId = promptIdContext.getStore(); + if (!promptId) { + promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`; + console.warn( + `Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`, + ); + } + const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`; const cachedResult = editCorrectionWithInstructionCache.get(cacheKey); if (cachedResult) { @@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction( const contents: Content[] = [ { role: 'user', - parts: [ - { - text: `${EDIT_SYS_PROMPT} -${userPrompt}`, - }, - ], + parts: [{ text: userPrompt }], }, ]; - const result = (await geminiClient.generateJson( + const result = (await baseLlmClient.generateJson({ contents, - SearchReplaceEditSchema, + schema: SearchReplaceEditSchema, abortSignal, - DEFAULT_GEMINI_FLASH_MODEL, - )) as unknown as SearchReplaceEdit; + model: DEFAULT_GEMINI_FLASH_MODEL, + systemInstruction: EDIT_SYS_PROMPT, + promptId, + })) as unknown as SearchReplaceEdit; editCorrectionWithInstructionCache.set(cacheKey, result); return result; diff --git a/packages/core/src/utils/promptIdContext.ts b/packages/core/src/utils/promptIdContext.ts new file mode 100644 index 0000000000..6344bd0b83 --- /dev/null +++ b/packages/core/src/utils/promptIdContext.ts @@ -0,0 +1,9 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { AsyncLocalStorage } from 'node:async_hooks'; + +export const promptIdContext = new AsyncLocalStorage();