mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-21 10:34:35 -07:00
feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)
This commit is contained in:
@@ -17,6 +17,9 @@ import {
|
||||
RetryableQuotaError,
|
||||
} from './googleQuotaErrors.js';
|
||||
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { ModelPolicy } from '../availability/modelPolicy.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
|
||||
// Helper to create a mock function that fails a certain number of times
|
||||
const createFailingFunction = (
|
||||
@@ -104,7 +107,6 @@ describe('retryWithBackoff', () => {
|
||||
|
||||
const promise = retryWithBackoff(mockFn);
|
||||
|
||||
// Expect it to fail with the error from the 5th attempt.
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Simulated error attempt 3'),
|
||||
vi.runAllTimersAsync(),
|
||||
@@ -566,4 +568,171 @@ describe('retryWithBackoff', () => {
|
||||
);
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
describe('Availability Context Integration', () => {
|
||||
let mockService: ModelAvailabilityService;
|
||||
let mockPolicy1: ModelPolicy;
|
||||
let mockPolicy2: ModelPolicy;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
mockService = createAvailabilityServiceMock();
|
||||
|
||||
mockPolicy1 = {
|
||||
model: 'model-1',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
},
|
||||
};
|
||||
|
||||
mockPolicy2 = {
|
||||
model: 'model-2',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
it('updates availability context per attempt and applies transitions to the correct policy', async () => {
|
||||
const error = new TerminalQuotaError(
|
||||
'quota exceeded',
|
||||
{ code: 429, message: 'quota', details: [] },
|
||||
10,
|
||||
);
|
||||
|
||||
const fn = vi.fn().mockImplementation(async () => {
|
||||
throw error; // Always fail with quota
|
||||
});
|
||||
|
||||
const onPersistent429 = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce('model-2') // First fallback success
|
||||
.mockResolvedValueOnce(null); // Second fallback fails (give up)
|
||||
|
||||
// Context provider returns policy1 first, then policy2
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ service: mockService, policy: mockPolicy1 })
|
||||
.mockReturnValueOnce({ service: mockService, policy: mockPolicy2 });
|
||||
|
||||
await expect(
|
||||
retryWithBackoff(fn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
onPersistent429,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
}),
|
||||
).rejects.toThrow(TerminalQuotaError);
|
||||
|
||||
// Verify failures
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-2', 'quota');
|
||||
|
||||
// Verify sequences
|
||||
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'model-1',
|
||||
'quota',
|
||||
);
|
||||
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'model-2',
|
||||
'quota',
|
||||
);
|
||||
});
|
||||
|
||||
it('marks sticky_retry after retries are exhausted for transient failures', async () => {
|
||||
const transientError = new RetryableQuotaError(
|
||||
'transient error',
|
||||
{ code: 429, message: 'transient', details: [] },
|
||||
0,
|
||||
);
|
||||
|
||||
const fn = vi.fn().mockRejectedValue(transientError);
|
||||
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue({ service: mockService, policy: mockPolicy1 });
|
||||
|
||||
vi.useFakeTimers();
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 3,
|
||||
getAvailabilityContext: getContext,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 1,
|
||||
}).catch((err) => err);
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
const result = await promise;
|
||||
expect(result).toBe(transientError);
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(3);
|
||||
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith('model-1');
|
||||
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledTimes(1);
|
||||
expect(mockService.markTerminal).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('maps different failure kinds to correct terminal reasons', async () => {
|
||||
const quotaError = new TerminalQuotaError(
|
||||
'quota',
|
||||
{ code: 429, message: 'q', details: [] },
|
||||
10,
|
||||
);
|
||||
const notFoundError = new ModelNotFoundError('not found', 404);
|
||||
const genericError = new Error('unknown error');
|
||||
|
||||
const fn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(quotaError)
|
||||
.mockRejectedValueOnce(notFoundError)
|
||||
.mockRejectedValueOnce(genericError);
|
||||
|
||||
const policy: ModelPolicy = {
|
||||
model: 'model-1',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal', // from quotaError
|
||||
not_found: 'terminal', // from notFoundError
|
||||
unknown: 'terminal', // from genericError
|
||||
},
|
||||
};
|
||||
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue({ service: mockService, policy });
|
||||
|
||||
// Run for quotaError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
|
||||
|
||||
// Run for notFoundError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith(
|
||||
'model-1',
|
||||
'capacity',
|
||||
);
|
||||
|
||||
// Run for genericError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith(
|
||||
'model-1',
|
||||
'capacity',
|
||||
);
|
||||
|
||||
expect(mockService.markTerminal).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,13 +8,18 @@ import type { GenerateContentResponse } from '@google/genai';
|
||||
import { ApiError } from '@google/genai';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
classifyGoogleError,
|
||||
RetryableQuotaError,
|
||||
TerminalQuotaError,
|
||||
RetryableQuotaError,
|
||||
classifyGoogleError,
|
||||
} from './googleQuotaErrors.js';
|
||||
import { delay, createAbortError } from './delay.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import { getErrorStatus, ModelNotFoundError } from './httpErrors.js';
|
||||
import type { RetryAvailabilityContext } from '../availability/modelPolicy.js';
|
||||
import { classifyFailureKind } from '../availability/errorClassification.js';
|
||||
import { applyAvailabilityTransition } from '../availability/policyHelpers.js';
|
||||
|
||||
export type { RetryAvailabilityContext };
|
||||
|
||||
export interface RetryOptions {
|
||||
maxAttempts: number;
|
||||
@@ -29,6 +34,7 @@ export interface RetryOptions {
|
||||
authType?: string;
|
||||
retryFetchErrors?: boolean;
|
||||
signal?: AbortSignal;
|
||||
getAvailabilityContext?: () => RetryAvailabilityContext | undefined;
|
||||
}
|
||||
|
||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
@@ -145,6 +151,7 @@ export async function retryWithBackoff<T>(
|
||||
shouldRetryOnContent,
|
||||
retryFetchErrors,
|
||||
signal,
|
||||
getAvailabilityContext,
|
||||
} = {
|
||||
...DEFAULT_RETRY_OPTIONS,
|
||||
shouldRetryOnError: isRetryableError,
|
||||
@@ -173,6 +180,11 @@ export async function retryWithBackoff<T>(
|
||||
continue;
|
||||
}
|
||||
|
||||
const successContext = getAvailabilityContext?.();
|
||||
if (successContext) {
|
||||
successContext.service.markHealthy(successContext.policy.model);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
@@ -180,6 +192,13 @@ export async function retryWithBackoff<T>(
|
||||
}
|
||||
|
||||
const classifiedError = classifyGoogleError(error);
|
||||
const failureKind = classifyFailureKind(classifiedError);
|
||||
const appliedImmediate =
|
||||
failureKind === 'terminal' || failureKind === 'not_found';
|
||||
if (appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
|
||||
const errorCode = getErrorStatus(error);
|
||||
|
||||
if (
|
||||
@@ -201,6 +220,7 @@ export async function retryWithBackoff<T>(
|
||||
debugLogger.warn('Fallback to Flash model failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
// Terminal/not_found already recorded; nothing else to mark here.
|
||||
throw classifiedError; // Throw if no fallback or fallback failed.
|
||||
}
|
||||
|
||||
@@ -224,6 +244,9 @@ export async function retryWithBackoff<T>(
|
||||
console.warn('Model fallback failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
if (!appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
throw classifiedError instanceof RetryableQuotaError
|
||||
? classifiedError
|
||||
: error;
|
||||
@@ -253,6 +276,9 @@ export async function retryWithBackoff<T>(
|
||||
attempt >= maxAttempts ||
|
||||
!shouldRetryOnError(error as Error, retryFetchErrors)
|
||||
) {
|
||||
if (!appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user