refactor(core): Unify retry logic and remove schema depth check (#10453)

This commit is contained in:
Sandy Tao
2025-10-03 09:08:27 -07:00
committed by GitHub
parent 667ca6d241
commit 3b92f127a9
4 changed files with 92 additions and 102 deletions

View File

@@ -907,14 +907,26 @@ describe('GeminiChat', () => {
describe('API error retry behavior', () => {
beforeEach(() => {
// Use a more direct mock for retry testing
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
try {
return await apiCall();
} catch (error) {
if (
options?.shouldRetryOnError &&
options.shouldRetryOnError(error)
) {
// Simulate the logic of defaultShouldRetry for ApiError
let shouldRetry = false;
if (error instanceof ApiError && error.message) {
if (
error.status === 429 ||
(error.status >= 500 && error.status < 600)
) {
shouldRetry = true;
}
// Explicitly don't retry on these
if (error.status === 400) {
shouldRetry = false;
}
}
if (shouldRetry) {
// Try again
return await apiCall();
}
@@ -995,36 +1007,6 @@ describe('GeminiChat', () => {
).toBe(true);
});
it('should not retry on schema depth errors', async () => {
const schemaError = new ApiError({
message: 'Request failed: maximum schema depth exceeded',
status: 500,
});
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
schemaError,
);
const stream = await chat.sendMessageStream(
'test-model',
{ message: 'test' },
'prompt-id-schema',
);
await expect(
(async () => {
for await (const _ of stream) {
/* consume stream */
}
})(),
).rejects.toThrow(schemaError);
// Should only be called once (no retry)
expect(
mockContentGenerator.generateContentStream,
).toHaveBeenCalledTimes(1);
});
it('should retry on 5xx server errors', async () => {
const error500 = new ApiError({
message: 'Internal Server Error 500',

View File

@@ -15,7 +15,6 @@ import {
type Part,
type Tool,
FinishReason,
ApiError,
} from '@google/genai';
import { toParts } from '../code_assist/converter.js';
import { createUserContent } from '@google/genai';
@@ -376,15 +375,6 @@ export class GeminiChat {
) => await handleFallback(this.config, model, authType, error);
const streamResponse = await retryWithBackoff(apiCall, {
shouldRetryOnError: (error: unknown) => {
if (error instanceof ApiError && error.message) {
if (error.status === 400) return false;
if (isSchemaDepthError(error.message)) return false;
if (error.status === 429) return true;
if (error.status >= 500 && error.status < 600) return true;
}
return false;
},
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});

View File

@@ -6,6 +6,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import { ApiError } from '@google/genai';
import type { HttpError } from './retry.js';
import { retryWithBackoff } from './retry.js';
import { setSimulate429 } from './testUtils.js';
@@ -80,22 +81,13 @@ describe('retryWithBackoff', () => {
initialDelayMs: 10,
});
// 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*.
// This ensures a 'catch' handler is present before the promise can reject.
// The result is a new promise that resolves when the assertion is met.
// eslint-disable-next-line vitest/valid-expect
const assertionPromise = expect(promise).rejects.toThrow(
'Simulated error attempt 3',
);
// 2. Run timers and await expectation in parallel.
await Promise.all([
expect(promise).rejects.toThrow('Simulated error attempt 3'),
vi.runAllTimersAsync(),
]);
// 3. Now, advance the timers. This will trigger the retries and the
// eventual rejection. The handler attached in step 2 will catch it.
await vi.runAllTimersAsync();
// 4. Await the assertion promise itself to ensure the test was successful.
await assertionPromise;
// 5. Finally, assert the number of calls.
// 3. Finally, assert the number of calls.
expect(mockFn).toHaveBeenCalledTimes(3);
});
@@ -106,12 +98,10 @@ describe('retryWithBackoff', () => {
const promise = retryWithBackoff(mockFn);
// Expect it to fail with the error from the 5th attempt.
// eslint-disable-next-line vitest/valid-expect
const assertionPromise = expect(promise).rejects.toThrow(
'Simulated error attempt 5',
);
await vi.runAllTimersAsync();
await assertionPromise;
await Promise.all([
expect(promise).rejects.toThrow('Simulated error attempt 5'),
vi.runAllTimersAsync(),
]);
expect(mockFn).toHaveBeenCalledTimes(5);
});
@@ -123,12 +113,10 @@ describe('retryWithBackoff', () => {
const promise = retryWithBackoff(mockFn, { maxAttempts: undefined });
// Expect it to fail with the error from the 5th attempt.
// eslint-disable-next-line vitest/valid-expect
const assertionPromise = expect(promise).rejects.toThrow(
'Simulated error attempt 5',
);
await vi.runAllTimersAsync();
await assertionPromise;
await Promise.all([
expect(promise).rejects.toThrow('Simulated error attempt 5'),
vi.runAllTimersAsync(),
]);
expect(mockFn).toHaveBeenCalledTimes(5);
});
@@ -161,7 +149,38 @@ describe('retryWithBackoff', () => {
expect(mockFn).not.toHaveBeenCalled();
});
it('should use default shouldRetry if not provided, retrying on 429', async () => {
it('should use default shouldRetry if not provided, retrying on ApiError 429', async () => {
const mockFn = vi.fn(async () => {
throw new ApiError({ message: 'Too Many Requests', status: 429 });
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 2,
initialDelayMs: 10,
});
await Promise.all([
expect(promise).rejects.toThrow('Too Many Requests'),
vi.runAllTimersAsync(),
]);
expect(mockFn).toHaveBeenCalledTimes(2);
});
it('should use default shouldRetry if not provided, not retrying on ApiError 400', async () => {
const mockFn = vi.fn(async () => {
throw new ApiError({ message: 'Bad Request', status: 400 });
});
const promise = retryWithBackoff(mockFn, {
maxAttempts: 2,
initialDelayMs: 10,
});
await expect(promise).rejects.toThrow('Bad Request');
expect(mockFn).toHaveBeenCalledTimes(1);
});
it('should use default shouldRetry if not provided, retrying on generic error with status 429', async () => {
const mockFn = vi.fn(async () => {
const error = new Error('Too Many Requests') as any;
error.status = 429;
@@ -173,20 +192,16 @@ describe('retryWithBackoff', () => {
initialDelayMs: 10,
});
// Attach the rejection expectation *before* running timers
const assertionPromise =
expect(promise).rejects.toThrow('Too Many Requests'); // eslint-disable-line vitest/valid-expect
// Run timers to trigger retries and eventual rejection
await vi.runAllTimersAsync();
// Await the assertion
await assertionPromise;
// Run timers and await expectation in parallel.
await Promise.all([
expect(promise).rejects.toThrow('Too Many Requests'),
vi.runAllTimersAsync(),
]);
expect(mockFn).toHaveBeenCalledTimes(2);
});
it('should use default shouldRetry if not provided, not retrying on 400', async () => {
it('should use default shouldRetry if not provided, not retrying on generic error with status 400', async () => {
const mockFn = vi.fn(async () => {
const error = new Error('Bad Request') as any;
error.status = 400;
@@ -242,11 +257,11 @@ describe('retryWithBackoff', () => {
// We expect rejections as mockFn fails 5 times
const promise1 = runRetry();
// Attach the rejection expectation *before* running timers
// eslint-disable-next-line vitest/valid-expect
const assertionPromise1 = expect(promise1).rejects.toThrow();
await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry
await assertionPromise1;
// Run timers and await expectation in parallel.
await Promise.all([
expect(promise1).rejects.toThrow(),
vi.runAllTimersAsync(),
]);
const firstDelaySet = setTimeoutSpy.mock.calls.map(
(call) => call[1] as number,
@@ -257,11 +272,11 @@ describe('retryWithBackoff', () => {
mockFn = createFailingFunction(5); // Re-initialize with 5 failures
const promise2 = runRetry();
// Attach the rejection expectation *before* running timers
// eslint-disable-next-line vitest/valid-expect
const assertionPromise2 = expect(promise2).rejects.toThrow();
await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry
await assertionPromise2;
// Run timers and await expectation in parallel.
await Promise.all([
expect(promise2).rejects.toThrow(),
vi.runAllTimersAsync(),
]);
const secondDelaySet = setTimeoutSpy.mock.calls.map(
(call) => call[1] as number,

View File

@@ -5,6 +5,7 @@
*/
import type { GenerateContentResponse } from '@google/genai';
import { ApiError } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js';
import {
isProQuotaExceededError,
@@ -42,17 +43,19 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = {
* @returns True if the error is a transient error, false otherwise.
*/
function defaultShouldRetry(error: Error | unknown): boolean {
// Check for common transient error status codes either in message or a status property
if (error && typeof (error as { status?: number }).status === 'number') {
const status = (error as { status: number }).status;
if (status === 429 || (status >= 500 && status < 600)) {
return true;
}
// Priority check for ApiError
if (error instanceof ApiError) {
// Explicitly do not retry 400 (Bad Request)
if (error.status === 400) return false;
return error.status === 429 || (error.status >= 500 && error.status < 600);
}
if (error instanceof Error && error.message) {
if (error.message.includes('429')) return true;
if (error.message.match(/5\d{2}/)) return true;
// Check for status using helper (handles other error shapes)
const status = getErrorStatus(error);
if (status !== undefined) {
return status === 429 || (status >= 500 && status < 600);
}
return false;
}