feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)

This commit is contained in:
Adam Weidman
2025-12-08 06:44:34 -08:00
committed by GitHub
parent 7a72037572
commit 8f4f8baa81
20 changed files with 1611 additions and 119 deletions
+170 -1
View File
@@ -17,6 +17,9 @@ import {
RetryableQuotaError,
} from './googleQuotaErrors.js';
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
import type { ModelPolicy } from '../availability/modelPolicy.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
// Helper to create a mock function that fails a certain number of times
const createFailingFunction = (
@@ -104,7 +107,6 @@ describe('retryWithBackoff', () => {
const promise = retryWithBackoff(mockFn);
// Expect it to fail with the error from the 5th attempt.
await Promise.all([
expect(promise).rejects.toThrow('Simulated error attempt 3'),
vi.runAllTimersAsync(),
@@ -566,4 +568,171 @@ describe('retryWithBackoff', () => {
);
expect(mockFn).toHaveBeenCalledTimes(2);
});
describe('Availability Context Integration', () => {
let mockService: ModelAvailabilityService;
let mockPolicy1: ModelPolicy;
let mockPolicy2: ModelPolicy;
beforeEach(() => {
vi.useRealTimers();
mockService = createAvailabilityServiceMock();
mockPolicy1 = {
model: 'model-1',
actions: {},
stateTransitions: {
terminal: 'terminal',
transient: 'sticky_retry',
},
};
mockPolicy2 = {
model: 'model-2',
actions: {},
stateTransitions: {
terminal: 'terminal',
},
};
});
it('updates availability context per attempt and applies transitions to the correct policy', async () => {
const error = new TerminalQuotaError(
'quota exceeded',
{ code: 429, message: 'quota', details: [] },
10,
);
const fn = vi.fn().mockImplementation(async () => {
throw error; // Always fail with quota
});
const onPersistent429 = vi
.fn()
.mockResolvedValueOnce('model-2') // First fallback success
.mockResolvedValueOnce(null); // Second fallback fails (give up)
// Context provider returns policy1 first, then policy2
const getContext = vi
.fn()
.mockReturnValueOnce({ service: mockService, policy: mockPolicy1 })
.mockReturnValueOnce({ service: mockService, policy: mockPolicy2 });
await expect(
retryWithBackoff(fn, {
maxAttempts: 3,
initialDelayMs: 1,
getAvailabilityContext: getContext,
onPersistent429,
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
).rejects.toThrow(TerminalQuotaError);
// Verify failures
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
expect(mockService.markTerminal).toHaveBeenCalledWith('model-2', 'quota');
// Verify sequences
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
1,
'model-1',
'quota',
);
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
2,
'model-2',
'quota',
);
});
it('marks sticky_retry after retries are exhausted for transient failures', async () => {
const transientError = new RetryableQuotaError(
'transient error',
{ code: 429, message: 'transient', details: [] },
0,
);
const fn = vi.fn().mockRejectedValue(transientError);
const getContext = vi
.fn()
.mockReturnValue({ service: mockService, policy: mockPolicy1 });
vi.useFakeTimers();
const promise = retryWithBackoff(fn, {
maxAttempts: 3,
getAvailabilityContext: getContext,
initialDelayMs: 1,
maxDelayMs: 1,
}).catch((err) => err);
await vi.runAllTimersAsync();
const result = await promise;
expect(result).toBe(transientError);
expect(fn).toHaveBeenCalledTimes(3);
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith('model-1');
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledTimes(1);
expect(mockService.markTerminal).not.toHaveBeenCalled();
});
it('maps different failure kinds to correct terminal reasons', async () => {
const quotaError = new TerminalQuotaError(
'quota',
{ code: 429, message: 'q', details: [] },
10,
);
const notFoundError = new ModelNotFoundError('not found', 404);
const genericError = new Error('unknown error');
const fn = vi
.fn()
.mockRejectedValueOnce(quotaError)
.mockRejectedValueOnce(notFoundError)
.mockRejectedValueOnce(genericError);
const policy: ModelPolicy = {
model: 'model-1',
actions: {},
stateTransitions: {
terminal: 'terminal', // from quotaError
not_found: 'terminal', // from notFoundError
unknown: 'terminal', // from genericError
},
};
const getContext = vi
.fn()
.mockReturnValue({ service: mockService, policy });
// Run for quotaError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
// Run for notFoundError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith(
'model-1',
'capacity',
);
// Run for genericError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith(
'model-1',
'capacity',
);
expect(mockService.markTerminal).toHaveBeenCalledTimes(3);
});
});
});
+28 -2
View File
@@ -8,13 +8,18 @@ import type { GenerateContentResponse } from '@google/genai';
import { ApiError } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js';
import {
classifyGoogleError,
RetryableQuotaError,
TerminalQuotaError,
RetryableQuotaError,
classifyGoogleError,
} from './googleQuotaErrors.js';
import { delay, createAbortError } from './delay.js';
import { debugLogger } from './debugLogger.js';
import { getErrorStatus, ModelNotFoundError } from './httpErrors.js';
import type { RetryAvailabilityContext } from '../availability/modelPolicy.js';
import { classifyFailureKind } from '../availability/errorClassification.js';
import { applyAvailabilityTransition } from '../availability/policyHelpers.js';
export type { RetryAvailabilityContext };
export interface RetryOptions {
maxAttempts: number;
@@ -29,6 +34,7 @@ export interface RetryOptions {
authType?: string;
retryFetchErrors?: boolean;
signal?: AbortSignal;
getAvailabilityContext?: () => RetryAvailabilityContext | undefined;
}
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
@@ -145,6 +151,7 @@ export async function retryWithBackoff<T>(
shouldRetryOnContent,
retryFetchErrors,
signal,
getAvailabilityContext,
} = {
...DEFAULT_RETRY_OPTIONS,
shouldRetryOnError: isRetryableError,
@@ -173,6 +180,11 @@ export async function retryWithBackoff<T>(
continue;
}
const successContext = getAvailabilityContext?.();
if (successContext) {
successContext.service.markHealthy(successContext.policy.model);
}
return result;
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
@@ -180,6 +192,13 @@ export async function retryWithBackoff<T>(
}
const classifiedError = classifyGoogleError(error);
const failureKind = classifyFailureKind(classifiedError);
const appliedImmediate =
failureKind === 'terminal' || failureKind === 'not_found';
if (appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
const errorCode = getErrorStatus(error);
if (
@@ -201,6 +220,7 @@ export async function retryWithBackoff<T>(
debugLogger.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Terminal/not_found already recorded; nothing else to mark here.
throw classifiedError; // Throw if no fallback or fallback failed.
}
@@ -224,6 +244,9 @@ export async function retryWithBackoff<T>(
console.warn('Model fallback failed:', fallbackError);
}
}
if (!appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
throw classifiedError instanceof RetryableQuotaError
? classifiedError
: error;
@@ -253,6 +276,9 @@ export async function retryWithBackoff<T>(
attempt >= maxAttempts ||
!shouldRetryOnError(error as Error, retryFetchErrors)
) {
if (!appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
throw error;
}