diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts index ade925a181..8405aec78a 100644 --- a/packages/core/src/core/baseLlmClient.test.ts +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -180,6 +180,36 @@ describe('BaseLlmClient', () => { customPromptId, ); }); + + it('should pass maxAttempts to retryWithBackoff when provided', async () => { + const mockResponse = createMockResponse('{"color": "cyan"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + const customMaxAttempts = 3; + + const options: GenerateJsonOptions = { + ...defaultOptions, + maxAttempts: customMaxAttempts, + }; + + await client.generateJson(options); + + expect(retryWithBackoff).toHaveBeenCalledTimes(1); + expect(retryWithBackoff).toHaveBeenCalledWith(expect.any(Function), { + maxAttempts: customMaxAttempts, + }); + }); + + it('should call retryWithBackoff without maxAttempts when not provided', async () => { + const mockResponse = createMockResponse('{"color": "indigo"}'); + mockGenerateContent.mockResolvedValue(mockResponse); + + // No maxAttempts in defaultOptions + await client.generateJson(defaultOptions); + + expect(retryWithBackoff).toHaveBeenCalledWith(expect.any(Function), { + maxAttempts: undefined, + }); + }); }); describe('generateJson - Response Cleaning', () => { diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts index 8ce63540fc..e76cb4cfb9 100644 --- a/packages/core/src/core/baseLlmClient.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -51,6 +51,10 @@ export interface GenerateJsonOptions { * A unique ID for the prompt, used for logging/telemetry correlation. */ promptId: string; + /** + * The maximum number of attempts for the request. + */ + maxAttempts?: number; } /** @@ -78,6 +82,7 @@ export class BaseLlmClient { abortSignal, systemInstruction, promptId, + maxAttempts, } = options; const requestConfig: GenerateContentConfig = { @@ -100,7 +105,7 @@ export class BaseLlmClient { promptId, ); - const result = await retryWithBackoff(apiCall); + const result = await retryWithBackoff(apiCall, { maxAttempts }); let text = getResponseText(result)?.trim(); if (!text) { diff --git a/packages/core/src/utils/llm-edit-fixer.ts b/packages/core/src/utils/llm-edit-fixer.ts index a4b4b131c0..7b186dfde9 100644 --- a/packages/core/src/utils/llm-edit-fixer.ts +++ b/packages/core/src/utils/llm-edit-fixer.ts @@ -141,6 +141,7 @@ export async function FixLLMEditWithInstruction( model: DEFAULT_GEMINI_FLASH_MODEL, systemInstruction: EDIT_SYS_PROMPT, promptId, + maxAttempts: 1, })) as unknown as SearchReplaceEdit; editCorrectionWithInstructionCache.set(cacheKey, result); diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 180fe218c1..78c4100edb 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -99,6 +99,23 @@ describe('retryWithBackoff', () => { expect(mockFn).toHaveBeenCalledTimes(3); }); + it('should default to 5 maxAttempts if no options are provided', async () => { + // This function will fail more than 5 times to ensure all retries are used. + const mockFn = createFailingFunction(10); + + const promise = retryWithBackoff(mockFn); + + // Expect it to fail with the error from the 5th attempt. + // eslint-disable-next-line vitest/valid-expect + const assertionPromise = expect(promise).rejects.toThrow( + 'Simulated error attempt 5', + ); + await vi.runAllTimersAsync(); + await assertionPromise; + + expect(mockFn).toHaveBeenCalledTimes(5); + }); + it('should not retry if shouldRetry returns false', async () => { const mockFn = vi.fn(async () => { throw new NonRetryableError('Non-retryable error'); @@ -114,6 +131,18 @@ describe('retryWithBackoff', () => { expect(mockFn).toHaveBeenCalledTimes(1); }); + it('should throw an error if maxAttempts is not a positive number', async () => { + const mockFn = createFailingFunction(1); + + // Test with 0 + await expect(retryWithBackoff(mockFn, { maxAttempts: 0 })).rejects.toThrow( + 'maxAttempts must be a positive number.', + ); + + // The function should not be called at all if validation fails + expect(mockFn).not.toHaveBeenCalled(); + }); + it('should use default shouldRetry if not provided, retrying on 429', async () => { const mockFn = vi.fn(async () => { const error = new Error('Too Many Requests') as any; diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 8130088203..50f113967a 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -74,6 +74,10 @@ export async function retryWithBackoff( fn: () => Promise, options?: Partial, ): Promise { + if (options?.maxAttempts !== undefined && options.maxAttempts <= 0) { + throw new Error('maxAttempts must be a positive number.'); + } + const { maxAttempts, initialDelayMs,