mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-14 22:02:59 -07:00
feat: centralize maxAttempts configuration via ExperimentFlags
This commit centralizes the retry attempt limits to be driven by the `ExperimentFlags.MAX_ATTEMPTS` flag or the user configuration, rather than being hardcoded throughout the codebase. The retry logic in `baseLlmClient`, `geminiChat`, `client`, and `web-fetch` has been updated to retrieve the `maxAttempts` setting directly from `Config`. It also addresses the removal of the previous 10-attempt cap in the Config initialization to allow tests simulating high retry limits to pass successfully.
This commit is contained in:
@@ -20,6 +20,7 @@ export const ExperimentFlags = {
|
||||
PRO_MODEL_NO_ACCESS: 45768879,
|
||||
GEMINI_3_1_FLASH_LITE_LAUNCHED: 45771641,
|
||||
DEFAULT_REQUEST_TIMEOUT: 45773134,
|
||||
MAX_ATTEMPTS: 45774515,
|
||||
} as const;
|
||||
|
||||
export type ExperimentFlagName =
|
||||
|
||||
@@ -305,6 +305,53 @@ describe('Server Config (config.ts)', () => {
|
||||
});
|
||||
expect(config.getMaxAttempts()).toBe(DEFAULT_MAX_ATTEMPTS);
|
||||
});
|
||||
|
||||
it('should use experiment flag if present and valid', () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
experiments: {
|
||||
flags: {
|
||||
[ExperimentFlags.MAX_ATTEMPTS]: {
|
||||
intValue: '15',
|
||||
},
|
||||
},
|
||||
experimentIds: [],
|
||||
},
|
||||
});
|
||||
expect(config.getMaxAttempts()).toBe(15);
|
||||
});
|
||||
|
||||
it('should fallback to maxAttempts if experiment flag is invalid', () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
maxAttempts: 5,
|
||||
experiments: {
|
||||
flags: {
|
||||
[ExperimentFlags.MAX_ATTEMPTS]: {
|
||||
intValue: 'abc',
|
||||
},
|
||||
},
|
||||
experimentIds: [],
|
||||
},
|
||||
});
|
||||
expect(config.getMaxAttempts()).toBe(5);
|
||||
});
|
||||
|
||||
it('should fallback to maxAttempts if experiment flag is non-positive', () => {
|
||||
const config = new Config({
|
||||
...baseParams,
|
||||
maxAttempts: 5,
|
||||
experiments: {
|
||||
flags: {
|
||||
[ExperimentFlags.MAX_ATTEMPTS]: {
|
||||
intValue: '0',
|
||||
},
|
||||
},
|
||||
experimentIds: [],
|
||||
},
|
||||
});
|
||||
expect(config.getMaxAttempts()).toBe(5);
|
||||
});
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
|
||||
@@ -3318,6 +3318,14 @@ export class Config implements McpContext, AgentLoopContext {
|
||||
}
|
||||
|
||||
getMaxAttempts(): number {
|
||||
const flagVal =
|
||||
this.experiments?.flags?.[ExperimentFlags.MAX_ATTEMPTS]?.intValue;
|
||||
if (flagVal !== undefined) {
|
||||
const parsed = parseInt(flagVal, 10);
|
||||
if (!isNaN(parsed) && parsed > 0) {
|
||||
return parsed;
|
||||
}
|
||||
}
|
||||
return this.maxAttempts;
|
||||
}
|
||||
|
||||
|
||||
@@ -252,7 +252,7 @@ describe('BaseLlmClient', () => {
|
||||
expect(retryWithBackoff).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
expect.objectContaining({
|
||||
maxAttempts: 5,
|
||||
maxAttempts: 3,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
@@ -36,8 +36,6 @@ import {
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
|
||||
const DEFAULT_MAX_ATTEMPTS = 5;
|
||||
|
||||
/**
|
||||
* Options for the generateJson utility function.
|
||||
*/
|
||||
@@ -328,7 +326,9 @@ export class BaseLlmClient {
|
||||
return await retryWithBackoff(apiCall, {
|
||||
shouldRetryOnContent,
|
||||
maxAttempts:
|
||||
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
|
||||
availabilityMaxAttempts ??
|
||||
maxAttempts ??
|
||||
this.config.getMaxAttempts(),
|
||||
getAvailabilityContext,
|
||||
onPersistent429: this.config.isInteractive()
|
||||
? (authType, error) =>
|
||||
@@ -339,7 +339,9 @@ export class BaseLlmClient {
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) => {
|
||||
const actualMaxAttempts =
|
||||
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS;
|
||||
availabilityMaxAttempts ??
|
||||
maxAttempts ??
|
||||
this.config.getMaxAttempts();
|
||||
const modelName = getDisplayString(currentModel);
|
||||
const errorType = getRetryErrorType(error);
|
||||
|
||||
|
||||
@@ -1133,7 +1133,7 @@ export class GeminiClient {
|
||||
onPersistent429: onPersistent429Callback,
|
||||
onValidationRequired: onValidationRequiredCallback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
maxAttempts: availabilityMaxAttempts,
|
||||
maxAttempts: availabilityMaxAttempts ?? this.config.getMaxAttempts(),
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
getAvailabilityContext,
|
||||
onRetry: (attempt, error, delayMs) => {
|
||||
|
||||
@@ -176,7 +176,7 @@ describe('GeminiChat', () => {
|
||||
},
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
getMaxAttempts: vi.fn().mockReturnValue(10),
|
||||
getMaxAttempts: vi.fn().mockReturnValue(4),
|
||||
getUserTier: vi.fn().mockReturnValue(undefined),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => {
|
||||
|
||||
@@ -78,8 +78,6 @@ export type StreamEvent =
|
||||
* Options for retrying mid-stream errors (e.g. invalid content or API disconnects).
|
||||
*/
|
||||
interface MidStreamRetryOptions {
|
||||
/** Total number of attempts to make (1 initial + N retries). */
|
||||
maxAttempts: number;
|
||||
/** The base delay in milliseconds for backoff. */
|
||||
initialDelayMs: number;
|
||||
/** Whether to use exponential backoff instead of linear. */
|
||||
@@ -87,7 +85,6 @@ interface MidStreamRetryOptions {
|
||||
}
|
||||
|
||||
const MID_STREAM_RETRY_OPTIONS: MidStreamRetryOptions = {
|
||||
maxAttempts: 4, // 1 initial call + 3 retries mid-stream
|
||||
initialDelayMs: 1000,
|
||||
useExponentialBackoff: true,
|
||||
};
|
||||
@@ -420,10 +417,8 @@ export class GeminiChat {
|
||||
: getRetryErrorType(error);
|
||||
|
||||
if (isContentError || (isRetryable && !signal.aborted)) {
|
||||
// The issue requests exactly 3 retries (4 attempts) for API errors during stream iteration.
|
||||
// Regardless of the global maxAttempts (e.g. 10), we only want to retry these mid-stream API errors
|
||||
// up to 3 times before finally throwing the error to the user.
|
||||
const maxMidStreamAttempts = MID_STREAM_RETRY_OPTIONS.maxAttempts;
|
||||
// We retry mid-stream API errors up to maxAttempts times before finally throwing the error to the user.
|
||||
const maxMidStreamAttempts = this.context.config.getMaxAttempts();
|
||||
|
||||
if (
|
||||
attempt < maxAttempts - 1 &&
|
||||
|
||||
@@ -309,6 +309,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
return res;
|
||||
},
|
||||
{
|
||||
maxAttempts: this.context.config.getMaxAttempts(),
|
||||
retryFetchErrors: this.context.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) =>
|
||||
this.handleRetry(attempt, error, delayMs),
|
||||
@@ -643,6 +644,7 @@ ${aggregatedContent}
|
||||
return res;
|
||||
},
|
||||
{
|
||||
maxAttempts: this.context.config.getMaxAttempts(),
|
||||
retryFetchErrors: this.context.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) =>
|
||||
this.handleRetry(attempt, error, delayMs),
|
||||
|
||||
@@ -9,12 +9,13 @@ import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { ApiError } from '@google/genai';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { type HttpError, ModelNotFoundError } from './httpErrors.js';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
import { retryWithBackoff, isRetryableError } from './retry.js';
|
||||
import { setSimulate429 } from './testUtils.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import {
|
||||
TerminalQuotaError,
|
||||
RetryableQuotaError,
|
||||
ValidationRequiredError,
|
||||
} from './googleQuotaErrors.js';
|
||||
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { ModelPolicy } from '../availability/modelPolicy.js';
|
||||
@@ -332,6 +333,81 @@ describe('retryWithBackoff', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should call onRetry callback on each retry', async () => {
|
||||
const mockFn = createFailingFunction(2);
|
||||
const onRetry = vi.fn();
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 10,
|
||||
onRetry,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
await promise;
|
||||
expect(onRetry).toHaveBeenCalledTimes(2);
|
||||
expect(onRetry).toHaveBeenCalledWith(
|
||||
1,
|
||||
expect.any(Error),
|
||||
expect.any(Number),
|
||||
);
|
||||
expect(onRetry).toHaveBeenCalledWith(
|
||||
2,
|
||||
expect.any(Error),
|
||||
expect.any(Number),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle ValidationRequiredError using onValidationRequired', async () => {
|
||||
const error = new ValidationRequiredError('Validation required', {} as any);
|
||||
let validationCalled = false;
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
if (!validationCalled) {
|
||||
throw error;
|
||||
}
|
||||
return 'success';
|
||||
});
|
||||
|
||||
const onValidationRequired = vi.fn().mockImplementation(async () => {
|
||||
validationCalled = true;
|
||||
return 'verify';
|
||||
});
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 10,
|
||||
onValidationRequired,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
const result = await promise;
|
||||
expect(result).toBe('success');
|
||||
expect(onValidationRequired).toHaveBeenCalledWith(error);
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should throw ValidationRequiredError if onValidationRequired returns cancel', async () => {
|
||||
const error = new ValidationRequiredError('Validation required', {} as any);
|
||||
const mockFn = vi.fn().mockImplementation(async () => {
|
||||
throw error;
|
||||
});
|
||||
|
||||
const onValidationRequired = vi.fn().mockResolvedValue('cancel');
|
||||
|
||||
const promise = retryWithBackoff(mockFn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 10,
|
||||
onValidationRequired,
|
||||
});
|
||||
|
||||
await expect(promise).rejects.toThrow('Validation required');
|
||||
await vi.runAllTimersAsync();
|
||||
|
||||
expect(error.userHandled).toBe(true);
|
||||
expect(mockFn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
describe('Fetch error retries', () => {
|
||||
it("should retry on 'fetch failed' when retryFetchErrors is true", async () => {
|
||||
const mockFn = vi.fn();
|
||||
@@ -886,3 +962,37 @@ describe('retryWithBackoff', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('isRetryableError', () => {
|
||||
it('should return true for 429 errors', () => {
|
||||
const error = new ApiError({ message: 'Quota exceeded', status: 429 });
|
||||
expect(isRetryableError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for 499 errors', () => {
|
||||
const error = new ApiError({
|
||||
message: 'Client closed request',
|
||||
status: 499,
|
||||
});
|
||||
expect(isRetryableError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return true for 500 errors', () => {
|
||||
const error = new ApiError({
|
||||
message: 'Internal Server Error',
|
||||
status: 500,
|
||||
});
|
||||
expect(isRetryableError(error)).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false for 400 errors', () => {
|
||||
const error = new ApiError({ message: 'Bad Request', status: 400 });
|
||||
expect(isRetryableError(error)).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for network error codes like ECONNRESET', () => {
|
||||
const error = new Error('ECONNRESET');
|
||||
(error as any).code = 'ECONNRESET';
|
||||
expect(isRetryableError(error)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,9 +17,18 @@ import { getErrorStatus, ModelNotFoundError } from './httpErrors.js';
|
||||
import type { RetryAvailabilityContext } from '../availability/modelPolicy.js';
|
||||
|
||||
export type { RetryAvailabilityContext };
|
||||
|
||||
/**
|
||||
* Global fallback for maximum retry attempts when not explicitly provided.
|
||||
* Most callers should use config.getMaxAttempts() instead.
|
||||
*/
|
||||
export const DEFAULT_MAX_ATTEMPTS = 10;
|
||||
|
||||
export interface RetryOptions {
|
||||
/**
|
||||
* Total number of attempts (1 initial + N retries).
|
||||
* Defaults to DEFAULT_MAX_ATTEMPTS (10) if not specified.
|
||||
*/
|
||||
maxAttempts: number;
|
||||
initialDelayMs: number;
|
||||
maxDelayMs: number;
|
||||
|
||||
Reference in New Issue
Block a user