refactor(core): Introduce LlmUtilityService and promptIdContext (#7952)

This commit is contained in:
Abhi
2025-09-09 01:14:15 -04:00
committed by GitHub
parent 471cbcd450
commit 1eaf21f6a2
12 changed files with 943 additions and 165 deletions

View File

@@ -13,6 +13,7 @@ import {
parseAndFormatApiError,
FatalInputError,
FatalTurnLimitedError,
promptIdContext,
} from '@google/gemini-cli-core';
import type { Content, Part } from '@google/genai';
@@ -24,115 +25,117 @@ export async function runNonInteractive(
input: string,
prompt_id: string,
): Promise<void> {
const consolePatcher = new ConsolePatcher({
stderr: true,
debugMode: config.getDebugMode(),
});
try {
consolePatcher.patch();
// Handle EPIPE errors when the output is piped to a command that closes early.
process.stdout.on('error', (err: NodeJS.ErrnoException) => {
if (err.code === 'EPIPE') {
// Exit gracefully if the pipe is closed.
process.exit(0);
}
return promptIdContext.run(prompt_id, async () => {
const consolePatcher = new ConsolePatcher({
stderr: true,
debugMode: config.getDebugMode(),
});
const geminiClient = config.getGeminiClient();
try {
consolePatcher.patch();
// Handle EPIPE errors when the output is piped to a command that closes early.
process.stdout.on('error', (err: NodeJS.ErrnoException) => {
if (err.code === 'EPIPE') {
// Exit gracefully if the pipe is closed.
process.exit(0);
}
});
const abortController = new AbortController();
const geminiClient = config.getGeminiClient();
const { processedQuery, shouldProceed } = await handleAtCommand({
query: input,
config,
addItem: (_item, _timestamp) => 0,
onDebugMessage: () => {},
messageId: Date.now(),
signal: abortController.signal,
});
const abortController = new AbortController();
if (!shouldProceed || !processedQuery) {
// An error occurred during @include processing (e.g., file not found).
// The error message is already logged by handleAtCommand.
throw new FatalInputError(
'Exiting due to an error processing the @ command.',
);
}
const { processedQuery, shouldProceed } = await handleAtCommand({
query: input,
config,
addItem: (_item, _timestamp) => 0,
onDebugMessage: () => {},
messageId: Date.now(),
signal: abortController.signal,
});
let currentMessages: Content[] = [
{ role: 'user', parts: processedQuery as Part[] },
];
let turnCount = 0;
while (true) {
turnCount++;
if (
config.getMaxSessionTurns() >= 0 &&
turnCount > config.getMaxSessionTurns()
) {
throw new FatalTurnLimitedError(
'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
if (!shouldProceed || !processedQuery) {
// An error occurred during @include processing (e.g., file not found).
// The error message is already logged by handleAtCommand.
throw new FatalInputError(
'Exiting due to an error processing the @ command.',
);
}
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
);
let currentMessages: Content[] = [
{ role: 'user', parts: processedQuery as Part[] },
];
for await (const event of responseStream) {
if (abortController.signal.aborted) {
console.error('Operation cancelled.');
let turnCount = 0;
while (true) {
turnCount++;
if (
config.getMaxSessionTurns() >= 0 &&
turnCount > config.getMaxSessionTurns()
) {
throw new FatalTurnLimitedError(
'Reached max session turns for this session. Increase the number of turns by specifying maxSessionTurns in settings.json.',
);
}
const toolCallRequests: ToolCallRequestInfo[] = [];
const responseStream = geminiClient.sendMessageStream(
currentMessages[0]?.parts || [],
abortController.signal,
prompt_id,
);
for await (const event of responseStream) {
if (abortController.signal.aborted) {
console.error('Operation cancelled.');
return;
}
if (event.type === GeminiEventType.Content) {
process.stdout.write(event.value);
} else if (event.type === GeminiEventType.ToolCallRequest) {
toolCallRequests.push(event.value);
}
}
if (toolCallRequests.length > 0) {
const toolResponseParts: Part[] = [];
for (const requestInfo of toolCallRequests) {
const toolResponse = await executeToolCall(
config,
requestInfo,
abortController.signal,
);
if (toolResponse.error) {
console.error(
`Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
);
}
if (toolResponse.responseParts) {
toolResponseParts.push(...toolResponse.responseParts);
}
}
currentMessages = [{ role: 'user', parts: toolResponseParts }];
} else {
process.stdout.write('\n'); // Ensure a final newline
return;
}
if (event.type === GeminiEventType.Content) {
process.stdout.write(event.value);
} else if (event.type === GeminiEventType.ToolCallRequest) {
toolCallRequests.push(event.value);
}
}
if (toolCallRequests.length > 0) {
const toolResponseParts: Part[] = [];
for (const requestInfo of toolCallRequests) {
const toolResponse = await executeToolCall(
config,
requestInfo,
abortController.signal,
);
if (toolResponse.error) {
console.error(
`Error executing tool ${requestInfo.name}: ${toolResponse.resultDisplay || toolResponse.error.message}`,
);
}
if (toolResponse.responseParts) {
toolResponseParts.push(...toolResponse.responseParts);
}
}
currentMessages = [{ role: 'user', parts: toolResponseParts }];
} else {
process.stdout.write('\n'); // Ensure a final newline
return;
} catch (error) {
console.error(
parseAndFormatApiError(
error,
config.getContentGeneratorConfig()?.authType,
),
);
throw error;
} finally {
consolePatcher.cleanup();
if (isTelemetrySdkInitialized()) {
await shutdownTelemetry(config);
}
}
} catch (error) {
console.error(
parseAndFormatApiError(
error,
config.getContentGeneratorConfig()?.authType,
),
);
throw error;
} finally {
consolePatcher.cleanup();
if (isTelemetrySdkInitialized()) {
await shutdownTelemetry(config);
}
}
});
}

View File

@@ -33,6 +33,7 @@ import {
parseAndFormatApiError,
getCodeAssistServer,
UserTierId,
promptIdContext,
} from '@google/gemini-cli-core';
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
import type {
@@ -705,71 +706,72 @@ export const useGeminiStream = (
if (!prompt_id) {
prompt_id = config.getSessionId() + '########' + getPromptCount();
}
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
query,
userMessageTimestamp,
abortSignal,
prompt_id!,
);
if (!shouldProceed || queryToSend === null) {
return;
}
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
}
setIsResponding(true);
setInitError(null);
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
abortSignal,
prompt_id!,
);
const processingStatus = await processGeminiStreamEvents(
stream,
return promptIdContext.run(prompt_id, async () => {
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
query,
userMessageTimestamp,
abortSignal,
prompt_id,
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
if (!shouldProceed || queryToSend === null) {
return;
}
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
if (!options?.isContinuation) {
startNewPrompt();
setThought(null); // Reset thought when starting a new prompt
}
if (loopDetectedRef.current) {
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
} catch (error: unknown) {
if (error instanceof UnauthorizedError) {
onAuthError('Session expired or is unauthorized.');
} else if (!isNodeError(error) || error.name !== 'AbortError') {
addItem(
{
type: MessageType.ERROR,
text: parseAndFormatApiError(
getErrorMessage(error) || 'Unknown error',
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,
setIsResponding(true);
setInitError(null);
try {
const stream = geminiClient.sendMessageStream(
queryToSend,
abortSignal,
prompt_id,
);
const processingStatus = await processGeminiStreamEvents(
stream,
userMessageTimestamp,
abortSignal,
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
return;
}
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
setPendingHistoryItem(null);
}
if (loopDetectedRef.current) {
loopDetectedRef.current = false;
handleLoopDetectedEvent();
}
} catch (error: unknown) {
if (error instanceof UnauthorizedError) {
onAuthError('Session expired or is unauthorized.');
} else if (!isNodeError(error) || error.name !== 'AbortError') {
addItem(
{
type: MessageType.ERROR,
text: parseAndFormatApiError(
getErrorMessage(error) || 'Unknown error',
config.getContentGeneratorConfig()?.authType,
undefined,
config.getModel(),
DEFAULT_GEMINI_FLASH_MODEL,
),
},
userMessageTimestamp,
);
}
} finally {
setIsResponding(false);
}
} finally {
setIsResponding(false);
}
});
},
[
streamingState,

View File

@@ -122,6 +122,10 @@ vi.mock('../ide/ide-client.js', () => ({
},
}));
import { BaseLlmClient } from '../core/baseLlmClient.js';
vi.mock('../core/baseLlmClient.js');
describe('Server Config (config.ts)', () => {
const MODEL = 'gemini-pro';
const SANDBOX: SandboxConfig = {
@@ -774,3 +778,58 @@ describe('setApprovalMode with folder trust', () => {
});
});
});
describe('BaseLlmClient Lifecycle', () => {
const MODEL = 'gemini-pro';
const SANDBOX: SandboxConfig = {
command: 'docker',
image: 'gemini-cli-sandbox',
};
const TARGET_DIR = '/path/to/target';
const DEBUG_MODE = false;
const QUESTION = 'test question';
const FULL_CONTEXT = false;
const USER_MEMORY = 'Test User Memory';
const TELEMETRY_SETTINGS = { enabled: false };
const EMBEDDING_MODEL = 'gemini-embedding';
const SESSION_ID = 'test-session-id';
const baseParams: ConfigParameters = {
cwd: '/tmp',
embeddingModel: EMBEDDING_MODEL,
sandbox: SANDBOX,
targetDir: TARGET_DIR,
debugMode: DEBUG_MODE,
question: QUESTION,
fullContext: FULL_CONTEXT,
userMemory: USER_MEMORY,
telemetry: TELEMETRY_SETTINGS,
sessionId: SESSION_ID,
model: MODEL,
usageStatisticsEnabled: false,
};
it('should throw an error if getBaseLlmClient is called before refreshAuth', () => {
const config = new Config(baseParams);
expect(() => config.getBaseLlmClient()).toThrow(
'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.',
);
});
it('should successfully initialize BaseLlmClient after refreshAuth is called', async () => {
const config = new Config(baseParams);
const authType = AuthType.USE_GEMINI;
const mockContentConfig = { model: 'gemini-flash', apiKey: 'test-key' };
vi.mocked(createContentGeneratorConfig).mockReturnValue(mockContentConfig);
await config.refreshAuth(authType);
// Should not throw
const llmService = config.getBaseLlmClient();
expect(llmService).toBeDefined();
expect(BaseLlmClient).toHaveBeenCalledWith(
config.getContentGenerator(),
config,
);
});
});

View File

@@ -31,6 +31,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js';
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
import { WebSearchTool } from '../tools/web-search.js';
import { GeminiClient } from '../core/client.js';
import { BaseLlmClient } from '../core/baseLlmClient.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { GitService } from '../services/gitService.js';
import type { TelemetryTarget } from '../telemetry/index.js';
@@ -257,6 +258,7 @@ export class Config {
private readonly telemetrySettings: TelemetrySettings;
private readonly usageStatisticsEnabled: boolean;
private geminiClient!: GeminiClient;
private baseLlmClient!: BaseLlmClient;
private readonly fileFiltering: {
respectGitIgnore: boolean;
respectGeminiIgnore: boolean;
@@ -455,6 +457,9 @@ export class Config {
// Only assign to instance properties after successful initialization
this.contentGeneratorConfig = newContentGeneratorConfig;
// Initialize BaseLlmClient now that the ContentGenerator is available
this.baseLlmClient = new BaseLlmClient(this.contentGenerator, this);
// Reset the session flag since we're explicitly changing auth and using default model
this.inFallbackMode = false;
}
@@ -463,6 +468,26 @@ export class Config {
return this.contentGenerator?.userTier;
}
/**
* Provides access to the BaseLlmClient for stateless LLM operations.
*/
getBaseLlmClient(): BaseLlmClient {
if (!this.baseLlmClient) {
// Handle cases where initialization might be deferred or authentication failed
if (this.contentGenerator) {
this.baseLlmClient = new BaseLlmClient(
this.getContentGenerator(),
this,
);
} else {
throw new Error(
'BaseLlmClient not initialized. Ensure authentication has occurred and ContentGenerator is ready.',
);
}
}
return this.baseLlmClient;
}
getSessionId(): string {
return this.sessionId;
}

View File

@@ -0,0 +1,291 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mocked,
} from 'vitest';
import type { GenerateContentResponse } from '@google/genai';
import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js';
import type { ContentGenerator } from './contentGenerator.js';
import type { Config } from '../config/config.js';
import { AuthType } from './contentGenerator.js';
import { reportError } from '../utils/errorReporting.js';
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
import { retryWithBackoff } from '../utils/retry.js';
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { getErrorMessage } from '../utils/errors.js';
vi.mock('../utils/errorReporting.js');
vi.mock('../telemetry/loggers.js');
vi.mock('../utils/errors.js', async (importOriginal) => {
const actual = await importOriginal<typeof import('../utils/errors.js')>();
return {
...actual,
getErrorMessage: vi.fn((e) => (e instanceof Error ? e.message : String(e))),
};
});
vi.mock('../utils/retry.js', () => ({
retryWithBackoff: vi.fn(async (fn) => await fn()),
}));
const mockGenerateContent = vi.fn();
const mockContentGenerator = {
generateContent: mockGenerateContent,
} as unknown as Mocked<ContentGenerator>;
const mockConfig = {
getSessionId: vi.fn().mockReturnValue('test-session-id'),
getContentGeneratorConfig: vi
.fn()
.mockReturnValue({ authType: AuthType.USE_GEMINI }),
} as unknown as Mocked<Config>;
// Helper to create a mock GenerateContentResponse
const createMockResponse = (text: string): GenerateContentResponse =>
({
candidates: [{ content: { role: 'model', parts: [{ text }] }, index: 0 }],
}) as GenerateContentResponse;
describe('BaseLlmClient', () => {
let client: BaseLlmClient;
let abortController: AbortController;
let defaultOptions: GenerateJsonOptions;
beforeEach(() => {
vi.clearAllMocks();
// Reset the mocked implementation for getErrorMessage for accurate error message assertions
vi.mocked(getErrorMessage).mockImplementation((e) =>
e instanceof Error ? e.message : String(e),
);
client = new BaseLlmClient(mockContentGenerator, mockConfig);
abortController = new AbortController();
defaultOptions = {
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
schema: { type: 'object', properties: { color: { type: 'string' } } },
model: 'test-model',
abortSignal: abortController.signal,
promptId: 'test-prompt-id',
};
});
afterEach(() => {
abortController.abort();
});
describe('generateJson - Success Scenarios', () => {
it('should call generateContent with correct parameters, defaults, and utilize retry mechanism', async () => {
const mockResponse = createMockResponse('{"color": "blue"}');
mockGenerateContent.mockResolvedValue(mockResponse);
const result = await client.generateJson(defaultOptions);
expect(result).toEqual({ color: 'blue' });
// Ensure the retry mechanism was engaged
expect(retryWithBackoff).toHaveBeenCalledTimes(1);
// Validate the parameters passed to the underlying generator
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
expect(mockGenerateContent).toHaveBeenCalledWith(
{
model: 'test-model',
contents: defaultOptions.contents,
config: {
abortSignal: defaultOptions.abortSignal,
temperature: 0,
topP: 1,
responseJsonSchema: defaultOptions.schema,
responseMimeType: 'application/json',
// Crucial: systemInstruction should NOT be in the config object if not provided
},
},
'test-prompt-id',
);
});
it('should respect configuration overrides', async () => {
const mockResponse = createMockResponse('{"color": "red"}');
mockGenerateContent.mockResolvedValue(mockResponse);
const options: GenerateJsonOptions = {
...defaultOptions,
config: { temperature: 0.8, topK: 10 },
};
await client.generateJson(options);
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
temperature: 0.8,
topP: 1, // Default should remain if not overridden
topK: 10,
}),
}),
expect.any(String),
);
});
it('should include system instructions when provided', async () => {
const mockResponse = createMockResponse('{"color": "green"}');
mockGenerateContent.mockResolvedValue(mockResponse);
const systemInstruction = 'You are a helpful assistant.';
const options: GenerateJsonOptions = {
...defaultOptions,
systemInstruction,
};
await client.generateJson(options);
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
systemInstruction,
}),
}),
expect.any(String),
);
});
it('should use the provided promptId', async () => {
const mockResponse = createMockResponse('{"color": "yellow"}');
mockGenerateContent.mockResolvedValue(mockResponse);
const customPromptId = 'custom-id-123';
const options: GenerateJsonOptions = {
...defaultOptions,
promptId: customPromptId,
};
await client.generateJson(options);
expect(mockGenerateContent).toHaveBeenCalledWith(
expect.any(Object),
customPromptId,
);
});
});
describe('generateJson - Response Cleaning', () => {
it('should clean JSON wrapped in markdown backticks and log telemetry', async () => {
const malformedResponse = '```json\n{"color": "purple"}\n```';
mockGenerateContent.mockResolvedValue(
createMockResponse(malformedResponse),
);
const result = await client.generateJson(defaultOptions);
expect(result).toEqual({ color: 'purple' });
expect(logMalformedJsonResponse).toHaveBeenCalledTimes(1);
expect(logMalformedJsonResponse).toHaveBeenCalledWith(
mockConfig,
expect.any(MalformedJsonResponseEvent),
);
// Validate the telemetry event content
const event = vi.mocked(logMalformedJsonResponse).mock
.calls[0][1] as MalformedJsonResponseEvent;
expect(event.model).toBe('test-model');
});
it('should handle extra whitespace correctly without logging malformed telemetry', async () => {
const responseWithWhitespace = ' \n {"color": "orange"} \n';
mockGenerateContent.mockResolvedValue(
createMockResponse(responseWithWhitespace),
);
const result = await client.generateJson(defaultOptions);
expect(result).toEqual({ color: 'orange' });
expect(logMalformedJsonResponse).not.toHaveBeenCalled();
});
});
describe('generateJson - Error Handling', () => {
it('should throw and report error for empty response', async () => {
mockGenerateContent.mockResolvedValue(createMockResponse(''));
// The final error message includes the prefix added by the client's outer catch block.
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
'Failed to generate JSON content: API returned an empty response for generateJson.',
);
// Verify error reporting details
expect(reportError).toHaveBeenCalledTimes(1);
expect(reportError).toHaveBeenCalledWith(
expect.any(Error),
'Error in generateJson: API returned an empty response.',
defaultOptions.contents,
'generateJson-empty-response',
);
});
it('should throw and report error for invalid JSON syntax', async () => {
const invalidJson = '{"color": "blue"'; // missing closing brace
mockGenerateContent.mockResolvedValue(createMockResponse(invalidJson));
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
/^Failed to generate JSON content: Failed to parse API response as JSON:/,
);
expect(reportError).toHaveBeenCalledTimes(1);
expect(reportError).toHaveBeenCalledWith(
expect.any(Error),
'Failed to parse JSON response from generateJson.',
expect.objectContaining({ responseTextFailedToParse: invalidJson }),
'generateJson-parse',
);
});
it('should throw and report generic API errors', async () => {
const apiError = new Error('Service Unavailable (503)');
// Simulate the generator failing
mockGenerateContent.mockRejectedValue(apiError);
await expect(client.generateJson(defaultOptions)).rejects.toThrow(
'Failed to generate JSON content: Service Unavailable (503)',
);
// Verify generic error reporting
expect(reportError).toHaveBeenCalledTimes(1);
expect(reportError).toHaveBeenCalledWith(
apiError,
'Error generating JSON content via API.',
defaultOptions.contents,
'generateJson-api',
);
});
it('should throw immediately without reporting if aborted', async () => {
const abortError = new DOMException('Aborted', 'AbortError');
// Simulate abortion happening during the API call
mockGenerateContent.mockImplementation(() => {
abortController.abort(); // Ensure the signal is aborted when the service checks
throw abortError;
});
const options = {
...defaultOptions,
abortSignal: abortController.signal,
};
await expect(client.generateJson(options)).rejects.toThrow(abortError);
// Crucially, it should not report a cancellation as an application error
expect(reportError).not.toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,171 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content, GenerateContentConfig, Part } from '@google/genai';
import type { Config } from '../config/config.js';
import type { ContentGenerator } from './contentGenerator.js';
import { getResponseText } from '../utils/partUtils.js';
import { reportError } from '../utils/errorReporting.js';
import { getErrorMessage } from '../utils/errors.js';
import { logMalformedJsonResponse } from '../telemetry/loggers.js';
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { retryWithBackoff } from '../utils/retry.js';
/**
* Options for the generateJson utility function.
*/
export interface GenerateJsonOptions {
/** The input prompt or history. */
contents: Content[];
/** The required JSON schema for the output. */
schema: Record<string, unknown>;
/** The specific model to use for this task. */
model: string;
/**
* Task-specific system instructions.
* If omitted, no system instruction is sent.
*/
systemInstruction?: string | Part | Part[] | Content;
/**
* Overrides for generation configuration (e.g., temperature).
*/
config?: Omit<
GenerateContentConfig,
| 'systemInstruction'
| 'responseJsonSchema'
| 'responseMimeType'
| 'tools'
| 'abortSignal'
>;
/** Signal for cancellation. */
abortSignal: AbortSignal;
/**
* A unique ID for the prompt, used for logging/telemetry correlation.
*/
promptId: string;
}
/**
* A client dedicated to stateless, utility-focused LLM calls.
*/
export class BaseLlmClient {
// Default configuration for utility tasks
private readonly defaultUtilityConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
constructor(
private readonly contentGenerator: ContentGenerator,
private readonly config: Config,
) {}
async generateJson(
options: GenerateJsonOptions,
): Promise<Record<string, unknown>> {
const {
contents,
schema,
model,
abortSignal,
systemInstruction,
promptId,
} = options;
const requestConfig: GenerateContentConfig = {
abortSignal,
...this.defaultUtilityConfig,
...options.config,
...(systemInstruction && { systemInstruction }),
responseJsonSchema: schema,
responseMimeType: 'application/json',
};
try {
const apiCall = () =>
this.contentGenerator.generateContent(
{
model,
config: requestConfig,
contents,
},
promptId,
);
const result = await retryWithBackoff(apiCall);
let text = getResponseText(result)?.trim();
if (!text) {
const error = new Error(
'API returned an empty response for generateJson.',
);
await reportError(
error,
'Error in generateJson: API returned an empty response.',
contents,
'generateJson-empty-response',
);
throw error;
}
text = this.cleanJsonResponse(text, model);
try {
return JSON.parse(text);
} catch (parseError) {
const error = new Error(
`Failed to parse API response as JSON: ${getErrorMessage(parseError)}`,
);
await reportError(
parseError,
'Failed to parse JSON response from generateJson.',
{
responseTextFailedToParse: text,
originalRequestContents: contents,
},
'generateJson-parse',
);
throw error;
}
} catch (error) {
if (abortSignal.aborted) {
throw error;
}
if (
error instanceof Error &&
(error.message === 'API returned an empty response for generateJson.' ||
error.message.startsWith('Failed to parse API response as JSON:'))
) {
// We perform this check so that we don't report these again.
} else {
await reportError(
error,
'Error generating JSON content via API.',
contents,
'generateJson-api',
);
}
throw new Error(
`Failed to generate JSON content: ${getErrorMessage(error)}`,
);
}
}
private cleanJsonResponse(text: string, model: string): string {
const prefix = '```json';
const suffix = '```';
if (text.startsWith(prefix) && text.endsWith(suffix)) {
logMalformedJsonResponse(
this.config,
new MalformedJsonResponseEvent(model),
);
return text.substring(prefix.length, text.length - suffix.length).trim();
}
return text;
}
}

View File

@@ -50,6 +50,7 @@ export * from './utils/workspaceContext.js';
export * from './utils/ignorePatterns.js';
export * from './utils/partUtils.js';
export * from './utils/ide-trust.js';
export * from './utils/promptIdContext.js';
// Export services
export * from './services/fileDiscoveryService.js';

View File

@@ -60,6 +60,7 @@ import { ApprovalMode, type Config } from '../config/config.js';
import { type Content, type Part, type SchemaUnion } from '@google/genai';
import { createMockWorkspaceContext } from '../test-utils/mockWorkspaceContext.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
describe('SmartEditTool', () => {
let tool: SmartEditTool;
@@ -67,6 +68,7 @@ describe('SmartEditTool', () => {
let rootDir: string;
let mockConfig: Config;
let geminiClient: any;
let baseLlmClient: BaseLlmClient;
beforeEach(() => {
vi.restoreAllMocks();
@@ -78,8 +80,13 @@ describe('SmartEditTool', () => {
generateJson: mockGenerateJson,
};
baseLlmClient = {
generateJson: mockGenerateJson,
} as unknown as BaseLlmClient;
mockConfig = {
getGeminiClient: vi.fn().mockReturnValue(geminiClient),
getBaseLlmClient: vi.fn().mockReturnValue(baseLlmClient),
getTargetDir: () => rootDir,
getApprovalMode: vi.fn(),
setApprovalMode: vi.fn(),

View File

@@ -310,7 +310,7 @@ class EditToolInvocation implements ToolInvocation<EditToolParams, ToolResult> {
params.new_string,
initialError.raw,
currentContent,
this.config.getGeminiClient(),
this.config.getBaseLlmClient(),
abortSignal,
);

View File

@@ -0,0 +1,203 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
FixLLMEditWithInstruction,
resetLlmEditFixerCaches_TEST_ONLY,
type SearchReplaceEdit,
} from './llm-edit-fixer.js';
import { promptIdContext } from './promptIdContext.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
// Mock the BaseLlmClient
const mockGenerateJson = vi.fn();
const mockBaseLlmClient = {
generateJson: mockGenerateJson,
} as unknown as BaseLlmClient;
describe('FixLLMEditWithInstruction', () => {
const instruction = 'Replace the title';
const old_string = '<h1>Old Title</h1>';
const new_string = '<h1>New Title</h1>';
const error = 'String not found';
const current_content = '<body><h1>Old Title</h1></body>';
const abortController = new AbortController();
const abortSignal = abortController.signal;
beforeEach(() => {
vi.clearAllMocks();
resetLlmEditFixerCaches_TEST_ONLY(); // Ensure cache is cleared before each test
});
afterEach(() => {
vi.useRealTimers(); // Reset timers after each test
});
const mockApiResponse: SearchReplaceEdit = {
search: '<h1>Old Title</h1>',
replace: '<h1>New Title</h1>',
noChangesRequired: false,
explanation: 'The original search was correct.',
};
it('should use the promptId from the AsyncLocalStorage context when available', async () => {
const testPromptId = 'test-prompt-id-12345';
mockGenerateJson.mockResolvedValue(mockApiResponse);
await promptIdContext.run(testPromptId, async () => {
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
});
// Verify that generateJson was called with the promptId from the context
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(mockGenerateJson).toHaveBeenCalledWith(
expect.objectContaining({
promptId: testPromptId,
}),
);
});
it('should generate and use a fallback promptId when context is not available', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const consoleWarnSpy = vi
.spyOn(console, 'warn')
.mockImplementation(() => {});
// Run the function outside of any context
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Verify the warning was logged
expect(consoleWarnSpy).toHaveBeenCalledWith(
expect.stringContaining(
'Could not find promptId in context. This is unexpected. Using a fallback ID: llm-fixer-fallback-',
),
);
// Verify that generateJson was called with the generated fallback promptId
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
expect(mockGenerateJson).toHaveBeenCalledWith(
expect.objectContaining({
promptId: expect.stringContaining('llm-fixer-fallback-'),
}),
);
// Restore mocks
consoleWarnSpy.mockRestore();
});
it('should construct the user prompt correctly', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const promptId = 'test-prompt-id-prompt-construction';
await promptIdContext.run(promptId, async () => {
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
});
const generateJsonCall = mockGenerateJson.mock.calls[0][0];
const userPromptContent = generateJsonCall.contents[0].parts[0].text;
expect(userPromptContent).toContain(
`<instruction>\n${instruction}\n</instruction>`,
);
expect(userPromptContent).toContain(`<search>\n${old_string}\n</search>`);
expect(userPromptContent).toContain(`<replace>\n${new_string}\n</replace>`);
expect(userPromptContent).toContain(`<error>\n${error}\n</error>`);
expect(userPromptContent).toContain(
`<file_content>\n${current_content}\n</file_content>`,
);
});
it('should return a cached result on subsequent identical calls', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const testPromptId = 'test-prompt-id-caching';
await promptIdContext.run(testPromptId, async () => {
// First call - should call the API
const result1 = await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Second call with identical parameters - should hit the cache
const result2 = await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
expect(result1).toEqual(mockApiResponse);
expect(result2).toEqual(mockApiResponse);
// Verify the underlying service was only called ONCE
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
});
});
it('should not use cache for calls with different parameters', async () => {
mockGenerateJson.mockResolvedValue(mockApiResponse);
const testPromptId = 'test-prompt-id-cache-miss';
await promptIdContext.run(testPromptId, async () => {
// First call
await FixLLMEditWithInstruction(
instruction,
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Second call with a different instruction
await FixLLMEditWithInstruction(
'A different instruction',
old_string,
new_string,
error,
current_content,
mockBaseLlmClient,
abortSignal,
);
// Verify the underlying service was called TWICE
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
});
});
});

View File

@@ -5,9 +5,10 @@
*/
import { type Content, Type } from '@google/genai';
import { type GeminiClient } from '../core/client.js';
import { type BaseLlmClient } from '../core/baseLlmClient.js';
import { LruCache } from './LruCache.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { promptIdContext } from './promptIdContext.js';
const MAX_CACHE_SIZE = 50;
@@ -93,8 +94,9 @@ const editCorrectionWithInstructionCache = new LruCache<
* @param new_string The original replacement string.
* @param error The error that occurred during the initial edit.
* @param current_content The current content of the file.
* @param geminiClient The Gemini client to use for the LLM call.
* @param baseLlmClient The BaseLlmClient to use for the LLM call.
* @param abortSignal An abort signal to cancel the operation.
* @param promptId A unique ID for the prompt.
* @returns A new search and replace pair.
*/
export async function FixLLMEditWithInstruction(
@@ -103,9 +105,17 @@ export async function FixLLMEditWithInstruction(
new_string: string,
error: string,
current_content: string,
geminiClient: GeminiClient,
baseLlmClient: BaseLlmClient,
abortSignal: AbortSignal,
): Promise<SearchReplaceEdit> {
let promptId = promptIdContext.getStore();
if (!promptId) {
promptId = `llm-fixer-fallback-${Date.now()}-${Math.random().toString(16).slice(2)}`;
console.warn(
`Could not find promptId in context. This is unexpected. Using a fallback ID: ${promptId}`,
);
}
const cacheKey = `${instruction}---${old_string}---${new_string}--${current_content}--${error}`;
const cachedResult = editCorrectionWithInstructionCache.get(cacheKey);
if (cachedResult) {
@@ -120,21 +130,18 @@ export async function FixLLMEditWithInstruction(
const contents: Content[] = [
{
role: 'user',
parts: [
{
text: `${EDIT_SYS_PROMPT}
${userPrompt}`,
},
],
parts: [{ text: userPrompt }],
},
];
const result = (await geminiClient.generateJson(
const result = (await baseLlmClient.generateJson({
contents,
SearchReplaceEditSchema,
schema: SearchReplaceEditSchema,
abortSignal,
DEFAULT_GEMINI_FLASH_MODEL,
)) as unknown as SearchReplaceEdit;
model: DEFAULT_GEMINI_FLASH_MODEL,
systemInstruction: EDIT_SYS_PROMPT,
promptId,
})) as unknown as SearchReplaceEdit;
editCorrectionWithInstructionCache.set(cacheKey, result);
return result;

View File

@@ -0,0 +1,9 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { AsyncLocalStorage } from 'node:async_hooks';
export const promptIdContext = new AsyncLocalStorage<string>();