From 7a9ed4c20a1640fe08c571cf8bf45371f3e77e32 Mon Sep 17 00:00:00 2001 From: Coco Sheng Date: Tue, 12 May 2026 10:26:50 -0400 Subject: [PATCH] fix: respect explicit model selection after Flash quota exhaustion (#26759) (#26872) --- .../modelAvailabilityService.test.ts | 42 ++++++++++++++ .../availability/modelAvailabilityService.ts | 22 ++++--- .../src/availability/policyHelpers.test.ts | 5 +- .../core/src/availability/policyHelpers.ts | 57 ++++++++++++------- .../strategies/classifierStrategy.test.ts | 4 ++ .../routing/strategies/classifierStrategy.ts | 29 +++++++--- .../numericalClassifierStrategy.test.ts | 4 ++ .../strategies/numericalClassifierStrategy.ts | 29 +++++++--- packages/core/src/utils/modelUtils.test.ts | 32 +++++++++++ packages/core/src/utils/modelUtils.ts | 17 ++++++ 10 files changed, 197 insertions(+), 44 deletions(-) create mode 100644 packages/core/src/utils/modelUtils.test.ts create mode 100644 packages/core/src/utils/modelUtils.ts diff --git a/packages/core/src/availability/modelAvailabilityService.test.ts b/packages/core/src/availability/modelAvailabilityService.test.ts index 2dcc90f477..9d468c543c 100644 --- a/packages/core/src/availability/modelAvailabilityService.test.ts +++ b/packages/core/src/availability/modelAvailabilityService.test.ts @@ -168,4 +168,46 @@ describe('ModelAvailabilityService', () => { reason: 'quota', }); }); + + describe('prefix normalization', () => { + it('treats prefixed and non-prefixed models as identical when marking terminal', () => { + service.markTerminal('models/gemini-3.1-pro-preview', 'quota'); + + // Checking the non-prefixed version should show it as unavailable + expect(service.snapshot('gemini-3.1-pro-preview')).toEqual({ + available: false, + reason: 'quota', + }); + + // Checking the prefixed version should also show it as unavailable + expect(service.snapshot('models/gemini-3.1-pro-preview')).toEqual({ + available: false, + reason: 'quota', + }); + }); + + it('treats prefixed and non-prefixed models as identical when selecting', () => { + service.markTerminal('gemini-3-flash-preview', 'quota'); + + // Attempting to select the prefixed version should skip it because the base is exhausted + const result = service.selectFirstAvailable([ + 'models/gemini-3-flash-preview', + 'gemini-3.1-pro-preview', + ]); + + expect(result.selectedModel).toBe('gemini-3.1-pro-preview'); + expect(result.skipped).toEqual([ + { model: 'gemini-3-flash-preview', reason: 'quota' }, + ]); + }); + + it('treats prefixed and non-prefixed models as identical when marking healthy', () => { + service.markTerminal('gemini-3-flash-preview', 'quota'); + service.markHealthy('models/gemini-3-flash-preview'); + + expect(service.snapshot('gemini-3-flash-preview')).toEqual({ + available: true, + }); + }); + }); }); diff --git a/packages/core/src/availability/modelAvailabilityService.ts b/packages/core/src/availability/modelAvailabilityService.ts index 9ef83230ec..631de67193 100644 --- a/packages/core/src/availability/modelAvailabilityService.ts +++ b/packages/core/src/availability/modelAvailabilityService.ts @@ -39,21 +39,26 @@ export interface ModelSelectionResult { }>; } +import { normalizeModelId } from '../utils/modelUtils.js'; + export class ModelAvailabilityService { private readonly health = new Map(); - markTerminal(model: ModelId, reason: TerminalUnavailabilityReason) { + markTerminal(modelId: ModelId, reason: TerminalUnavailabilityReason) { + const model = normalizeModelId(modelId); this.setState(model, { status: 'terminal', reason, }); } - markHealthy(model: ModelId) { + markHealthy(modelId: ModelId) { + const model = normalizeModelId(modelId); this.clearState(model); } - markRetryOncePerTurn(model: ModelId, attempts: number = 1) { + markRetryOncePerTurn(modelId: ModelId, attempts: number = 1) { + const model = normalizeModelId(modelId); const currentState = this.health.get(model); // Do not override a terminal failure with a transient one. if (currentState?.status === 'terminal') { @@ -75,14 +80,16 @@ export class ModelAvailabilityService { }); } - consumeStickyAttempt(model: ModelId) { + consumeStickyAttempt(modelId: ModelId) { + const model = normalizeModelId(modelId); const state = this.health.get(model); if (state?.status === 'sticky_retry') { this.setState(model, { ...state, consumed: true }); } } - snapshot(model: ModelId): ModelAvailabilitySnapshot { + snapshot(modelId: ModelId): ModelAvailabilitySnapshot { + const model = normalizeModelId(modelId); const state = this.health.get(model); if (!state) { @@ -100,10 +107,11 @@ export class ModelAvailabilityService { return { available: true }; } - selectFirstAvailable(models: ModelId[]): ModelSelectionResult { + selectFirstAvailable(modelIds: ModelId[]): ModelSelectionResult { const skipped: ModelSelectionResult['skipped'] = []; - for (const model of models) { + for (const modelId of modelIds) { + const model = normalizeModelId(modelId); const snapshot = this.snapshot(model); if (snapshot.available) { const state = this.health.get(model); diff --git a/packages/core/src/availability/policyHelpers.test.ts b/packages/core/src/availability/policyHelpers.test.ts index 945de646e0..7db99bb1aa 100644 --- a/packages/core/src/availability/policyHelpers.test.ts +++ b/packages/core/src/availability/policyHelpers.test.ts @@ -96,10 +96,11 @@ describe('policyHelpers', () => { it('starts chain from preferredModel when model is "auto"', () => { const config = createMockConfig({ - getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + getModel: () => 'auto', }); const chain = resolvePolicyChain(config, 'gemini-2.5-flash'); - expect(chain).toHaveLength(1); + // Due to Gemini 2.x wrapsAround, the chain will contain both flash and pro + expect(chain.length).toBeGreaterThanOrEqual(1); expect(chain[0]?.model).toBe('gemini-2.5-flash'); }); diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index 5d65a7598e..28447bd836 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -28,6 +28,7 @@ import { isGemini3Model, resolveModel, } from '../config/models.js'; +import { normalizeModelId } from '../utils/modelUtils.js'; import type { ModelSelectionResult } from './modelAvailabilityService.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; import { ApprovalMode } from '../policy/types.js'; @@ -41,9 +42,13 @@ export function resolvePolicyChain( preferredModel?: string, wrapsAround: boolean = false, ): ModelPolicyChain { - const modelFromConfig = - preferredModel ?? config.getActiveModel?.() ?? config.getModel(); - const configuredModel = config.getModel(); + const normalizedPreferredModel = preferredModel + ? normalizeModelId(preferredModel) + : undefined; + const modelFromConfig = normalizeModelId( + normalizedPreferredModel ?? config.getActiveModel?.() ?? config.getModel(), + ); + const configuredModel = normalizeModelId(config.getModel()); let chain: ModelPolicyChain | undefined; const useGemini31 = config.getGemini31LaunchedSync?.() ?? false; @@ -52,19 +57,29 @@ export function resolvePolicyChain( const useCustomToolModel = config.getUseCustomToolModelSync?.() ?? false; const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true; - const resolvedModel = resolveModel( - modelFromConfig, - useGemini31, - useGemini31FlashLite, - useCustomToolModel, - hasAccessToPreview, - config, + const resolvedModel = normalizeModelId( + resolveModel( + modelFromConfig, + useGemini31, + useGemini31FlashLite, + useCustomToolModel, + hasAccessToPreview, + config, + ), ); - const isAutoPreferred = preferredModel - ? isAutoModel(preferredModel, config) + const isAutoPreferred = normalizedPreferredModel + ? isAutoModel(normalizedPreferredModel, config) : false; const isAutoConfigured = isAutoModel(configuredModel, config); + // We always wrap around for Gemini 3 chains to ensure maximum availability + // between models in the same family (e.g. fallback to Pro if Flash is exhausted). + const effectiveWrapsAround = + wrapsAround || + isAutoPreferred || + isAutoConfigured || + isGemini3Model(resolvedModel, config); + // --- DYNAMIC PATH --- if (config.getExperimentalDynamicModelConfiguration?.() === true) { const context = { @@ -76,7 +91,7 @@ export function resolvePolicyChain( if (resolvedModel === DEFAULT_GEMINI_FLASH_LITE_MODEL) { chain = config.modelConfigService.resolveChain('lite', context); } else if ( - isGemini3Model(resolvedModel, config) || + isGemini3Model(normalizeModelId(resolvedModel), config) || isAutoPreferred || isAutoConfigured ) { @@ -96,7 +111,7 @@ export function resolvePolicyChain( const previewEnabled = hasAccessToPreview && (isGemini3Model(resolvedModel, config) || - preferredModel === PREVIEW_GEMINI_MODEL_AUTO || + normalizedPreferredModel === PREVIEW_GEMINI_MODEL_AUTO || configuredModel === PREVIEW_GEMINI_MODEL_AUTO); const autoPrefix = isAutoSelection ? 'auto-' : ''; const chainKey = previewEnabled ? 'preview' : 'default'; @@ -110,7 +125,7 @@ export function resolvePolicyChain( // No matching modelChains found, default to single model chain chain = createSingleModelChain(modelFromConfig); } - chain = applyDynamicSlicing(chain, resolvedModel, wrapsAround); + chain = applyDynamicSlicing(chain, resolvedModel, effectiveWrapsAround); } else { // --- LEGACY PATH --- @@ -125,7 +140,7 @@ export function resolvePolicyChain( if (hasAccessToPreview) { const previewEnabled = isGemini3Model(resolvedModel, config) || - preferredModel === PREVIEW_GEMINI_MODEL_AUTO || + normalizedPreferredModel === PREVIEW_GEMINI_MODEL_AUTO || configuredModel === PREVIEW_GEMINI_MODEL_AUTO; chain = getModelPolicyChain({ previewEnabled, @@ -150,7 +165,7 @@ export function resolvePolicyChain( } else { chain = createSingleModelChain(modelFromConfig); } - chain = applyDynamicSlicing(chain, resolvedModel, wrapsAround); + chain = applyDynamicSlicing(chain, resolvedModel, effectiveWrapsAround); } // Apply Unified Silent Injection for Plan Mode with defensive checks if (config?.getApprovalMode?.() === ApprovalMode.PLAN) { @@ -171,8 +186,9 @@ function applyDynamicSlicing( resolvedModel: string, wrapsAround: boolean, ): ModelPolicyChain { + const normalizedResolved = normalizeModelId(resolvedModel); const activeIndex = chain.findIndex( - (policy) => policy.model === resolvedModel, + (policy) => normalizeModelId(policy.model) === normalizedResolved, ); if (activeIndex !== -1) { return wrapsAround @@ -200,7 +216,10 @@ export function buildFallbackPolicyContext( failedPolicy?: ModelPolicy; candidates: ModelPolicy[]; } { - const index = chain.findIndex((policy) => policy.model === failedModel); + const normalizedFailed = normalizeModelId(failedModel); + const index = chain.findIndex( + (policy) => normalizeModelId(policy.model) === normalizedFailed, + ); if (index === -1) { return { failedPolicy: undefined, candidates: chain }; } diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index 373da6f144..a81cd53de3 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -27,6 +27,7 @@ import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { AuthType } from '../../core/contentGenerator.js'; +import { ModelAvailabilityService } from '../../availability/modelAvailabilityService.js'; vi.mock('../../core/baseLlmClient.js'); @@ -68,6 +69,9 @@ describe('ClassifierStrategy', () => { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.LOGIN_WITH_GOOGLE, }), + getModelAvailabilityService: vi + .fn() + .mockReturnValue(new ModelAvailabilityService()), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 1dd09f4596..dda0f49665 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -20,6 +20,7 @@ import { isFunctionResponse, } from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { normalizeModelId } from '../../utils/modelUtils.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; @@ -177,16 +178,28 @@ export class ClassifierStrategy implements RoutingStrategy { config.getGemini31FlashLiteLaunched(), config.getUseCustomToolModel(), ]); - const selectedModel = resolveClassifierModel( - model, - routerResponse.model_choice, - useGemini3_1, - useGemini3_1FlashLite, - useCustomToolModel, - config.getHasAccessToPreviewModel?.() ?? true, - config, + const selectedModel = normalizeModelId( + resolveClassifierModel( + normalizeModelId(model), + routerResponse.model_choice, + useGemini3_1, + useGemini3_1FlashLite, + useCustomToolModel, + config.getHasAccessToPreviewModel?.() ?? true, + config, + ), ); + const service = config.getModelAvailabilityService(); + const snapshot = service.snapshot(selectedModel); + + if (!snapshot.available) { + debugLogger.warn( + `[Routing] Classifier selected unavailable model ${selectedModel} (${snapshot.reason}). Bypassing.`, + ); + return null; + } + return { model: selectedModel, metadata: { diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index f400dfc51b..fccd1c53eb 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -27,6 +27,7 @@ import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { AuthType } from '../../core/contentGenerator.js'; +import { ModelAvailabilityService } from '../../availability/modelAvailabilityService.js'; vi.mock('../../core/baseLlmClient.js'); @@ -71,6 +72,9 @@ describe('NumericalClassifierStrategy', () => { getContentGeneratorConfig: vi.fn().mockReturnValue({ authType: AuthType.LOGIN_WITH_GOOGLE, }), + getModelAvailabilityService: vi + .fn() + .mockReturnValue(new ModelAvailabilityService()), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index 0e2401c8f1..a490601436 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -20,6 +20,7 @@ import { isFunctionResponse, } from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { normalizeModelId } from '../../utils/modelUtils.js'; import type { LocalLiteRtLmClient } from '../../core/localLiteRtLmClient.js'; import { LlmRole } from '../../telemetry/types.js'; @@ -172,16 +173,28 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config.getGemini31FlashLiteLaunched(), config.getUseCustomToolModel(), ]); - const selectedModel = resolveClassifierModel( - model, - modelAlias, - useGemini3_1, - useGemini3_1FlashLite, - useCustomToolModel, - config.getHasAccessToPreviewModel?.() ?? true, - config, + const selectedModel = normalizeModelId( + resolveClassifierModel( + normalizeModelId(model), + modelAlias, + useGemini3_1, + useGemini3_1FlashLite, + useCustomToolModel, + config.getHasAccessToPreviewModel?.() ?? true, + config, + ), ); + const service = config.getModelAvailabilityService(); + const snapshot = service.snapshot(selectedModel); + + if (!snapshot.available) { + debugLogger.warn( + `[Routing] Numerical classifier selected unavailable model ${selectedModel} (${snapshot.reason}). Bypassing.`, + ); + return null; + } + const latencyMs = Date.now() - startTime; return { diff --git a/packages/core/src/utils/modelUtils.test.ts b/packages/core/src/utils/modelUtils.test.ts new file mode 100644 index 0000000000..6745bf73cf --- /dev/null +++ b/packages/core/src/utils/modelUtils.test.ts @@ -0,0 +1,32 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { normalizeModelId } from './modelUtils.js'; + +describe('modelUtils', () => { + describe('normalizeModelId', () => { + it('should strip "models/" prefix if present', () => { + expect(normalizeModelId('models/gemini-3.1-pro-preview')).toBe( + 'gemini-3.1-pro-preview', + ); + expect(normalizeModelId('models/gemini-1.5-flash')).toBe( + 'gemini-1.5-flash', + ); + }); + + it('should leave model ID untouched if prefix is not present', () => { + expect(normalizeModelId('gemini-3.1-pro-preview')).toBe( + 'gemini-3.1-pro-preview', + ); + expect(normalizeModelId('auto')).toBe('auto'); + }); + + it('should handle empty string', () => { + expect(normalizeModelId('')).toBe(''); + }); + }); +}); diff --git a/packages/core/src/utils/modelUtils.ts b/packages/core/src/utils/modelUtils.ts new file mode 100644 index 0000000000..c85fd784df --- /dev/null +++ b/packages/core/src/utils/modelUtils.ts @@ -0,0 +1,17 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Strips the 'models/' prefix from a model ID if present. + * This ensures internal logic (like family matching) works correctly + * even when receiving formal resource names from the API. + * + * @param modelId The model identifier to normalize. + * @returns The model ID without the 'models/' prefix. + */ +export function normalizeModelId(modelId: string): string { + return modelId.startsWith('models/') ? modelId.slice(7) : modelId; +}