mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-13 23:51:16 -07:00
refactor(core): Unify retry logic and remove schema depth check (#10453)
This commit is contained in:
@@ -907,14 +907,26 @@ describe('GeminiChat', () => {
|
||||
describe('API error retry behavior', () => {
|
||||
beforeEach(() => {
|
||||
// Use a more direct mock for retry testing
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
|
||||
try {
|
||||
return await apiCall();
|
||||
} catch (error) {
|
||||
if (
|
||||
options?.shouldRetryOnError &&
|
||||
options.shouldRetryOnError(error)
|
||||
) {
|
||||
// Simulate the logic of defaultShouldRetry for ApiError
|
||||
let shouldRetry = false;
|
||||
if (error instanceof ApiError && error.message) {
|
||||
if (
|
||||
error.status === 429 ||
|
||||
(error.status >= 500 && error.status < 600)
|
||||
) {
|
||||
shouldRetry = true;
|
||||
}
|
||||
// Explicitly don't retry on these
|
||||
if (error.status === 400) {
|
||||
shouldRetry = false;
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldRetry) {
|
||||
// Try again
|
||||
return await apiCall();
|
||||
}
|
||||
@@ -995,36 +1007,6 @@ describe('GeminiChat', () => {
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it('should not retry on schema depth errors', async () => {
|
||||
const schemaError = new ApiError({
|
||||
message: 'Request failed: maximum schema depth exceeded',
|
||||
status: 500,
|
||||
});
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
|
||||
schemaError,
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
'test-model',
|
||||
{ message: 'test' },
|
||||
'prompt-id-schema',
|
||||
);
|
||||
|
||||
await expect(
|
||||
(async () => {
|
||||
for await (const _ of stream) {
|
||||
/* consume stream */
|
||||
}
|
||||
})(),
|
||||
).rejects.toThrow(schemaError);
|
||||
|
||||
// Should only be called once (no retry)
|
||||
expect(
|
||||
mockContentGenerator.generateContentStream,
|
||||
).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should retry on 5xx server errors', async () => {
|
||||
const error500 = new ApiError({
|
||||
message: 'Internal Server Error 500',
|
||||
|
||||
@@ -15,7 +15,6 @@ import {
|
||||
type Part,
|
||||
type Tool,
|
||||
FinishReason,
|
||||
ApiError,
|
||||
} from '@google/genai';
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
@@ -376,15 +375,6 @@ export class GeminiChat {
|
||||
) => await handleFallback(this.config, model, authType, error);
|
||||
|
||||
const streamResponse = await retryWithBackoff(apiCall, {
|
||||
shouldRetryOnError: (error: unknown) => {
|
||||
if (error instanceof ApiError && error.message) {
|
||||
if (error.status === 400) return false;
|
||||
if (isSchemaDepthError(error.message)) return false;
|
||||
if (error.status === 429) return true;
|
||||
if (error.status >= 500 && error.status < 600) return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { ApiError } from '@google/genai';
|
||||
import type { HttpError } from './retry.js';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
import { setSimulate429 } from './testUtils.js';
|
||||
@@ -80,22 +81,13 @@ describe('retryWithBackoff', () => {
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
// 2. IMPORTANT: Attach the rejection expectation to the promise *immediately*.
|
||||
// This ensures a 'catch' handler is present before the promise can reject.
|
||||
// The result is a new promise that resolves when the assertion is met.
|
||||
// eslint-disable-next-line vitest/valid-expect
|
||||
const assertionPromise = expect(promise).rejects.toThrow(
|
||||
'Simulated error attempt 3',
|
||||
);
|
||||
// 2. Run timers and await expectation in parallel.
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Simulated error attempt 3'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
// 3. Now, advance the timers. This will trigger the retries and the
|
||||
// eventual rejection. The handler attached in step 2 will catch it.
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
// 4. Await the assertion promise itself to ensure the test was successful.
|
||||
await assertionPromise;
|
||||
|
||||
// 5. Finally, assert the number of calls.
|
||||
// 3. Finally, assert the number of calls.
|
||||
expect(mockFn).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
@@ -106,12 +98,10 @@ describe('retryWithBackoff', () => {
|
||||
const promise = retryWithBackoff(mockFn);
|
||||
|
||||
// Expect it to fail with the error from the 5th attempt.
|
||||
// eslint-disable-next-line vitest/valid-expect
|
||||
const assertionPromise = expect(promise).rejects.toThrow(
|
||||
'Simulated error attempt 5',
|
||||
);
|
||||
await vi.runAllTimersAsync();
|
||||
await assertionPromise;
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Simulated error attempt 5'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
expect(mockFn).toHaveBeenCalledTimes(5);
|
||||
});
|
||||
@@ -123,12 +113,10 @@ describe('retryWithBackoff', () => {
|
||||
const promise = retryWithBackoff(mockFn, { maxAttempts: undefined });
|
||||
|
||||
// Expect it to fail with the error from the 5th attempt.
|
||||
// eslint-disable-next-line vitest/valid-expect
|
||||
const assertionPromise = expect(promise).rejects.toThrow(
|
||||
'Simulated error attempt 5',
|
||||
);
|
||||
await vi.runAllTimersAsync();
|
||||
await assertionPromise;
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Simulated error attempt 5'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
expect(mockFn).toHaveBeenCalledTimes(5);
|
||||
});
|
||||
@@ -161,7 +149,38 @@ describe('retryWithBackoff', () => {
|
||||
expect(mockFn).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, retrying on 429', async () => {
|
||||
it('should use default shouldRetry if not provided, retrying on ApiError 429', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
throw new ApiError({ message: 'Too Many Requests', status: 429 });
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 2,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Too Many Requests'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, not retrying on ApiError 400', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
throw new ApiError({ message: 'Bad Request', status: 400 });
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 2,
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
await expect(promise).rejects.toThrow('Bad Request');
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, retrying on generic error with status 429', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error = new Error('Too Many Requests') as any;
|
||||
error.status = 429;
|
||||
@@ -173,20 +192,16 @@ describe('retryWithBackoff', () => {
|
||||
initialDelayMs: 10,
|
||||
});
|
||||
|
||||
// Attach the rejection expectation *before* running timers
|
||||
const assertionPromise =
|
||||
expect(promise).rejects.toThrow('Too Many Requests'); // eslint-disable-line vitest/valid-expect
|
||||
|
||||
// Run timers to trigger retries and eventual rejection
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
// Await the assertion
|
||||
await assertionPromise;
|
||||
// Run timers and await expectation in parallel.
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Too Many Requests'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should use default shouldRetry if not provided, not retrying on 400', async () => {
|
||||
it('should use default shouldRetry if not provided, not retrying on generic error with status 400', async () => {
|
||||
const mockFn = vi.fn(async () => {
|
||||
const error = new Error('Bad Request') as any;
|
||||
error.status = 400;
|
||||
@@ -242,11 +257,11 @@ describe('retryWithBackoff', () => {
|
||||
|
||||
// We expect rejections as mockFn fails 5 times
|
||||
const promise1 = runRetry();
|
||||
// Attach the rejection expectation *before* running timers
|
||||
// eslint-disable-next-line vitest/valid-expect
|
||||
const assertionPromise1 = expect(promise1).rejects.toThrow();
|
||||
await vi.runAllTimersAsync(); // Advance for the delay in the first runRetry
|
||||
await assertionPromise1;
|
||||
// Run timers and await expectation in parallel.
|
||||
await Promise.all([
|
||||
expect(promise1).rejects.toThrow(),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
const firstDelaySet = setTimeoutSpy.mock.calls.map(
|
||||
(call) => call[1] as number,
|
||||
@@ -257,11 +272,11 @@ describe('retryWithBackoff', () => {
|
||||
mockFn = createFailingFunction(5); // Re-initialize with 5 failures
|
||||
|
||||
const promise2 = runRetry();
|
||||
// Attach the rejection expectation *before* running timers
|
||||
// eslint-disable-next-line vitest/valid-expect
|
||||
const assertionPromise2 = expect(promise2).rejects.toThrow();
|
||||
await vi.runAllTimersAsync(); // Advance for the delay in the second runRetry
|
||||
await assertionPromise2;
|
||||
// Run timers and await expectation in parallel.
|
||||
await Promise.all([
|
||||
expect(promise2).rejects.toThrow(),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
const secondDelaySet = setTimeoutSpy.mock.calls.map(
|
||||
(call) => call[1] as number,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import { ApiError } from '@google/genai';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
isProQuotaExceededError,
|
||||
@@ -42,17 +43,19 @@ const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
* @returns True if the error is a transient error, false otherwise.
|
||||
*/
|
||||
function defaultShouldRetry(error: Error | unknown): boolean {
|
||||
// Check for common transient error status codes either in message or a status property
|
||||
if (error && typeof (error as { status?: number }).status === 'number') {
|
||||
const status = (error as { status: number }).status;
|
||||
if (status === 429 || (status >= 500 && status < 600)) {
|
||||
return true;
|
||||
}
|
||||
// Priority check for ApiError
|
||||
if (error instanceof ApiError) {
|
||||
// Explicitly do not retry 400 (Bad Request)
|
||||
if (error.status === 400) return false;
|
||||
return error.status === 429 || (error.status >= 500 && error.status < 600);
|
||||
}
|
||||
if (error instanceof Error && error.message) {
|
||||
if (error.message.includes('429')) return true;
|
||||
if (error.message.match(/5\d{2}/)) return true;
|
||||
|
||||
// Check for status using helper (handles other error shapes)
|
||||
const status = getErrorStatus(error);
|
||||
if (status !== undefined) {
|
||||
return status === 429 || (status >= 500 && status < 600);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user