mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-20 18:14:29 -07:00
feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
TerminalQuotaError,
|
||||
RetryableQuotaError,
|
||||
} from '../utils/googleQuotaErrors.js';
|
||||
import { ModelNotFoundError } from '../utils/httpErrors.js';
|
||||
import type { FailureKind } from './modelPolicy.js';
|
||||
|
||||
export function classifyFailureKind(error: unknown): FailureKind {
|
||||
if (error instanceof TerminalQuotaError) {
|
||||
return 'terminal';
|
||||
}
|
||||
if (error instanceof RetryableQuotaError) {
|
||||
return 'transient';
|
||||
}
|
||||
if (error instanceof ModelNotFoundError) {
|
||||
return 'not_found';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
||||
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ModelHealthStatus, ModelId } from './modelAvailabilityService.js';
|
||||
import type {
|
||||
ModelAvailabilityService,
|
||||
ModelHealthStatus,
|
||||
ModelId,
|
||||
} from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Whether to prompt the user or fallback silently on a model API failure.
|
||||
@@ -49,3 +53,11 @@ export interface ModelPolicy {
|
||||
* The first model in the chain is the primary model.
|
||||
*/
|
||||
export type ModelPolicyChain = ModelPolicy[];
|
||||
|
||||
/**
|
||||
* Context required by retry logic to apply availability policies on failure.
|
||||
*/
|
||||
export interface RetryAvailabilityContext {
|
||||
service: ModelAvailabilityService;
|
||||
policy: ModelPolicy;
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ const PREVIEW_CHAIN: ModelPolicyChain = [
|
||||
definePolicy({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
stateTransitions: { transient: 'sticky_retry' },
|
||||
actions: { transient: 'silent' },
|
||||
}),
|
||||
definePolicy({ model: DEFAULT_GEMINI_MODEL }),
|
||||
definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }),
|
||||
|
||||
@@ -4,38 +4,54 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
resolvePolicyChain,
|
||||
buildFallbackPolicyContext,
|
||||
applyModelSelection,
|
||||
} from './policyHelpers.js';
|
||||
import { createDefaultPolicy } from './policyCatalog.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
getModel: () => 'gemini-2.5-pro',
|
||||
isInFallbackMode: () => false,
|
||||
...overrides,
|
||||
}) as unknown as Config;
|
||||
|
||||
describe('policyHelpers', () => {
|
||||
describe('resolvePolicyChain', () => {
|
||||
it('inserts the active model when missing from the catalog', () => {
|
||||
const config = {
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'custom-model',
|
||||
isInFallbackMode: () => false,
|
||||
} as unknown as Config;
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
expect(chain).toHaveLength(1);
|
||||
expect(chain[0]?.model).toBe('custom-model');
|
||||
});
|
||||
|
||||
it('leaves catalog order untouched when active model already present', () => {
|
||||
const config = {
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'gemini-2.5-pro',
|
||||
isInFallbackMode: () => false,
|
||||
} as unknown as Config;
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
expect(chain[0]?.model).toBe('gemini-2.5-pro');
|
||||
});
|
||||
|
||||
it('returns the default chain when active model is "auto"', () => {
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'auto',
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
|
||||
// Expect default chain [Pro, Flash]
|
||||
expect(chain).toHaveLength(2);
|
||||
expect(chain[0]?.model).toBe('gemini-2.5-pro');
|
||||
expect(chain[1]?.model).toBe('gemini-2.5-flash');
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildFallbackPolicyContext', () => {
|
||||
@@ -57,4 +73,112 @@ describe('policyHelpers', () => {
|
||||
expect(context.candidates).toEqual(chain);
|
||||
});
|
||||
});
|
||||
|
||||
describe('applyModelSelection', () => {
|
||||
const mockModelConfigService = {
|
||||
getResolvedConfig: vi.fn(),
|
||||
};
|
||||
|
||||
const mockAvailabilityService = {
|
||||
selectFirstAvailable: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
};
|
||||
|
||||
const createExtendedMockConfig = (
|
||||
overrides: Partial<Config> = {},
|
||||
): Config => {
|
||||
const defaults = {
|
||||
isModelAvailabilityServiceEnabled: () => true,
|
||||
getModelAvailabilityService: () => mockAvailabilityService,
|
||||
setActiveModel: vi.fn(),
|
||||
modelConfigService: mockModelConfigService,
|
||||
};
|
||||
return createMockConfig({ ...defaults, ...overrides } as Partial<Config>);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns requested model if availability service is disabled', () => {
|
||||
const config = createExtendedMockConfig({
|
||||
isModelAvailabilityServiceEnabled: () => false,
|
||||
});
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(result.model).toBe('gemini-pro');
|
||||
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns requested model if it is available', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(result.model).toBe('gemini-pro');
|
||||
expect(result.maxAttempts).toBeUndefined();
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||
});
|
||||
|
||||
it('switches to backup model and updates config if requested is unavailable', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-flash',
|
||||
});
|
||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||
generateContentConfig: { temperature: 0.1 },
|
||||
});
|
||||
|
||||
const currentConfig = { temperature: 0.9, topP: 1 };
|
||||
const result = applyModelSelection(config, 'gemini-pro', currentConfig);
|
||||
|
||||
expect(result.model).toBe('gemini-flash');
|
||||
expect(result.config).toEqual({
|
||||
temperature: 0.1,
|
||||
topP: 1,
|
||||
});
|
||||
|
||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||
model: 'gemini-flash',
|
||||
});
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
|
||||
});
|
||||
|
||||
it('consumes sticky attempt if indicated', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'gemini-pro',
|
||||
);
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
|
||||
it('does not consume sticky attempt if consumeAttempt is false', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(
|
||||
config,
|
||||
'gemini-pro',
|
||||
undefined,
|
||||
undefined,
|
||||
{
|
||||
consumeAttempt: false,
|
||||
},
|
||||
);
|
||||
expect(
|
||||
mockAvailabilityService.consumeStickyAttempt,
|
||||
).not.toHaveBeenCalled();
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,34 +4,47 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
FailureKind,
|
||||
FallbackAction,
|
||||
ModelPolicy,
|
||||
ModelPolicyChain,
|
||||
RetryAvailabilityContext,
|
||||
} from './modelPolicy.js';
|
||||
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js';
|
||||
import { getEffectiveModel } from '../config/models.js';
|
||||
import { DEFAULT_GEMINI_MODEL, getEffectiveModel } from '../config/models.js';
|
||||
import type { ModelSelectionResult } from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Resolves the active policy chain for the given config, ensuring the
|
||||
* user-selected active model is represented.
|
||||
*/
|
||||
export function resolvePolicyChain(config: Config): ModelPolicyChain {
|
||||
export function resolvePolicyChain(
|
||||
config: Config,
|
||||
preferredModel?: string,
|
||||
): ModelPolicyChain {
|
||||
const chain = getModelPolicyChain({
|
||||
previewEnabled: !!config.getPreviewFeatures(),
|
||||
userTier: config.getUserTier(),
|
||||
});
|
||||
// TODO: This will be replaced when we get rid of Fallback Modes
|
||||
const activeModel = getEffectiveModel(
|
||||
config.isInFallbackMode(),
|
||||
config.getModel(),
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
// TODO: This will be replaced when we get rid of Fallback Modes.
|
||||
// Switch to getActiveModel()
|
||||
const activeModel =
|
||||
preferredModel ??
|
||||
getEffectiveModel(
|
||||
config.isInFallbackMode(),
|
||||
config.getModel(),
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
if (activeModel === 'auto') {
|
||||
return [...chain];
|
||||
}
|
||||
|
||||
if (chain.some((policy) => policy.model === activeModel)) {
|
||||
return chain;
|
||||
return [...chain];
|
||||
}
|
||||
|
||||
// If the user specified a model not in the default chain, we assume they want
|
||||
@@ -68,3 +81,120 @@ export function resolvePolicyAction(
|
||||
): FallbackAction {
|
||||
return policy.actions?.[failureKind] ?? 'prompt';
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a context provider for retry logic that returns the availability
|
||||
* sevice and resolves the current model's policy.
|
||||
*
|
||||
* @param modelGetter A function that returns the model ID currently being attempted.
|
||||
* (Allows handling dynamic model changes during retries).
|
||||
*/
|
||||
export function createAvailabilityContextProvider(
|
||||
config: Config,
|
||||
modelGetter: () => string,
|
||||
): () => RetryAvailabilityContext | undefined {
|
||||
return () => {
|
||||
if (!config.isModelAvailabilityServiceEnabled()) {
|
||||
return undefined;
|
||||
}
|
||||
const service = config.getModelAvailabilityService();
|
||||
const currentModel = modelGetter();
|
||||
|
||||
// Resolve the chain for the specific model we are attempting.
|
||||
const chain = resolvePolicyChain(config, currentModel);
|
||||
const policy = chain.find((p) => p.model === currentModel);
|
||||
|
||||
return policy ? { service, policy } : undefined;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the model to use for an attempt via the availability service and
|
||||
* returns the selection context.
|
||||
*/
|
||||
export function selectModelForAvailability(
|
||||
config: Config,
|
||||
requestedModel: string,
|
||||
): ModelSelectionResult | undefined {
|
||||
if (!config.isModelAvailabilityServiceEnabled()) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const chain = resolvePolicyChain(config, requestedModel);
|
||||
const selection = config
|
||||
.getModelAvailabilityService()
|
||||
.selectFirstAvailable(chain.map((p) => p.model));
|
||||
|
||||
if (selection.selectedModel) return selection;
|
||||
|
||||
const backupModel =
|
||||
chain.find((p) => p.isLastResort)?.model ?? DEFAULT_GEMINI_MODEL;
|
||||
|
||||
return { selectedModel: backupModel, skipped: [] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the model availability selection logic, including side effects
|
||||
* (setting active model, consuming sticky attempts) and config updates.
|
||||
*/
|
||||
export function applyModelSelection(
|
||||
config: Config,
|
||||
requestedModel: string,
|
||||
currentConfig?: GenerateContentConfig,
|
||||
overrideScope?: string,
|
||||
options: { consumeAttempt?: boolean } = {},
|
||||
): { model: string; config?: GenerateContentConfig; maxAttempts?: number } {
|
||||
const selection = selectModelForAvailability(config, requestedModel);
|
||||
|
||||
if (!selection?.selectedModel) {
|
||||
return { model: requestedModel, config: currentConfig };
|
||||
}
|
||||
|
||||
const finalModel = selection.selectedModel;
|
||||
let finalConfig = currentConfig;
|
||||
|
||||
// If model changed, re-resolve config
|
||||
if (finalModel !== requestedModel) {
|
||||
const { generateContentConfig } =
|
||||
config.modelConfigService.getResolvedConfig({
|
||||
overrideScope,
|
||||
model: finalModel,
|
||||
});
|
||||
|
||||
finalConfig = currentConfig
|
||||
? { ...currentConfig, ...generateContentConfig }
|
||||
: generateContentConfig;
|
||||
}
|
||||
|
||||
config.setActiveModel(finalModel);
|
||||
|
||||
if (selection.attempts && options.consumeAttempt !== false) {
|
||||
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
||||
}
|
||||
|
||||
return {
|
||||
model: finalModel,
|
||||
config: finalConfig,
|
||||
maxAttempts: selection.attempts,
|
||||
};
|
||||
}
|
||||
|
||||
export function applyAvailabilityTransition(
|
||||
getContext: (() => RetryAvailabilityContext | undefined) | undefined,
|
||||
failureKind: FailureKind,
|
||||
): void {
|
||||
const context = getContext?.();
|
||||
if (!context) return;
|
||||
|
||||
const transition = context.policy.stateTransitions?.[failureKind];
|
||||
if (!transition) return;
|
||||
|
||||
if (transition === 'terminal') {
|
||||
context.service.markTerminal(
|
||||
context.policy.model,
|
||||
failureKind === 'terminal' ? 'quota' : 'capacity',
|
||||
);
|
||||
} else if (transition === 'sticky_retry') {
|
||||
context.service.markRetryOncePerTurn(context.policy.model);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi } from 'vitest';
|
||||
import type {
|
||||
ModelAvailabilityService,
|
||||
ModelSelectionResult,
|
||||
} from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Test helper to create a fully mocked ModelAvailabilityService.
|
||||
*/
|
||||
export function createAvailabilityServiceMock(
|
||||
selection: ModelSelectionResult = { selectedModel: null, skipped: [] },
|
||||
): ModelAvailabilityService {
|
||||
const service = {
|
||||
markTerminal: vi.fn(),
|
||||
markHealthy: vi.fn(),
|
||||
markRetryOncePerTurn: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
snapshot: vi.fn(),
|
||||
resetTurn: vi.fn(),
|
||||
selectFirstAvailable: vi.fn().mockReturnValue(selection),
|
||||
};
|
||||
|
||||
return service as unknown as ModelAvailabilityService;
|
||||
}
|
||||
Reference in New Issue
Block a user