feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)

This commit is contained in:
Adam Weidman
2025-12-08 06:44:34 -08:00
committed by GitHub
parent 7a72037572
commit 8f4f8baa81
20 changed files with 1611 additions and 119 deletions
@@ -0,0 +1,25 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
TerminalQuotaError,
RetryableQuotaError,
} from '../utils/googleQuotaErrors.js';
import { ModelNotFoundError } from '../utils/httpErrors.js';
import type { FailureKind } from './modelPolicy.js';
export function classifyFailureKind(error: unknown): FailureKind {
if (error instanceof TerminalQuotaError) {
return 'terminal';
}
if (error instanceof RetryableQuotaError) {
return 'transient';
}
if (error instanceof ModelNotFoundError) {
return 'not_found';
}
return 'unknown';
}
+13 -1
View File
@@ -4,7 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { ModelHealthStatus, ModelId } from './modelAvailabilityService.js';
import type {
ModelAvailabilityService,
ModelHealthStatus,
ModelId,
} from './modelAvailabilityService.js';
/**
* Whether to prompt the user or fallback silently on a model API failure.
@@ -49,3 +53,11 @@ export interface ModelPolicy {
* The first model in the chain is the primary model.
*/
export type ModelPolicyChain = ModelPolicy[];
/**
* Context required by retry logic to apply availability policies on failure.
*/
export interface RetryAvailabilityContext {
service: ModelAvailabilityService;
policy: ModelPolicy;
}
@@ -51,6 +51,7 @@ const PREVIEW_CHAIN: ModelPolicyChain = [
definePolicy({
model: PREVIEW_GEMINI_MODEL,
stateTransitions: { transient: 'sticky_retry' },
actions: { transient: 'silent' },
}),
definePolicy({ model: DEFAULT_GEMINI_MODEL }),
definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }),
@@ -4,38 +4,54 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
resolvePolicyChain,
buildFallbackPolicyContext,
applyModelSelection,
} from './policyHelpers.js';
import { createDefaultPolicy } from './policyCatalog.js';
import type { Config } from '../config/config.js';
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
getPreviewFeatures: () => false,
getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
...overrides,
}) as unknown as Config;
describe('policyHelpers', () => {
describe('resolvePolicyChain', () => {
it('inserts the active model when missing from the catalog', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
const config = createMockConfig({
getModel: () => 'custom-model',
isInFallbackMode: () => false,
} as unknown as Config;
});
const chain = resolvePolicyChain(config);
expect(chain).toHaveLength(1);
expect(chain[0]?.model).toBe('custom-model');
});
it('leaves catalog order untouched when active model already present', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
const config = createMockConfig({
getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
} as unknown as Config;
});
const chain = resolvePolicyChain(config);
expect(chain[0]?.model).toBe('gemini-2.5-pro');
});
it('returns the default chain when active model is "auto"', () => {
const config = createMockConfig({
getModel: () => 'auto',
});
const chain = resolvePolicyChain(config);
// Expect default chain [Pro, Flash]
expect(chain).toHaveLength(2);
expect(chain[0]?.model).toBe('gemini-2.5-pro');
expect(chain[1]?.model).toBe('gemini-2.5-flash');
});
});
describe('buildFallbackPolicyContext', () => {
@@ -57,4 +73,112 @@ describe('policyHelpers', () => {
expect(context.candidates).toEqual(chain);
});
});
describe('applyModelSelection', () => {
const mockModelConfigService = {
getResolvedConfig: vi.fn(),
};
const mockAvailabilityService = {
selectFirstAvailable: vi.fn(),
consumeStickyAttempt: vi.fn(),
};
const createExtendedMockConfig = (
overrides: Partial<Config> = {},
): Config => {
const defaults = {
isModelAvailabilityServiceEnabled: () => true,
getModelAvailabilityService: () => mockAvailabilityService,
setActiveModel: vi.fn(),
modelConfigService: mockModelConfigService,
};
return createMockConfig({ ...defaults, ...overrides } as Partial<Config>);
};
beforeEach(() => {
vi.clearAllMocks();
});
it('returns requested model if availability service is disabled', () => {
const config = createExtendedMockConfig({
isModelAvailabilityServiceEnabled: () => false,
});
const result = applyModelSelection(config, 'gemini-pro');
expect(result.model).toBe('gemini-pro');
expect(config.setActiveModel).not.toHaveBeenCalled();
});
it('returns requested model if it is available', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
});
const result = applyModelSelection(config, 'gemini-pro');
expect(result.model).toBe('gemini-pro');
expect(result.maxAttempts).toBeUndefined();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
});
it('switches to backup model and updates config if requested is unavailable', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-flash',
});
mockModelConfigService.getResolvedConfig.mockReturnValue({
generateContentConfig: { temperature: 0.1 },
});
const currentConfig = { temperature: 0.9, topP: 1 };
const result = applyModelSelection(config, 'gemini-pro', currentConfig);
expect(result.model).toBe('gemini-flash');
expect(result.config).toEqual({
temperature: 0.1,
topP: 1,
});
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-flash',
});
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
});
it('consumes sticky attempt if indicated', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(config, 'gemini-pro');
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro',
);
expect(result.maxAttempts).toBe(1);
});
it('does not consume sticky attempt if consumeAttempt is false', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(
config,
'gemini-pro',
undefined,
undefined,
{
consumeAttempt: false,
},
);
expect(
mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled();
expect(result.maxAttempts).toBe(1);
});
});
});
+139 -9
View File
@@ -4,34 +4,47 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../config/config.js';
import type {
FailureKind,
FallbackAction,
ModelPolicy,
ModelPolicyChain,
RetryAvailabilityContext,
} from './modelPolicy.js';
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js';
import { getEffectiveModel } from '../config/models.js';
import { DEFAULT_GEMINI_MODEL, getEffectiveModel } from '../config/models.js';
import type { ModelSelectionResult } from './modelAvailabilityService.js';
/**
* Resolves the active policy chain for the given config, ensuring the
* user-selected active model is represented.
*/
export function resolvePolicyChain(config: Config): ModelPolicyChain {
export function resolvePolicyChain(
config: Config,
preferredModel?: string,
): ModelPolicyChain {
const chain = getModelPolicyChain({
previewEnabled: !!config.getPreviewFeatures(),
userTier: config.getUserTier(),
});
// TODO: This will be replaced when we get rid of Fallback Modes
const activeModel = getEffectiveModel(
config.isInFallbackMode(),
config.getModel(),
config.getPreviewFeatures(),
);
// TODO: This will be replaced when we get rid of Fallback Modes.
// Switch to getActiveModel()
const activeModel =
preferredModel ??
getEffectiveModel(
config.isInFallbackMode(),
config.getModel(),
config.getPreviewFeatures(),
);
if (activeModel === 'auto') {
return [...chain];
}
if (chain.some((policy) => policy.model === activeModel)) {
return chain;
return [...chain];
}
// If the user specified a model not in the default chain, we assume they want
@@ -68,3 +81,120 @@ export function resolvePolicyAction(
): FallbackAction {
return policy.actions?.[failureKind] ?? 'prompt';
}
/**
* Creates a context provider for retry logic that returns the availability
* sevice and resolves the current model's policy.
*
* @param modelGetter A function that returns the model ID currently being attempted.
* (Allows handling dynamic model changes during retries).
*/
export function createAvailabilityContextProvider(
config: Config,
modelGetter: () => string,
): () => RetryAvailabilityContext | undefined {
return () => {
if (!config.isModelAvailabilityServiceEnabled()) {
return undefined;
}
const service = config.getModelAvailabilityService();
const currentModel = modelGetter();
// Resolve the chain for the specific model we are attempting.
const chain = resolvePolicyChain(config, currentModel);
const policy = chain.find((p) => p.model === currentModel);
return policy ? { service, policy } : undefined;
};
}
/**
* Selects the model to use for an attempt via the availability service and
* returns the selection context.
*/
export function selectModelForAvailability(
config: Config,
requestedModel: string,
): ModelSelectionResult | undefined {
if (!config.isModelAvailabilityServiceEnabled()) {
return undefined;
}
const chain = resolvePolicyChain(config, requestedModel);
const selection = config
.getModelAvailabilityService()
.selectFirstAvailable(chain.map((p) => p.model));
if (selection.selectedModel) return selection;
const backupModel =
chain.find((p) => p.isLastResort)?.model ?? DEFAULT_GEMINI_MODEL;
return { selectedModel: backupModel, skipped: [] };
}
/**
* Applies the model availability selection logic, including side effects
* (setting active model, consuming sticky attempts) and config updates.
*/
export function applyModelSelection(
config: Config,
requestedModel: string,
currentConfig?: GenerateContentConfig,
overrideScope?: string,
options: { consumeAttempt?: boolean } = {},
): { model: string; config?: GenerateContentConfig; maxAttempts?: number } {
const selection = selectModelForAvailability(config, requestedModel);
if (!selection?.selectedModel) {
return { model: requestedModel, config: currentConfig };
}
const finalModel = selection.selectedModel;
let finalConfig = currentConfig;
// If model changed, re-resolve config
if (finalModel !== requestedModel) {
const { generateContentConfig } =
config.modelConfigService.getResolvedConfig({
overrideScope,
model: finalModel,
});
finalConfig = currentConfig
? { ...currentConfig, ...generateContentConfig }
: generateContentConfig;
}
config.setActiveModel(finalModel);
if (selection.attempts && options.consumeAttempt !== false) {
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
}
return {
model: finalModel,
config: finalConfig,
maxAttempts: selection.attempts,
};
}
export function applyAvailabilityTransition(
getContext: (() => RetryAvailabilityContext | undefined) | undefined,
failureKind: FailureKind,
): void {
const context = getContext?.();
if (!context) return;
const transition = context.policy.stateTransitions?.[failureKind];
if (!transition) return;
if (transition === 'terminal') {
context.service.markTerminal(
context.policy.model,
failureKind === 'terminal' ? 'quota' : 'capacity',
);
} else if (transition === 'sticky_retry') {
context.service.markRetryOncePerTurn(context.policy.model);
}
}
@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type {
ModelAvailabilityService,
ModelSelectionResult,
} from './modelAvailabilityService.js';
/**
* Test helper to create a fully mocked ModelAvailabilityService.
*/
export function createAvailabilityServiceMock(
selection: ModelSelectionResult = { selectedModel: null, skipped: [] },
): ModelAvailabilityService {
const service = {
markTerminal: vi.fn(),
markHealthy: vi.fn(),
markRetryOncePerTurn: vi.fn(),
consumeStickyAttempt: vi.fn(),
snapshot: vi.fn(),
resetTurn: vi.fn(),
selectFirstAvailable: vi.fn().mockReturnValue(selection),
};
return service as unknown as ModelAvailabilityService;
}