diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 55d0af1400..3e31cb0365 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -907,14 +907,26 @@ describe('GeminiChat', () => { describe('API error retry behavior', () => { beforeEach(() => { // Use a more direct mock for retry testing - mockRetryWithBackoff.mockImplementation(async (apiCall, options) => { + mockRetryWithBackoff.mockImplementation(async (apiCall) => { try { return await apiCall(); } catch (error) { - if ( - options?.shouldRetryOnError && - options.shouldRetryOnError(error) - ) { + // Simulate the logic of defaultShouldRetry for ApiError + let shouldRetry = false; + if (error instanceof ApiError && error.message) { + if ( + error.status === 429 || + (error.status >= 500 && error.status < 600) + ) { + shouldRetry = true; + } + // Explicitly don't retry on these + if (error.status === 400) { + shouldRetry = false; + } + } + + if (shouldRetry) { // Try again return await apiCall(); } @@ -995,36 +1007,6 @@ describe('GeminiChat', () => { ).toBe(true); }); - it('should not retry on schema depth errors', async () => { - const schemaError = new ApiError({ - message: 'Request failed: maximum schema depth exceeded', - status: 500, - }); - - vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue( - schemaError, - ); - - const stream = await chat.sendMessageStream( - 'test-model', - { message: 'test' }, - 'prompt-id-schema', - ); - - await expect( - (async () => { - for await (const _ of stream) { - /* consume stream */ - } - })(), - ).rejects.toThrow(schemaError); - - // Should only be called once (no retry) - expect( - mockContentGenerator.generateContentStream, - ).toHaveBeenCalledTimes(1); - }); - it('should retry on 5xx server errors', async () => { const error500 = new ApiError({ message: 'Internal Server Error 500', diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 18353ced0e..0eff9d9d0f 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -15,7 +15,6 @@ import { type Part, type Tool, FinishReason, - ApiError, } from '@google/genai'; import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; @@ -376,15 +375,6 @@ export class GeminiChat { ) => await handleFallback(this.config, model, authType, error); const streamResponse = await retryWithBackoff(apiCall, { - shouldRetryOnError: (error: unknown) => { - if (error instanceof ApiError && error.message) { - if (error.status === 400) return false; - if (isSchemaDepthError(error.message)) return false; - if (error.status === 429) return true; - if (error.status >= 500 && error.status < 600) return true; - } - return false; - }, onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index 6a011f9a7a..c50459edc5 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -6,6 +6,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { ApiError } from '@google/genai'; import type { HttpError } from './retry.js'; import { retryWithBackoff } from './retry.js'; import { setSimulate429 } from './testUtils.js'; @@ -80,22 +81,13 @@ describe('retryWithBackoff', () => { initialDelayMs: 10, }); - // 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*. - // This ensures a 'catch' handler is present before the promise can reject. - // The result is a new promise that resolves when the assertion is met. - // eslint-disable-next-line vitest/valid-expect - const assertionPromise = expect(promise).rejects.toThrow( - 'Simulated error attempt 3', - ); + // 2. Run timers and await expectation in parallel. + await Promise.all([ + expect(promise).rejects.toThrow('Simulated error attempt 3'), + vi.runAllTimersAsync(), + ]); - // 3. Now, advance the timers. This will trigger the retries and the - // eventual rejection. The handler attached in step 2 will catch it. - await vi.runAllTimersAsync(); - - // 4. Await the assertion promise itself to ensure the test was successful. - await assertionPromise; - - // 5. Finally, assert the number of calls. + // 3. Finally, assert the number of calls. expect(mockFn).toHaveBeenCalledTimes(3); }); @@ -106,12 +98,10 @@ describe('retryWithBackoff', () => { const promise = retryWithBackoff(mockFn); // Expect it to fail with the error from the 5th attempt. - // eslint-disable-next-line vitest/valid-expect - const assertionPromise = expect(promise).rejects.toThrow( - 'Simulated error attempt 5', - ); - await vi.runAllTimersAsync(); - await assertionPromise; + await Promise.all([ + expect(promise).rejects.toThrow('Simulated error attempt 5'), + vi.runAllTimersAsync(), + ]); expect(mockFn).toHaveBeenCalledTimes(5); }); @@ -123,12 +113,10 @@ describe('retryWithBackoff', () => { const promise = retryWithBackoff(mockFn, { maxAttempts: undefined }); // Expect it to fail with the error from the 5th attempt. - // eslint-disable-next-line vitest/valid-expect - const assertionPromise = expect(promise).rejects.toThrow( - 'Simulated error attempt 5', - ); - await vi.runAllTimersAsync(); - await assertionPromise; + await Promise.all([ + expect(promise).rejects.toThrow('Simulated error attempt 5'), + vi.runAllTimersAsync(), + ]); expect(mockFn).toHaveBeenCalledTimes(5); }); @@ -161,7 +149,38 @@ describe('retryWithBackoff', () => { expect(mockFn).not.toHaveBeenCalled(); }); - it('should use default shouldRetry if not provided, retrying on 429', async () => { + it('should use default shouldRetry if not provided, retrying on ApiError 429', async () => { + const mockFn = vi.fn(async () => { + throw new ApiError({ message: 'Too Many Requests', status: 429 }); + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + + await Promise.all([ + expect(promise).rejects.toThrow('Too Many Requests'), + vi.runAllTimersAsync(), + ]); + + expect(mockFn).toHaveBeenCalledTimes(2); + }); + + it('should use default shouldRetry if not provided, not retrying on ApiError 400', async () => { + const mockFn = vi.fn(async () => { + throw new ApiError({ message: 'Bad Request', status: 400 }); + }); + + const promise = retryWithBackoff(mockFn, { + maxAttempts: 2, + initialDelayMs: 10, + }); + await expect(promise).rejects.toThrow('Bad Request'); + expect(mockFn).toHaveBeenCalledTimes(1); + }); + + it('should use default shouldRetry if not provided, retrying on generic error with status 429', async () => { const mockFn = vi.fn(async () => { const error = new Error('Too Many Requests') as any; error.status = 429; @@ -173,20 +192,16 @@ describe('retryWithBackoff', () => { initialDelayMs: 10, }); - // Attach the rejection expectation *before* running timers - const assertionPromise = - expect(promise).rejects.toThrow('Too Many Requests'); // eslint-disable-line vitest/valid-expect - - // Run timers to trigger retries and eventual rejection - await vi.runAllTimersAsync(); - - // Await the assertion - await assertionPromise; + // Run timers and await expectation in parallel. + await Promise.all([ + expect(promise).rejects.toThrow('Too Many Requests'), + vi.runAllTimersAsync(), + ]); expect(mockFn).toHaveBeenCalledTimes(2); }); - it('should use default shouldRetry if not provided, not retrying on 400', async () => { + it('should use default shouldRetry if not provided, not retrying on generic error with status 400', async () => { const mockFn = vi.fn(async () => { const error = new Error('Bad Request') as any; error.status = 400; @@ -242,11 +257,11 @@ describe('retryWithBackoff', () => { // We expect rejections as mockFn fails 5 times const promise1 = runRetry(); - // Attach the rejection expectation *before* running timers - // eslint-disable-next-line vitest/valid-expect - const assertionPromise1 = expect(promise1).rejects.toThrow(); - await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry - await assertionPromise1; + // Run timers and await expectation in parallel. + await Promise.all([ + expect(promise1).rejects.toThrow(), + vi.runAllTimersAsync(), + ]); const firstDelaySet = setTimeoutSpy.mock.calls.map( (call) => call[1] as number, @@ -257,11 +272,11 @@ describe('retryWithBackoff', () => { mockFn = createFailingFunction(5); // Re-initialize with 5 failures const promise2 = runRetry(); - // Attach the rejection expectation *before* running timers - // eslint-disable-next-line vitest/valid-expect - const assertionPromise2 = expect(promise2).rejects.toThrow(); - await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry - await assertionPromise2; + // Run timers and await expectation in parallel. + await Promise.all([ + expect(promise2).rejects.toThrow(), + vi.runAllTimersAsync(), + ]); const secondDelaySet = setTimeoutSpy.mock.calls.map( (call) => call[1] as number, diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 54d91f7d2d..091c355a12 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -5,6 +5,7 @@ */ import type { GenerateContentResponse } from '@google/genai'; +import { ApiError } from '@google/genai'; import { AuthType } from '../core/contentGenerator.js'; import { isProQuotaExceededError, @@ -42,17 +43,19 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = { * @returns True if the error is a transient error, false otherwise. */ function defaultShouldRetry(error: Error | unknown): boolean { - // Check for common transient error status codes either in message or a status property - if (error && typeof (error as { status?: number }).status === 'number') { - const status = (error as { status: number }).status; - if (status === 429 || (status >= 500 && status < 600)) { - return true; - } + // Priority check for ApiError + if (error instanceof ApiError) { + // Explicitly do not retry 400 (Bad Request) + if (error.status === 400) return false; + return error.status === 429 || (error.status >= 500 && error.status < 600); } - if (error instanceof Error && error.message) { - if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; + + // Check for status using helper (handles other error shapes) + const status = getErrorStatus(error); + if (status !== undefined) { + return status === 429 || (status >= 500 && status < 600); } + return false; }