From 8f4f8baa81d5b3696618d010990924a2dec7cc02 Mon Sep 17 00:00:00 2001 From: Adam Weidman <65992621+adamfweidman@users.noreply.github.com> Date: Mon, 8 Dec 2025 06:44:34 -0800 Subject: [PATCH] feat(modelAvailabilityService): integrate model availability service into backend logic (#14470) --- .../src/availability/errorClassification.ts | 25 ++ packages/core/src/availability/modelPolicy.ts | 14 +- .../core/src/availability/policyCatalog.ts | 1 + .../src/availability/policyHelpers.test.ts | 146 +++++++++- .../core/src/availability/policyHelpers.ts | 148 +++++++++- packages/core/src/availability/testUtils.ts | 30 ++ packages/core/src/config/config.test.ts | 39 +++ packages/core/src/config/config.ts | 19 ++ packages/core/src/core/baseLlmClient.test.ts | 271 ++++++++++++++++- packages/core/src/core/baseLlmClient.ts | 53 +++- packages/core/src/core/client.test.ts | 175 ++++++++++- packages/core/src/core/client.ts | 79 ++++- packages/core/src/core/geminiChat.test.ts | 274 +++++++++++++++++- packages/core/src/core/geminiChat.ts | 89 ++++-- .../src/core/geminiChat_network_retry.test.ts | 1 + packages/core/src/fallback/handler.test.ts | 79 +++-- packages/core/src/fallback/handler.ts | 63 ++-- .../services/modelConfigServiceTestUtils.ts | 23 ++ packages/core/src/utils/retry.test.ts | 171 ++++++++++- packages/core/src/utils/retry.ts | 30 +- 20 files changed, 1611 insertions(+), 119 deletions(-) create mode 100644 packages/core/src/availability/errorClassification.ts create mode 100644 packages/core/src/availability/testUtils.ts create mode 100644 packages/core/src/services/modelConfigServiceTestUtils.ts diff --git a/packages/core/src/availability/errorClassification.ts b/packages/core/src/availability/errorClassification.ts new file mode 100644 index 0000000000..1a50fe68d4 --- /dev/null +++ b/packages/core/src/availability/errorClassification.ts @@ -0,0 +1,25 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + TerminalQuotaError, + RetryableQuotaError, +} from '../utils/googleQuotaErrors.js'; +import { ModelNotFoundError } from '../utils/httpErrors.js'; +import type { FailureKind } from './modelPolicy.js'; + +export function classifyFailureKind(error: unknown): FailureKind { + if (error instanceof TerminalQuotaError) { + return 'terminal'; + } + if (error instanceof RetryableQuotaError) { + return 'transient'; + } + if (error instanceof ModelNotFoundError) { + return 'not_found'; + } + return 'unknown'; +} diff --git a/packages/core/src/availability/modelPolicy.ts b/packages/core/src/availability/modelPolicy.ts index 296047edb6..199c13d7a5 100644 --- a/packages/core/src/availability/modelPolicy.ts +++ b/packages/core/src/availability/modelPolicy.ts @@ -4,7 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { ModelHealthStatus, ModelId } from './modelAvailabilityService.js'; +import type { + ModelAvailabilityService, + ModelHealthStatus, + ModelId, +} from './modelAvailabilityService.js'; /** * Whether to prompt the user or fallback silently on a model API failure. @@ -49,3 +53,11 @@ export interface ModelPolicy { * The first model in the chain is the primary model. */ export type ModelPolicyChain = ModelPolicy[]; + +/** + * Context required by retry logic to apply availability policies on failure. + */ +export interface RetryAvailabilityContext { + service: ModelAvailabilityService; + policy: ModelPolicy; +} diff --git a/packages/core/src/availability/policyCatalog.ts b/packages/core/src/availability/policyCatalog.ts index 257eb7e072..6a0d9d07c7 100644 --- a/packages/core/src/availability/policyCatalog.ts +++ b/packages/core/src/availability/policyCatalog.ts @@ -51,6 +51,7 @@ const PREVIEW_CHAIN: ModelPolicyChain = [ definePolicy({ model: PREVIEW_GEMINI_MODEL, stateTransitions: { transient: 'sticky_retry' }, + actions: { transient: 'silent' }, }), definePolicy({ model: DEFAULT_GEMINI_MODEL }), definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }), diff --git a/packages/core/src/availability/policyHelpers.test.ts b/packages/core/src/availability/policyHelpers.test.ts index d1b7f1fd73..8a5455b097 100644 --- a/packages/core/src/availability/policyHelpers.test.ts +++ b/packages/core/src/availability/policyHelpers.test.ts @@ -4,38 +4,54 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { resolvePolicyChain, buildFallbackPolicyContext, + applyModelSelection, } from './policyHelpers.js'; import { createDefaultPolicy } from './policyCatalog.js'; import type { Config } from '../config/config.js'; +const createMockConfig = (overrides: Partial = {}): Config => + ({ + getPreviewFeatures: () => false, + getUserTier: () => undefined, + getModel: () => 'gemini-2.5-pro', + isInFallbackMode: () => false, + ...overrides, + }) as unknown as Config; + describe('policyHelpers', () => { describe('resolvePolicyChain', () => { it('inserts the active model when missing from the catalog', () => { - const config = { - getPreviewFeatures: () => false, - getUserTier: () => undefined, + const config = createMockConfig({ getModel: () => 'custom-model', - isInFallbackMode: () => false, - } as unknown as Config; + }); const chain = resolvePolicyChain(config); expect(chain).toHaveLength(1); expect(chain[0]?.model).toBe('custom-model'); }); it('leaves catalog order untouched when active model already present', () => { - const config = { - getPreviewFeatures: () => false, - getUserTier: () => undefined, + const config = createMockConfig({ getModel: () => 'gemini-2.5-pro', - isInFallbackMode: () => false, - } as unknown as Config; + }); const chain = resolvePolicyChain(config); expect(chain[0]?.model).toBe('gemini-2.5-pro'); }); + + it('returns the default chain when active model is "auto"', () => { + const config = createMockConfig({ + getModel: () => 'auto', + }); + const chain = resolvePolicyChain(config); + + // Expect default chain [Pro, Flash] + expect(chain).toHaveLength(2); + expect(chain[0]?.model).toBe('gemini-2.5-pro'); + expect(chain[1]?.model).toBe('gemini-2.5-flash'); + }); }); describe('buildFallbackPolicyContext', () => { @@ -57,4 +73,112 @@ describe('policyHelpers', () => { expect(context.candidates).toEqual(chain); }); }); + + describe('applyModelSelection', () => { + const mockModelConfigService = { + getResolvedConfig: vi.fn(), + }; + + const mockAvailabilityService = { + selectFirstAvailable: vi.fn(), + consumeStickyAttempt: vi.fn(), + }; + + const createExtendedMockConfig = ( + overrides: Partial = {}, + ): Config => { + const defaults = { + isModelAvailabilityServiceEnabled: () => true, + getModelAvailabilityService: () => mockAvailabilityService, + setActiveModel: vi.fn(), + modelConfigService: mockModelConfigService, + }; + return createMockConfig({ ...defaults, ...overrides } as Partial); + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('returns requested model if availability service is disabled', () => { + const config = createExtendedMockConfig({ + isModelAvailabilityServiceEnabled: () => false, + }); + const result = applyModelSelection(config, 'gemini-pro'); + expect(result.model).toBe('gemini-pro'); + expect(config.setActiveModel).not.toHaveBeenCalled(); + }); + + it('returns requested model if it is available', () => { + const config = createExtendedMockConfig(); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-pro', + }); + + const result = applyModelSelection(config, 'gemini-pro'); + expect(result.model).toBe('gemini-pro'); + expect(result.maxAttempts).toBeUndefined(); + expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); + }); + + it('switches to backup model and updates config if requested is unavailable', () => { + const config = createExtendedMockConfig(); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-flash', + }); + mockModelConfigService.getResolvedConfig.mockReturnValue({ + generateContentConfig: { temperature: 0.1 }, + }); + + const currentConfig = { temperature: 0.9, topP: 1 }; + const result = applyModelSelection(config, 'gemini-pro', currentConfig); + + expect(result.model).toBe('gemini-flash'); + expect(result.config).toEqual({ + temperature: 0.1, + topP: 1, + }); + + expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ + model: 'gemini-flash', + }); + expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash'); + }); + + it('consumes sticky attempt if indicated', () => { + const config = createExtendedMockConfig(); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-pro', + attempts: 1, + }); + + const result = applyModelSelection(config, 'gemini-pro'); + expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( + 'gemini-pro', + ); + expect(result.maxAttempts).toBe(1); + }); + + it('does not consume sticky attempt if consumeAttempt is false', () => { + const config = createExtendedMockConfig(); + mockAvailabilityService.selectFirstAvailable.mockReturnValue({ + selectedModel: 'gemini-pro', + attempts: 1, + }); + + const result = applyModelSelection( + config, + 'gemini-pro', + undefined, + undefined, + { + consumeAttempt: false, + }, + ); + expect( + mockAvailabilityService.consumeStickyAttempt, + ).not.toHaveBeenCalled(); + expect(result.maxAttempts).toBe(1); + }); + }); }); diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index c01a9ab890..ee0de84147 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -4,34 +4,47 @@ * SPDX-License-Identifier: Apache-2.0 */ +import type { GenerateContentConfig } from '@google/genai'; import type { Config } from '../config/config.js'; import type { FailureKind, FallbackAction, ModelPolicy, ModelPolicyChain, + RetryAvailabilityContext, } from './modelPolicy.js'; import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js'; -import { getEffectiveModel } from '../config/models.js'; +import { DEFAULT_GEMINI_MODEL, getEffectiveModel } from '../config/models.js'; +import type { ModelSelectionResult } from './modelAvailabilityService.js'; /** * Resolves the active policy chain for the given config, ensuring the * user-selected active model is represented. */ -export function resolvePolicyChain(config: Config): ModelPolicyChain { +export function resolvePolicyChain( + config: Config, + preferredModel?: string, +): ModelPolicyChain { const chain = getModelPolicyChain({ previewEnabled: !!config.getPreviewFeatures(), userTier: config.getUserTier(), }); - // TODO: This will be replaced when we get rid of Fallback Modes - const activeModel = getEffectiveModel( - config.isInFallbackMode(), - config.getModel(), - config.getPreviewFeatures(), - ); + // TODO: This will be replaced when we get rid of Fallback Modes. + // Switch to getActiveModel() + const activeModel = + preferredModel ?? + getEffectiveModel( + config.isInFallbackMode(), + config.getModel(), + config.getPreviewFeatures(), + ); + + if (activeModel === 'auto') { + return [...chain]; + } if (chain.some((policy) => policy.model === activeModel)) { - return chain; + return [...chain]; } // If the user specified a model not in the default chain, we assume they want @@ -68,3 +81,120 @@ export function resolvePolicyAction( ): FallbackAction { return policy.actions?.[failureKind] ?? 'prompt'; } + +/** + * Creates a context provider for retry logic that returns the availability + * sevice and resolves the current model's policy. + * + * @param modelGetter A function that returns the model ID currently being attempted. + * (Allows handling dynamic model changes during retries). + */ +export function createAvailabilityContextProvider( + config: Config, + modelGetter: () => string, +): () => RetryAvailabilityContext | undefined { + return () => { + if (!config.isModelAvailabilityServiceEnabled()) { + return undefined; + } + const service = config.getModelAvailabilityService(); + const currentModel = modelGetter(); + + // Resolve the chain for the specific model we are attempting. + const chain = resolvePolicyChain(config, currentModel); + const policy = chain.find((p) => p.model === currentModel); + + return policy ? { service, policy } : undefined; + }; +} + +/** + * Selects the model to use for an attempt via the availability service and + * returns the selection context. + */ +export function selectModelForAvailability( + config: Config, + requestedModel: string, +): ModelSelectionResult | undefined { + if (!config.isModelAvailabilityServiceEnabled()) { + return undefined; + } + + const chain = resolvePolicyChain(config, requestedModel); + const selection = config + .getModelAvailabilityService() + .selectFirstAvailable(chain.map((p) => p.model)); + + if (selection.selectedModel) return selection; + + const backupModel = + chain.find((p) => p.isLastResort)?.model ?? DEFAULT_GEMINI_MODEL; + + return { selectedModel: backupModel, skipped: [] }; +} + +/** + * Applies the model availability selection logic, including side effects + * (setting active model, consuming sticky attempts) and config updates. + */ +export function applyModelSelection( + config: Config, + requestedModel: string, + currentConfig?: GenerateContentConfig, + overrideScope?: string, + options: { consumeAttempt?: boolean } = {}, +): { model: string; config?: GenerateContentConfig; maxAttempts?: number } { + const selection = selectModelForAvailability(config, requestedModel); + + if (!selection?.selectedModel) { + return { model: requestedModel, config: currentConfig }; + } + + const finalModel = selection.selectedModel; + let finalConfig = currentConfig; + + // If model changed, re-resolve config + if (finalModel !== requestedModel) { + const { generateContentConfig } = + config.modelConfigService.getResolvedConfig({ + overrideScope, + model: finalModel, + }); + + finalConfig = currentConfig + ? { ...currentConfig, ...generateContentConfig } + : generateContentConfig; + } + + config.setActiveModel(finalModel); + + if (selection.attempts && options.consumeAttempt !== false) { + config.getModelAvailabilityService().consumeStickyAttempt(finalModel); + } + + return { + model: finalModel, + config: finalConfig, + maxAttempts: selection.attempts, + }; +} + +export function applyAvailabilityTransition( + getContext: (() => RetryAvailabilityContext | undefined) | undefined, + failureKind: FailureKind, +): void { + const context = getContext?.(); + if (!context) return; + + const transition = context.policy.stateTransitions?.[failureKind]; + if (!transition) return; + + if (transition === 'terminal') { + context.service.markTerminal( + context.policy.model, + failureKind === 'terminal' ? 'quota' : 'capacity', + ); + } else if (transition === 'sticky_retry') { + context.service.markRetryOncePerTurn(context.policy.model); + } +} diff --git a/packages/core/src/availability/testUtils.ts b/packages/core/src/availability/testUtils.ts new file mode 100644 index 0000000000..8b76c0f053 --- /dev/null +++ b/packages/core/src/availability/testUtils.ts @@ -0,0 +1,30 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { vi } from 'vitest'; +import type { + ModelAvailabilityService, + ModelSelectionResult, +} from './modelAvailabilityService.js'; + +/** + * Test helper to create a fully mocked ModelAvailabilityService. + */ +export function createAvailabilityServiceMock( + selection: ModelSelectionResult = { selectedModel: null, skipped: [] }, +): ModelAvailabilityService { + const service = { + markTerminal: vi.fn(), + markHealthy: vi.fn(), + markRetryOncePerTurn: vi.fn(), + consumeStickyAttempt: vi.fn(), + snapshot: vi.fn(), + resetTurn: vi.fn(), + selectFirstAvailable: vi.fn().mockReturnValue(selection), + }; + + return service as unknown as ModelAvailabilityService; +} diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index fd97a2aaa6..c67f05f888 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -1632,3 +1632,42 @@ describe('Config setExperiments logging', () => { debugSpy.mockRestore(); }); }); + +describe('Availability Service Integration', () => { + const baseModel = 'test-model'; + const baseParams: ConfigParameters = { + sessionId: 'test', + targetDir: '.', + debugMode: false, + model: baseModel, + cwd: '.', + }; + + it('setActiveModel updates active model and emits event', async () => { + const config = new Config(baseParams); + const model1 = 'model1'; + const model2 = 'model2'; + + config.setActiveModel(model1); + expect(config.getActiveModel()).toBe(model1); + expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model1); + + config.setActiveModel(model2); + expect(config.getActiveModel()).toBe(model2); + expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model2); + }); + + it('getActiveModel defaults to configured model if not set', () => { + const config = new Config(baseParams); + expect(config.getActiveModel()).toBe(baseModel); + }); + + it('resetTurn delegates to availability service', () => { + const config = new Config(baseParams); + const service = config.getModelAvailabilityService(); + const spy = vi.spyOn(service, 'resetTurn'); + + config.resetTurn(); + expect(spy).toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 81adcf30ec..2d26e29966 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -382,6 +382,7 @@ export class Config { private ideMode: boolean; private inFallbackMode = false; + private _activeModel: string; private readonly maxSessionTurns: number; private readonly listSessions: boolean; private readonly deleteSession: string | undefined; @@ -504,6 +505,7 @@ export class Config { this.fileDiscoveryService = params.fileDiscoveryService ?? null; this.bugCommand = params.bugCommand; this.model = params.model; + this._activeModel = params.model; this.enableModelAvailabilityService = params.enableModelAvailabilityService ?? false; this.enableAgents = params.enableAgents ?? false; @@ -810,11 +812,28 @@ export class Config { setModel(newModel: string): void { if (this.model !== newModel || this.inFallbackMode) { this.model = newModel; + // When the user explicitly sets a model, that becomes the active model. + this._activeModel = newModel; coreEvents.emitModelChanged(newModel); } this.setFallbackMode(false); } + getActiveModel(): string { + return this._activeModel ?? this.model; + } + + setActiveModel(model: string): void { + if (this._activeModel !== model) { + this._activeModel = model; + coreEvents.emitModelChanged(model); + } + } + + resetTurn(): void { + this.modelAvailabilityService.resetTurn(); + } + isInFallbackMode(): boolean { return this.inFallbackMode; } diff --git a/packages/core/src/core/baseLlmClient.test.ts b/packages/core/src/core/baseLlmClient.test.ts index d1596451c1..4027b070e8 100644 --- a/packages/core/src/core/baseLlmClient.test.ts +++ b/packages/core/src/core/baseLlmClient.test.ts @@ -15,9 +15,12 @@ import { type Mock, } from 'vitest'; -import type { GenerateContentResponse } from '@google/genai'; import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js'; import type { ContentGenerator } from './contentGenerator.js'; +import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; +import { createAvailabilityServiceMock } from '../availability/testUtils.js'; +import type { GenerateContentOptions } from './baseLlmClient.js'; +import type { GenerateContentResponse } from '@google/genai'; import type { Config } from '../config/config.js'; import { AuthType } from './contentGenerator.js'; import { reportError } from '../utils/errorReporting.js'; @@ -25,6 +28,8 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js'; import { retryWithBackoff } from '../utils/retry.js'; import { MalformedJsonResponseEvent } from '../telemetry/types.js'; import { getErrorMessage } from '../utils/errors.js'; +import type { ModelConfigService } from '../services/modelConfigService.js'; +import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; vi.mock('../utils/errorReporting.js'); vi.mock('../telemetry/loggers.js'); @@ -58,6 +63,11 @@ vi.mock('../utils/retry.js', () => ({ } } + const context = options?.getAvailabilityContext?.(); + if (context) { + context.service.markHealthy(context.policy.model); + } + return result; }), })); @@ -97,14 +107,18 @@ describe('BaseLlmClient', () => { getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'), isInteractive: vi.fn().mockReturnValue(false), modelConfigService: { - getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({ - model, - generateContentConfig: { - temperature: 0, - topP: 1, - }, - })), - }, + getResolvedConfig: vi + .fn() + .mockImplementation(({ model }) => makeResolvedModelConfig(model)), + } as unknown as ModelConfigService, + isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false), + getModelAvailabilityService: vi.fn(), + setActiveModel: vi.fn(), + getPreviewFeatures: vi.fn().mockReturnValue(false), + getUserTier: vi.fn().mockReturnValue(undefined), + isInFallbackMode: vi.fn().mockReturnValue(false), + getModel: vi.fn().mockReturnValue('test-model'), + getActiveModel: vi.fn().mockReturnValue('test-model'), } as unknown as Mocked; client = new BaseLlmClient(mockContentGenerator, mockConfig); @@ -593,4 +607,243 @@ describe('BaseLlmClient', () => { ); }); }); + + describe('Availability Service Integration', () => { + let mockAvailabilityService: ModelAvailabilityService; + let contentOptions: GenerateContentOptions; + let jsonOptions: GenerateJsonOptions; + + beforeEach(() => { + mockConfig.isModelAvailabilityServiceEnabled = vi + .fn() + .mockReturnValue(true); + + mockAvailabilityService = createAvailabilityServiceMock({ + selectedModel: 'test-model', + skipped: [], + }); + + // Reflect setActiveModel into getActiveModel so availability-driven updates + // are visible to the client under test. + mockConfig.getActiveModel = vi.fn().mockReturnValue('test-model'); + mockConfig.setActiveModel = vi.fn((model: string) => { + vi.mocked(mockConfig.getActiveModel).mockReturnValue(model); + }); + + vi.spyOn(mockConfig, 'getModelAvailabilityService').mockReturnValue( + mockAvailabilityService, + ); + + contentOptions = { + modelConfigKey: { model: 'test-model' }, + contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }], + abortSignal: abortController.signal, + promptId: 'content-prompt-id', + }; + + jsonOptions = { + ...defaultOptions, + promptId: 'json-prompt-id', + }; + }); + + it('should preserve legacy behavior when availability is disabled', async () => { + mockConfig.isModelAvailabilityServiceEnabled = vi + .fn() + .mockReturnValue(false); + mockGenerateContent.mockResolvedValue( + createMockResponse('Some text response'), + ); + + await client.generateContent(contentOptions); + + expect( + mockAvailabilityService.selectFirstAvailable, + ).not.toHaveBeenCalled(); + expect(mockConfig.setActiveModel).not.toHaveBeenCalled(); + expect(mockAvailabilityService.markHealthy).not.toHaveBeenCalled(); + }); + + it('should mark model as healthy on success', async () => { + const successfulModel = 'gemini-pro'; + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: successfulModel, + skipped: [], + }); + mockGenerateContent.mockResolvedValue( + createMockResponse('Some text response'), + ); + + await client.generateContent({ + ...contentOptions, + modelConfigKey: { model: successfulModel }, + }); + + expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( + successfulModel, + ); + }); + + it('marks the final attempted model healthy after a retry with availability enabled', async () => { + const firstModel = 'gemini-pro'; + const fallbackModel = 'gemini-flash'; + vi.mocked(mockAvailabilityService.selectFirstAvailable) + .mockReturnValueOnce({ selectedModel: firstModel, skipped: [] }) + .mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] }); + + mockGenerateContent + .mockResolvedValueOnce(createMockResponse('retry-me')) + .mockResolvedValueOnce(createMockResponse('final-response')); + + // Run the real retryWithBackoff (with fake timers) to exercise the retry path + vi.useFakeTimers(); + + const retryPromise = client.generateContent({ + ...contentOptions, + modelConfigKey: { model: firstModel }, + maxAttempts: 2, + }); + + await vi.runAllTimersAsync(); + await retryPromise; + + await client.generateContent({ + ...contentOptions, + modelConfigKey: { model: firstModel }, + maxAttempts: 2, + }); + + expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel); + expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel); + expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( + fallbackModel, + ); + expect(mockGenerateContent).toHaveBeenLastCalledWith( + expect.objectContaining({ model: fallbackModel }), + expect.any(String), + ); + }); + + it('should consume sticky attempt if selection has attempts', async () => { + const stickyModel = 'gemini-pro-sticky'; + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: stickyModel, + attempts: 1, + skipped: [], + }); + mockGenerateContent.mockResolvedValue( + createMockResponse('Some text response'), + ); + vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => { + const result = await fn(); + const context = options?.getAvailabilityContext?.(); + if (context) { + context.service.markHealthy(context.policy.model); + } + return result; + }); + + await client.generateContent({ + ...contentOptions, + modelConfigKey: { model: stickyModel }, + }); + + expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( + stickyModel, + ); + expect(retryWithBackoff).toHaveBeenCalledWith( + expect.any(Function), + expect.objectContaining({ maxAttempts: 1 }), + ); + }); + + it('should mark healthy and honor availability selection when using generateJson', async () => { + const availableModel = 'gemini-json-pro'; + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: availableModel, + skipped: [], + }); + mockGenerateContent.mockResolvedValue( + createMockResponse('{"color":"violet"}'), + ); + vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => { + const result = await fn(); + const context = options?.getAvailabilityContext?.(); + if (context) { + context.service.markHealthy(context.policy.model); + } + return result; + }); + + const result = await client.generateJson(jsonOptions); + + expect(result).toEqual({ color: 'violet' }); + expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel); + expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( + availableModel, + ); + expect(mockGenerateContent).toHaveBeenLastCalledWith( + expect.objectContaining({ model: availableModel }), + jsonOptions.promptId, + ); + }); + + it('should refresh configuration when model changes mid-retry', async () => { + const firstModel = 'gemini-pro'; + const fallbackModel = 'gemini-flash'; + + // Provide distinct configs per model + const getResolvedConfigMock = vi.mocked( + mockConfig.modelConfigService.getResolvedConfig, + ); + getResolvedConfigMock + .mockReturnValueOnce( + makeResolvedModelConfig(firstModel, { temperature: 0.1 }), + ) + .mockReturnValueOnce( + makeResolvedModelConfig(fallbackModel, { temperature: 0.9 }), + ); + + // Availability selects the first model initially + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: firstModel, + skipped: [], + }); + + // Change active model after the first attempt + let activeModel = firstModel; + mockConfig.setActiveModel = vi.fn(); // Prevent setActiveModel from resetting getActiveModel mock + mockConfig.getActiveModel.mockImplementation(() => activeModel); + + // First response empty -> triggers retry; second response valid + mockGenerateContent + .mockResolvedValueOnce(createMockResponse('')) + .mockResolvedValueOnce(createMockResponse('final-response')); + + // Custom retry to force two attempts + vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => { + const first = (await fn()) as GenerateContentResponse; + if (options?.shouldRetryOnContent?.(first)) { + activeModel = fallbackModel; // simulate handler switching active model before retry + return (await fn()) as GenerateContentResponse; + } + return first; + }); + + await client.generateContent({ + ...contentOptions, + modelConfigKey: { model: firstModel }, + maxAttempts: 2, + }); + + expect(mockGenerateContent).toHaveBeenCalledTimes(2); + const secondCall = mockGenerateContent.mock.calls[1]?.[0]; + + expect( + mockConfig.modelConfigService.getResolvedConfig, + ).toHaveBeenCalledWith({ model: fallbackModel }); + expect(secondCall?.model).toBe(fallbackModel); + expect(secondCall?.config?.temperature).toBe(0.9); + }); + }); }); diff --git a/packages/core/src/core/baseLlmClient.ts b/packages/core/src/core/baseLlmClient.ts index 166579b166..8313ee4bc9 100644 --- a/packages/core/src/core/baseLlmClient.ts +++ b/packages/core/src/core/baseLlmClient.ts @@ -20,6 +20,10 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js'; import { MalformedJsonResponseEvent } from '../telemetry/types.js'; import { retryWithBackoff } from '../utils/retry.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; +import { + applyModelSelection, + createAvailabilityContextProvider, +} from '../availability/policyHelpers.js'; const DEFAULT_MAX_ATTEMPTS = 5; @@ -232,13 +236,56 @@ export class BaseLlmClient { ): Promise { const abortSignal = requestParams.config?.abortSignal; + // Define callback to fetch context dynamically since active model may get updated during retry loop + const getAvailabilityContext = createAvailabilityContextProvider( + this.config, + () => requestParams.model, + ); + + const { + model, + config: newConfig, + maxAttempts: availabilityMaxAttempts, + } = applyModelSelection( + this.config, + requestParams.model, + requestParams.config, + ); + requestParams.model = model; + if (newConfig) { + requestParams.config = newConfig; + } + if (abortSignal) { + requestParams.config = { ...requestParams.config, abortSignal }; + } + try { - const apiCall = () => - this.contentGenerator.generateContent(requestParams, promptId); + const apiCall = () => { + // If availability is enabled, ensure we use the current active model + // in case a fallback occurred in a previous attempt. + if (this.config.isModelAvailabilityServiceEnabled()) { + const activeModel = this.config.getActiveModel(); + if (activeModel !== requestParams.model) { + requestParams.model = activeModel; + // Re-resolve config if model changed during retry + const { generateContentConfig } = + this.config.modelConfigService.getResolvedConfig({ + model: activeModel, + }); + requestParams.config = { + ...requestParams.config, + ...generateContentConfig, + }; + } + } + return this.contentGenerator.generateContent(requestParams, promptId); + }; return await retryWithBackoff(apiCall, { shouldRetryOnContent, - maxAttempts: maxAttempts ?? DEFAULT_MAX_ATTEMPTS, + maxAttempts: + availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS, + getAvailabilityContext, }); } catch (error) { if (abortSignal?.aborted) { diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index f6f8e98c40..e4aca47fc2 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -38,12 +38,15 @@ import { ideContextStore } from '../ide/ideContext.js'; import type { ModelRouterService } from '../routing/modelRouterService.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { ChatCompressionService } from '../services/chatCompressionService.js'; +import { createAvailabilityServiceMock } from '../availability/testUtils.js'; +import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; import type { ModelConfigKey, ResolvedModelConfig, } from '../services/modelConfigService.js'; import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js'; import { HookSystem } from '../hooks/hookSystem.js'; +import * as policyCatalog from '../availability/policyCatalog.js'; vi.mock('../services/chatCompressionService.js'); @@ -145,6 +148,7 @@ describe('Gemini Client (client.ts)', () => { let mockConfig: Config; let client: GeminiClient; let mockGenerateContentFn: Mock; + let mockRouterService: { route: Mock }; beforeEach(async () => { vi.resetAllMocks(); ClearcutLogger.clearInstance(); @@ -166,6 +170,12 @@ describe('Gemini Client (client.ts)', () => { // Disable 429 simulation for tests setSimulate429(false); + mockRouterService = { + route: vi + .fn() + .mockResolvedValue({ model: 'default-routed-model', reason: 'test' }), + }; + mockContentGenerator = { generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), @@ -192,6 +202,7 @@ describe('Gemini Client (client.ts)', () => { .mockReturnValue(contentGeneratorConfig), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getModel: vi.fn().mockReturnValue('test-model'), + getUserTier: vi.fn().mockReturnValue(undefined), getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'), getApiKey: vi.fn().mockReturnValue('test-key'), getVertexAI: vi.fn().mockReturnValue(false), @@ -215,9 +226,9 @@ describe('Gemini Client (client.ts)', () => { getDirectories: vi.fn().mockReturnValue(['/test/dir']), }), getGeminiClient: vi.fn(), - getModelRouterService: vi.fn().mockReturnValue({ - route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }), - }), + getModelRouterService: vi + .fn() + .mockReturnValue(mockRouterService as unknown as ModelRouterService), getMessageBus: vi.fn().mockReturnValue(undefined), getEnableHooks: vi.fn().mockReturnValue(false), isInFallbackMode: vi.fn().mockReturnValue(false), @@ -251,6 +262,11 @@ describe('Gemini Client (client.ts)', () => { }, isInteractive: vi.fn().mockReturnValue(false), getExperiments: () => {}, + isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false), + getActiveModel: vi.fn().mockReturnValue('test-model'), + setActiveModel: vi.fn(), + resetTurn: vi.fn(), + getModelAvailabilityService: vi.fn(), } as unknown as Config; mockConfig.getHookSystem = vi .fn() @@ -1614,6 +1630,7 @@ ${JSON.stringify( expect(events).toEqual([ { type: GeminiEventType.ModelInfo, value: 'default-routed-model' }, { type: GeminiEventType.InvalidStream }, + { type: GeminiEventType.ModelInfo, value: 'default-routed-model' }, { type: GeminiEventType.Content, value: 'Continued content' }, ]); @@ -1701,8 +1718,8 @@ ${JSON.stringify( const events = await fromAsync(stream); // Assert - // We expect 3 events (model_info + original + 1 retry) - expect(events.length).toBe(3); + // We expect 4 events (model_info + original + model_info + 1 retry) + expect(events.length).toBe(4); expect( events .filter((e) => e.type !== GeminiEventType.ModelInfo) @@ -1977,6 +1994,154 @@ ${JSON.stringify( }); }); + describe('Availability Service Integration', () => { + let mockAvailabilityService: ModelAvailabilityService; + + beforeEach(() => { + mockAvailabilityService = createAvailabilityServiceMock(); + + vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue( + mockAvailabilityService, + ); + vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue( + true, + ); + vi.mocked(mockConfig.setActiveModel).mockClear(); + mockRouterService.route.mockResolvedValue({ + model: 'model-a', + reason: 'test', + }); + vi.mocked(mockConfig.getModelRouterService).mockReturnValue( + mockRouterService as unknown as ModelRouterService, + ); + vi.spyOn(policyCatalog, 'getModelPolicyChain').mockReturnValue([ + { + model: 'model-a', + isLastResort: false, + actions: {}, + stateTransitions: {}, + }, + { + model: 'model-b', + isLastResort: true, + actions: {}, + stateTransitions: {}, + }, + ]); + + mockTurnRunFn.mockReturnValue( + (async function* () { + yield { type: 'content', value: 'Hello' }; + })(), + ); + }); + + it('should select first available model, set active, and not consume sticky attempt (done lower in chain)', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue( + { + selectedModel: 'model-a', + attempts: 1, + skipped: [], + }, + ); + + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-avail', + ); + await fromAsync(stream); + + expect( + mockAvailabilityService.selectFirstAvailable, + ).toHaveBeenCalledWith(['model-a', 'model-b']); + expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-a'); + expect( + mockAvailabilityService.consumeStickyAttempt, + ).not.toHaveBeenCalled(); + // Ensure turn.run used the selected model + expect(mockTurnRunFn).toHaveBeenCalledWith( + expect.objectContaining({ model: 'model-a' }), + expect.anything(), + expect.anything(), + ); + }); + + it('should default to last resort model if selection returns null', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue( + { + selectedModel: null, + skipped: [], + }, + ); + + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-avail-fallback', + ); + await fromAsync(stream); + + expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-b'); // Last resort + expect( + mockAvailabilityService.consumeStickyAttempt, + ).not.toHaveBeenCalled(); + }); + + it('should reset turn on new message stream', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue( + { + selectedModel: 'model-a', + skipped: [], + }, + ); + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-reset', + ); + await fromAsync(stream); + + expect(mockConfig.resetTurn).toHaveBeenCalled(); + }); + + it('should NOT reset turn on invalid stream retry', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue( + { + selectedModel: 'model-a', + skipped: [], + }, + ); + // We simulate a retry by calling sendMessageStream with isInvalidStreamRetry=true + // But the public API doesn't expose that argument directly unless we use the private method or simulate the recursion. + // We can simulate recursion by mocking turn run to return invalid stream once. + + vi.spyOn( + client['config'], + 'getContinueOnFailedApiCall', + ).mockReturnValue(true); + const mockStream1 = (async function* () { + yield { type: GeminiEventType.InvalidStream }; + })(); + const mockStream2 = (async function* () { + yield { type: 'content', value: 'ok' }; + })(); + mockTurnRunFn + .mockReturnValueOnce(mockStream1) + .mockReturnValueOnce(mockStream2); + + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-retry', + ); + await fromAsync(stream); + + // resetTurn should be called once (for the initial call) but NOT for the recursive call + expect(mockConfig.resetTurn).toHaveBeenCalledTimes(1); + }); + }); + describe('IDE context with pending tool calls', () => { let mockChat: Partial; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index f7d915bc9a..3ce8c1306f 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -57,6 +57,11 @@ import type { RoutingContext } from '../routing/routingStrategy.js'; import { debugLogger } from '../utils/debugLogger.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; import { calculateRequestTokenCount } from '../utils/tokenCalculation.js'; +import { + applyModelSelection, + createAvailabilityContextProvider, +} from '../availability/policyHelpers.js'; +import type { RetryAvailabilityContext } from '../utils/retry.js'; const MAX_TURNS = 100; @@ -405,6 +410,10 @@ export class GeminiClient { turns: number = MAX_TURNS, isInvalidStreamRetry: boolean = false, ): AsyncGenerator { + if (!isInvalidStreamRetry) { + this.config.resetTurn(); + } + // Fire BeforeAgent hook through MessageBus (only if hooks are enabled) const hooksEnabled = this.config.getEnableHooks(); const messageBus = this.config.getMessageBus(); @@ -534,11 +543,21 @@ export class GeminiClient { const router = await this.config.getModelRouterService(); const decision = await router.route(routingContext); modelToUse = decision.model; - // Lock the model for the rest of the sequence - this.currentSequenceModel = modelToUse; - yield { type: GeminiEventType.ModelInfo, value: modelToUse }; } + // availability logic + const { model: finalModel } = applyModelSelection( + this.config, + modelToUse, + undefined, + undefined, + { consumeAttempt: false }, + ); + modelToUse = finalModel; + + this.currentSequenceModel = modelToUse; + yield { type: GeminiEventType.ModelInfo, value: modelToUse }; + const resultStream = turn.run({ model: modelToUse }, request, linkedSignal); for await (const event of resultStream) { if (this.loopDetector.addAndCheck(event)) { @@ -665,14 +684,53 @@ export class GeminiClient { try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(this.config, userMemory); + const { + model, + config: newConfig, + maxAttempts: availabilityMaxAttempts, + } = applyModelSelection( + this.config, + currentAttemptModel, + currentAttemptGenerateContentConfig, + modelConfigKey.overrideScope, + ); + currentAttemptModel = model; + if (newConfig) { + currentAttemptGenerateContentConfig = newConfig; + } + + // Define callback to refresh context based on currentAttemptModel which might be updated by fallback handler + const getAvailabilityContext: () => RetryAvailabilityContext | undefined = + createAvailabilityContextProvider( + this.config, + () => currentAttemptModel, + ); const apiCall = () => { - const modelConfigToUse = this.config.isInFallbackMode() - ? fallbackModelConfig - : desiredModelConfig; - currentAttemptModel = modelConfigToUse.model; - currentAttemptGenerateContentConfig = - modelConfigToUse.generateContentConfig; + let modelConfigToUse = desiredModelConfig; + + if (!this.config.isModelAvailabilityServiceEnabled()) { + modelConfigToUse = this.config.isInFallbackMode() + ? fallbackModelConfig + : desiredModelConfig; + currentAttemptModel = modelConfigToUse.model; + currentAttemptGenerateContentConfig = + modelConfigToUse.generateContentConfig; + } else { + // AvailabilityService + const active = this.config.getActiveModel(); + if (active !== currentAttemptModel) { + currentAttemptModel = active; + // Re-resolve config if model changed + const newConfig = this.config.modelConfigService.getResolvedConfig({ + ...modelConfigKey, + model: currentAttemptModel, + }); + currentAttemptGenerateContentConfig = + newConfig.generateContentConfig; + } + } + const requestConfig: GenerateContentConfig = { ...currentAttemptGenerateContentConfig, abortSignal, @@ -698,7 +756,10 @@ export class GeminiClient { const result = await retryWithBackoff(apiCall, { onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, + maxAttempts: availabilityMaxAttempts, + getAvailabilityContext, }); + return result; } catch (error: unknown) { if (abortSignal.aborted) { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 4c2dc2786d..7afdd00ec7 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -29,6 +29,10 @@ import { retryWithBackoff, type RetryOptions } from '../utils/retry.js'; import { uiTelemetryService } from '../telemetry/uiTelemetry.js'; import { HookSystem } from '../hooks/hookSystem.js'; import { createMockMessageBus } from '../test-utils/mock-message-bus.js'; +import { createAvailabilityServiceMock } from '../availability/testUtils.js'; +import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; +import * as policyHelpers from '../availability/policyHelpers.js'; +import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -115,7 +119,14 @@ describe('GeminiChat', () => { mockHandleFallback.mockClear(); // Default mock implementation for tests that don't care about retry logic - mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); + mockRetryWithBackoff.mockImplementation(async (apiCall, options) => { + const result = await apiCall(); + const context = options?.getAvailabilityContext?.(); + if (context) { + context.service.markHealthy(context.policy.model); + } + return result; + }); mockConfig = { getSessionId: () => 'test-session-id', getTelemetryLogPromptsEnabled: () => true, @@ -141,6 +152,7 @@ describe('GeminiChat', () => { }), getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator), getRetryFetchErrors: vi.fn().mockReturnValue(false), + getUserTier: vi.fn().mockReturnValue(undefined), modelConfigService: { getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => { const thinkingConfig = modelConfigKey.model.startsWith('gemini-3') @@ -165,6 +177,10 @@ describe('GeminiChat', () => { setPreviewModelFallbackMode: vi.fn(), isInteractive: vi.fn().mockReturnValue(false), getEnableHooks: vi.fn().mockReturnValue(false), + isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false), + getActiveModel: vi.fn().mockReturnValue('gemini-pro'), + setActiveModel: vi.fn(), + getModelAvailabilityService: vi.fn(), } as unknown as Config; // Use proper MessageBus mocking for Phase 3 preparation @@ -2359,4 +2375,260 @@ describe('GeminiChat', () => { expect(newContents).toEqual(history); }); }); + + describe('Availability Service Integration', () => { + let mockAvailabilityService: ModelAvailabilityService; + + beforeEach(async () => { + mockAvailabilityService = createAvailabilityServiceMock(); + vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue( + mockAvailabilityService, + ); + vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue( + true, + ); + + // Stateful mock for activeModel + let activeModel = 'model-a'; + vi.mocked(mockConfig.getActiveModel).mockImplementation( + () => activeModel, + ); + vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => { + activeModel = model; + }); + + vi.spyOn(policyHelpers, 'resolvePolicyChain').mockReturnValue([ + { + model: 'model-a', + isLastResort: false, + actions: {}, + stateTransitions: {}, + }, + { + model: 'model-b', + isLastResort: false, + actions: {}, + stateTransitions: {}, + }, + { + model: 'model-c', + isLastResort: true, + actions: {}, + stateTransitions: {}, + }, + ]); + }); + + it('should mark healthy on successful stream', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: 'model-b', + skipped: [], + }); + // Simulate selection happening upstream + mockConfig.setActiveModel('model-b'); + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Response' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-healthy', + new AbortController().signal, + ); + for await (const _ of stream) { + // consume + } + + expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith( + 'model-b', + ); + }); + + it('caps retries to a single attempt when selection is sticky', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: 'model-a', + attempts: 1, + skipped: [], + }); + + vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue( + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Response' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-sticky-once', + new AbortController().signal, + ); + for await (const _ of stream) { + // consume + } + + expect(mockRetryWithBackoff).toHaveBeenCalledWith( + expect.any(Function), + expect.objectContaining({ maxAttempts: 1 }), + ); + expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( + 'model-a', + ); + }); + + it('should pass attempted model to onPersistent429 callback which calls handleFallback', async () => { + vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({ + selectedModel: 'model-a', + skipped: [], + }); + // Simulate selection happening upstream + mockConfig.setActiveModel('model-a'); + + // Simulate retry logic behavior: catch error, call onPersistent429 + const error = new TerminalQuotaError('Quota', { + code: 429, + message: 'quota', + details: [], + }); + vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue( + error, + ); + + // We need retryWithBackoff to trigger the callback + mockRetryWithBackoff.mockImplementation(async (apiCall, options) => { + try { + await apiCall(); + } catch (e) { + if (options?.onPersistent429) { + await options.onPersistent429(AuthType.LOGIN_WITH_GOOGLE, e); + } + throw e; // throw anyway to end test + } + }); + + const consume = async () => { + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-fallback-arg', + new AbortController().signal, + ); + for await (const _ of stream) { + // consume + } + }; + + await expect(consume()).rejects.toThrow(); + + // handleFallback is called with the ATTEMPTED model (model-a), not the requested one (gemini-pro) + expect(mockHandleFallback).toHaveBeenCalledWith( + expect.anything(), + 'model-a', + expect.anything(), + error, + ); + }); + + it('re-resolves generateContentConfig when active model changes between retries', async () => { + // Availability enabled with stateful active model + let activeModel = 'model-a'; + vi.mocked(mockConfig.getActiveModel).mockImplementation( + () => activeModel, + ); + vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => { + activeModel = model; + }); + + // Different configs per model + vi.mocked(mockConfig.modelConfigService.getResolvedConfig) + .mockReturnValueOnce( + makeResolvedModelConfig('model-a', { temperature: 0.1 }), + ) + .mockReturnValueOnce( + makeResolvedModelConfig('model-b', { temperature: 0.9 }), + ); + + // First attempt uses model-a, then simulate availability switching to model-b + mockRetryWithBackoff.mockImplementation(async (apiCall) => { + await apiCall(); // first attempt + activeModel = 'model-b'; // simulate switch before retry + return apiCall(); // second attempt + }); + + // Generators for each attempt + const firstResponse = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'first' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + const secondResponse = (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'second' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + vi.mocked(mockContentGenerator.generateContentStream) + .mockResolvedValueOnce(firstResponse) + .mockResolvedValueOnce(secondResponse); + + const stream = await chat.sendMessageStream( + { model: 'gemini-pro' }, + 'test', + 'prompt-config-refresh', + new AbortController().signal, + ); + // Consume to drive both attempts + for await (const _ of stream) { + // consume + } + + expect( + mockContentGenerator.generateContentStream, + ).toHaveBeenNthCalledWith( + 1, + expect.objectContaining({ + model: 'model-a', + config: expect.objectContaining({ temperature: 0.1 }), + }), + expect.any(String), + ); + expect( + mockContentGenerator.generateContentStream, + ).toHaveBeenNthCalledWith( + 2, + expect.objectContaining({ + model: 'model-b', + config: expect.objectContaining({ temperature: 0.9 }), + }), + expect.any(String), + ); + }); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 98b99dbe2f..bec85a5152 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -48,6 +48,10 @@ import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; import type { ModelConfigKey } from '../services/modelConfigService.js'; import { estimateTokenCountSync } from '../utils/tokenCalculation.js'; +import { + applyModelSelection, + createAvailabilityContextProvider, +} from '../availability/policyHelpers.js'; import { fireAfterModelHook, fireBeforeModelHook, @@ -410,35 +414,74 @@ export class GeminiChat { requestContents: Content[], prompt_id: string, ): Promise> { - let effectiveModel = model; const contentsForPreviewModel = this.ensureActiveLoopHasThoughtSignatures(requestContents); // Track final request parameters for AfterModel hooks - let lastModelToUse = model; - let lastConfig: GenerateContentConfig = generateContentConfig; + const { + model: availabilityFinalModel, + config: newAvailabilityConfig, + maxAttempts: availabilityMaxAttempts, + } = applyModelSelection(this.config, model, generateContentConfig); + + const abortSignal = generateContentConfig.abortSignal; + let lastModelToUse = availabilityFinalModel; + let currentGenerateContentConfig: GenerateContentConfig = + newAvailabilityConfig ?? generateContentConfig; + if (abortSignal) { + currentGenerateContentConfig = { + ...currentGenerateContentConfig, + abortSignal, + }; + } + let lastConfig: GenerateContentConfig = currentGenerateContentConfig; let lastContentsToUse: Content[] = requestContents; + const getAvailabilityContext = createAvailabilityContextProvider( + this.config, + () => lastModelToUse, + ); const apiCall = async () => { - let modelToUse = getEffectiveModel( - this.config.isInFallbackMode(), - model, - this.config.getPreviewFeatures(), - ); + let modelToUse: string; - // Preview Model Bypass Logic: - // If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro - // IF the effective model is currently Preview Model. - if ( - this.config.isPreviewModelBypassMode() && - modelToUse === PREVIEW_GEMINI_MODEL - ) { - modelToUse = DEFAULT_GEMINI_MODEL; + if (this.config.isModelAvailabilityServiceEnabled()) { + modelToUse = this.config.getActiveModel(); + if (modelToUse !== lastModelToUse) { + const { generateContentConfig: newConfig } = + this.config.modelConfigService.getResolvedConfig({ + model: modelToUse, + }); + currentGenerateContentConfig = { + ...currentGenerateContentConfig, + ...newConfig, + }; + if (abortSignal) { + currentGenerateContentConfig.abortSignal = abortSignal; + } + } + } else { + modelToUse = getEffectiveModel( + this.config.isInFallbackMode(), + model, + this.config.getPreviewFeatures(), + ); + + // Preview Model Bypass Logic: + // If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro + // IF the effective model is currently Preview Model. + // Note: In availability mode, this should ideally be handled by policy, but preserving + // bypass logic for now as it handles specific transient behavior. + if ( + this.config.isPreviewModelBypassMode() && + modelToUse === PREVIEW_GEMINI_MODEL + ) { + modelToUse = DEFAULT_GEMINI_MODEL; + } } - effectiveModel = modelToUse; + lastModelToUse = modelToUse; const config = { - ...generateContentConfig, + ...currentGenerateContentConfig, // TODO(12622): Ensure we don't overrwrite these when they are // passed via config. systemInstruction: this.systemInstruction, @@ -543,7 +586,7 @@ export class GeminiChat { const onPersistent429Callback = async ( authType?: string, error?: unknown, - ) => handleFallback(this.config, effectiveModel, authType, error); + ) => handleFallback(this.config, lastModelToUse, authType, error); const streamResponse = await retryWithBackoff(apiCall, { onPersistent429: onPersistent429Callback, @@ -551,10 +594,12 @@ export class GeminiChat { retryFetchErrors: this.config.getRetryFetchErrors(), signal: generateContentConfig.abortSignal, maxAttempts: - this.config.isPreviewModelFallbackMode() && + availabilityMaxAttempts ?? + (this.config.isPreviewModelFallbackMode() && model === PREVIEW_GEMINI_MODEL ? 1 - : undefined, + : undefined), + getAvailabilityContext, }); // Store the original request for AfterModel hooks @@ -565,7 +610,7 @@ export class GeminiChat { }; return this.processStreamResponse( - effectiveModel, + lastModelToUse, streamResponse, originalRequest, ); diff --git a/packages/core/src/core/geminiChat_network_retry.test.ts b/packages/core/src/core/geminiChat_network_retry.test.ts index 8952a1053a..71085188dc 100644 --- a/packages/core/src/core/geminiChat_network_retry.test.ts +++ b/packages/core/src/core/geminiChat_network_retry.test.ts @@ -93,6 +93,7 @@ describe('GeminiChat Network Retries', () => { generateContentConfig: { temperature: 0 }, })), }, + isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false), isPreviewModelBypassMode: vi.fn().mockReturnValue(false), setPreviewModelBypassMode: vi.fn(), isPreviewModelFallbackMode: vi.fn().mockReturnValue(false), diff --git a/packages/core/src/fallback/handler.test.ts b/packages/core/src/fallback/handler.test.ts index 8abcffc4c4..29ba2d5b9e 100644 --- a/packages/core/src/fallback/handler.test.ts +++ b/packages/core/src/fallback/handler.test.ts @@ -17,6 +17,7 @@ import { import { handleFallback } from './handler.js'; import type { Config } from '../config/config.js'; import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; +import { createAvailabilityServiceMock } from '../availability/testUtils.js'; import { AuthType } from '../core/contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, @@ -45,25 +46,20 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({ openBrowserSecurely: vi.fn(), })); +// Mock debugLogger to prevent console pollution and allow spying +vi.mock('../utils/debugLogger.js', () => ({ + debugLogger: { + warn: vi.fn(), + error: vi.fn(), + log: vi.fn(), + }, +})); + const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL; const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL; const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE; const AUTH_API_KEY = AuthType.USE_GEMINI; -function createAvailabilityMock( - result: ReturnType, -): ModelAvailabilityService { - return { - markTerminal: vi.fn(), - markHealthy: vi.fn(), - markRetryOncePerTurn: vi.fn(), - consumeStickyAttempt: vi.fn(), - snapshot: vi.fn(), - selectFirstAvailable: vi.fn().mockReturnValue(result), - resetTurn: vi.fn(), - } as unknown as ModelAvailabilityService; -} - const createMockConfig = (overrides: Partial = {}): Config => ({ isInFallbackMode: vi.fn(() => false), @@ -75,8 +71,12 @@ const createMockConfig = (overrides: Partial = {}): Config => setPreviewModelBypassMode: vi.fn(), fallbackHandler: undefined, getFallbackModelHandler: vi.fn(), + setActiveModel: vi.fn(), getModelAvailabilityService: vi.fn(() => - createAvailabilityMock({ selectedModel: FALLBACK_MODEL, skipped: [] }), + createAvailabilityServiceMock({ + selectedModel: FALLBACK_MODEL, + skipped: [], + }), ), getModel: vi.fn(() => MOCK_PRO_MODEL), getPreviewFeatures: vi.fn(() => false), @@ -98,6 +98,12 @@ describe('handleFallback', () => { mockConfig = createMockConfig({ fallbackModelHandler: mockHandler, }); + // Explicitly set the property to ensure it's present for legacy checks + mockConfig.fallbackModelHandler = mockHandler; + + // We mocked debugLogger, so we don't need to spy on console.error for handler failures + // But tests might check console.error usage in legacy code if any? + // The handler uses console.error in legacyHandleFallback. consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged'); }); @@ -538,7 +544,7 @@ describe('handleFallback', () => { beforeEach(() => { vi.clearAllMocks(); - availability = createAvailabilityMock({ + availability = createAvailabilityServiceMock({ selectedModel: DEFAULT_GEMINI_FLASH_MODEL, skipped: [], }); @@ -612,7 +618,17 @@ describe('handleFallback', () => { expect(result).toBe(true); expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled(); - expect(policyConfig.setFallbackMode).toHaveBeenCalledWith(true); + expect(policyConfig.setActiveModel).toHaveBeenCalledWith( + DEFAULT_GEMINI_FLASH_MODEL, + ); + // Silent actions should not trigger the legacy fallback mode (via activateFallbackMode), + // but setActiveModel might trigger it via legacy sync if it switches to Flash. + // However, the test requirement is "doesn't emit fallback mode". + // Since we are mocking setActiveModel, we can verify setFallbackMode isn't called *independently*. + // But setActiveModel is mocked, so it won't trigger side effects unless the implementation does. + // We verified setActiveModel is called. + // We verify setFallbackMode is NOT called (which would happen if activateFallbackMode was called). + expect(policyConfig.setFallbackMode).not.toHaveBeenCalled(); } finally { chainSpy.mockRestore(); } @@ -707,5 +723,34 @@ describe('handleFallback', () => { expect(policyConfig.getModelAvailabilityService).toHaveBeenCalled(); expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled(); }); + + it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => { + policyHandler.mockResolvedValue('retry_always'); + + const result = await handleFallback( + policyConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(true); + expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL); + expect(policyConfig.setFallbackMode).not.toHaveBeenCalled(); + // TODO: add logging expect statement + }); + + it('calls setActiveModel when handler returns "stop"', async () => { + policyHandler.mockResolvedValue('stop'); + + const result = await handleFallback( + policyConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(false); + expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL); + // TODO: add logging expect statement + }); }); }); diff --git a/packages/core/src/fallback/handler.ts b/packages/core/src/fallback/handler.ts index a6406d095b..93cab502fb 100644 --- a/packages/core/src/fallback/handler.ts +++ b/packages/core/src/fallback/handler.ts @@ -12,17 +12,14 @@ import { PREVIEW_GEMINI_MODEL, } from '../config/models.js'; import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js'; -import { coreEvents } from '../utils/events.js'; import { openBrowserSecurely } from '../utils/secure-browser-launcher.js'; import { debugLogger } from '../utils/debugLogger.js'; import { getErrorMessage } from '../utils/errors.js'; import { ModelNotFoundError } from '../utils/httpErrors.js'; -import { - RetryableQuotaError, - TerminalQuotaError, -} from '../utils/googleQuotaErrors.js'; +import { TerminalQuotaError } from '../utils/googleQuotaErrors.js'; +import { coreEvents } from '../utils/events.js'; import type { FallbackIntent, FallbackRecommendation } from './types.js'; -import type { FailureKind } from '../availability/modelPolicy.js'; +import { classifyFailureKind } from '../availability/errorClassification.js'; import { buildFallbackPolicyContext, resolvePolicyChain, @@ -126,6 +123,9 @@ async function handlePolicyDrivenFallback( chain, failedModel, ); + + const failureKind = classifyFailureKind(error); + if (!candidates.length) { return null; } @@ -145,7 +145,7 @@ async function handlePolicyDrivenFallback( return null; } - const failureKind = classifyFailureKind(error); + // failureKind is already declared and calculated above const action = resolvePolicyAction(failureKind, selectedPolicy); if (action === 'silent') { @@ -183,6 +183,7 @@ async function handlePolicyDrivenFallback( failedModel, fallbackModel, authType, + error, // Pass the error so processIntent can handle preview-specific logic ); } catch (handlerError) { debugLogger.error('Fallback handler failed:', handlerError); @@ -209,25 +210,42 @@ async function processIntent( authType?: string, error?: unknown, ): Promise { + const isAvailabilityEnabled = config.isModelAvailabilityServiceEnabled(); + switch (intent) { case 'retry_always': - // If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash. - // For all other errors, activate previewModel fallback. - if ( - failedModel === PREVIEW_GEMINI_MODEL && - !(error instanceof TerminalQuotaError) - ) { - activatePreviewModelFallbackMode(config); + if (isAvailabilityEnabled) { + // TODO(telemetry): Implement generic fallback event logging. Existing + // logFlashFallback is specific to a single Model. + config.setActiveModel(fallbackModel); } else { - activateFallbackMode(config, authType); + // If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash. + // For all other errors, activate previewModel fallback. + if ( + failedModel === PREVIEW_GEMINI_MODEL && + !(error instanceof TerminalQuotaError) + ) { + activatePreviewModelFallbackMode(config); + } else { + activateFallbackMode(config, authType); + } } return true; case 'retry_once': + if (isAvailabilityEnabled) { + config.setActiveModel(fallbackModel); + } return true; case 'stop': - activateFallbackMode(config, authType); + if (isAvailabilityEnabled) { + // TODO(telemetry): Implement generic fallback event logging. Existing + // logFlashFallback is specific to a single Model. + config.setActiveModel(fallbackModel); + } else { + activateFallbackMode(config, authType); + } return false; case 'retry_later': @@ -260,16 +278,3 @@ function activatePreviewModelFallbackMode(config: Config) { // We might want a specific event for Preview Model fallback, but for now we just set the mode. } } - -function classifyFailureKind(error?: unknown): FailureKind { - if (error instanceof TerminalQuotaError) { - return 'terminal'; - } - if (error instanceof RetryableQuotaError) { - return 'transient'; - } - if (error instanceof ModelNotFoundError) { - return 'not_found'; - } - return 'unknown'; -} diff --git a/packages/core/src/services/modelConfigServiceTestUtils.ts b/packages/core/src/services/modelConfigServiceTestUtils.ts new file mode 100644 index 0000000000..f6d0b9fbfc --- /dev/null +++ b/packages/core/src/services/modelConfigServiceTestUtils.ts @@ -0,0 +1,23 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { ResolvedModelConfig } from '../services/modelConfigService.js'; + +/** + * Creates a ResolvedModelConfig with sensible defaults, allowing overrides. + */ +export const makeResolvedModelConfig = ( + model: string, + overrides: Partial = {}, +): ResolvedModelConfig => + ({ + model, + generateContentConfig: { + temperature: 0, + topP: 1, + ...overrides, + }, + }) as ResolvedModelConfig; diff --git a/packages/core/src/utils/retry.test.ts b/packages/core/src/utils/retry.test.ts index f0d78677e3..1b940ed38b 100644 --- a/packages/core/src/utils/retry.test.ts +++ b/packages/core/src/utils/retry.test.ts @@ -17,6 +17,9 @@ import { RetryableQuotaError, } from './googleQuotaErrors.js'; import { PREVIEW_GEMINI_MODEL } from '../config/models.js'; +import type { ModelPolicy } from '../availability/modelPolicy.js'; +import { createAvailabilityServiceMock } from '../availability/testUtils.js'; +import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js'; // Helper to create a mock function that fails a certain number of times const createFailingFunction = ( @@ -104,7 +107,6 @@ describe('retryWithBackoff', () => { const promise = retryWithBackoff(mockFn); - // Expect it to fail with the error from the 5th attempt. await Promise.all([ expect(promise).rejects.toThrow('Simulated error attempt 3'), vi.runAllTimersAsync(), @@ -566,4 +568,171 @@ describe('retryWithBackoff', () => { ); expect(mockFn).toHaveBeenCalledTimes(2); }); + + describe('Availability Context Integration', () => { + let mockService: ModelAvailabilityService; + let mockPolicy1: ModelPolicy; + let mockPolicy2: ModelPolicy; + + beforeEach(() => { + vi.useRealTimers(); + mockService = createAvailabilityServiceMock(); + + mockPolicy1 = { + model: 'model-1', + actions: {}, + stateTransitions: { + terminal: 'terminal', + transient: 'sticky_retry', + }, + }; + + mockPolicy2 = { + model: 'model-2', + actions: {}, + stateTransitions: { + terminal: 'terminal', + }, + }; + }); + + it('updates availability context per attempt and applies transitions to the correct policy', async () => { + const error = new TerminalQuotaError( + 'quota exceeded', + { code: 429, message: 'quota', details: [] }, + 10, + ); + + const fn = vi.fn().mockImplementation(async () => { + throw error; // Always fail with quota + }); + + const onPersistent429 = vi + .fn() + .mockResolvedValueOnce('model-2') // First fallback success + .mockResolvedValueOnce(null); // Second fallback fails (give up) + + // Context provider returns policy1 first, then policy2 + const getContext = vi + .fn() + .mockReturnValueOnce({ service: mockService, policy: mockPolicy1 }) + .mockReturnValueOnce({ service: mockService, policy: mockPolicy2 }); + + await expect( + retryWithBackoff(fn, { + maxAttempts: 3, + initialDelayMs: 1, + getAvailabilityContext: getContext, + onPersistent429, + authType: AuthType.LOGIN_WITH_GOOGLE, + }), + ).rejects.toThrow(TerminalQuotaError); + + // Verify failures + expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota'); + expect(mockService.markTerminal).toHaveBeenCalledWith('model-2', 'quota'); + + // Verify sequences + expect(mockService.markTerminal).toHaveBeenNthCalledWith( + 1, + 'model-1', + 'quota', + ); + expect(mockService.markTerminal).toHaveBeenNthCalledWith( + 2, + 'model-2', + 'quota', + ); + }); + + it('marks sticky_retry after retries are exhausted for transient failures', async () => { + const transientError = new RetryableQuotaError( + 'transient error', + { code: 429, message: 'transient', details: [] }, + 0, + ); + + const fn = vi.fn().mockRejectedValue(transientError); + + const getContext = vi + .fn() + .mockReturnValue({ service: mockService, policy: mockPolicy1 }); + + vi.useFakeTimers(); + const promise = retryWithBackoff(fn, { + maxAttempts: 3, + getAvailabilityContext: getContext, + initialDelayMs: 1, + maxDelayMs: 1, + }).catch((err) => err); + + await vi.runAllTimersAsync(); + const result = await promise; + expect(result).toBe(transientError); + + expect(fn).toHaveBeenCalledTimes(3); + expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith('model-1'); + expect(mockService.markRetryOncePerTurn).toHaveBeenCalledTimes(1); + expect(mockService.markTerminal).not.toHaveBeenCalled(); + }); + + it('maps different failure kinds to correct terminal reasons', async () => { + const quotaError = new TerminalQuotaError( + 'quota', + { code: 429, message: 'q', details: [] }, + 10, + ); + const notFoundError = new ModelNotFoundError('not found', 404); + const genericError = new Error('unknown error'); + + const fn = vi + .fn() + .mockRejectedValueOnce(quotaError) + .mockRejectedValueOnce(notFoundError) + .mockRejectedValueOnce(genericError); + + const policy: ModelPolicy = { + model: 'model-1', + actions: {}, + stateTransitions: { + terminal: 'terminal', // from quotaError + not_found: 'terminal', // from notFoundError + unknown: 'terminal', // from genericError + }, + }; + + const getContext = vi + .fn() + .mockReturnValue({ service: mockService, policy }); + + // Run for quotaError + await retryWithBackoff(fn, { + maxAttempts: 1, + getAvailabilityContext: getContext, + }).catch(() => {}); + expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota'); + + // Run for notFoundError + await retryWithBackoff(fn, { + maxAttempts: 1, + getAvailabilityContext: getContext, + }).catch(() => {}); + expect(mockService.markTerminal).toHaveBeenCalledWith( + 'model-1', + 'capacity', + ); + + // Run for genericError + await retryWithBackoff(fn, { + maxAttempts: 1, + getAvailabilityContext: getContext, + }).catch(() => {}); + expect(mockService.markTerminal).toHaveBeenCalledWith( + 'model-1', + 'capacity', + ); + + expect(mockService.markTerminal).toHaveBeenCalledTimes(3); + }); + }); }); diff --git a/packages/core/src/utils/retry.ts b/packages/core/src/utils/retry.ts index 515eb92426..b9224fe304 100644 --- a/packages/core/src/utils/retry.ts +++ b/packages/core/src/utils/retry.ts @@ -8,13 +8,18 @@ import type { GenerateContentResponse } from '@google/genai'; import { ApiError } from '@google/genai'; import { AuthType } from '../core/contentGenerator.js'; import { - classifyGoogleError, - RetryableQuotaError, TerminalQuotaError, + RetryableQuotaError, + classifyGoogleError, } from './googleQuotaErrors.js'; import { delay, createAbortError } from './delay.js'; import { debugLogger } from './debugLogger.js'; import { getErrorStatus, ModelNotFoundError } from './httpErrors.js'; +import type { RetryAvailabilityContext } from '../availability/modelPolicy.js'; +import { classifyFailureKind } from '../availability/errorClassification.js'; +import { applyAvailabilityTransition } from '../availability/policyHelpers.js'; + +export type { RetryAvailabilityContext }; export interface RetryOptions { maxAttempts: number; @@ -29,6 +34,7 @@ export interface RetryOptions { authType?: string; retryFetchErrors?: boolean; signal?: AbortSignal; + getAvailabilityContext?: () => RetryAvailabilityContext | undefined; } const DEFAULT_RETRY_OPTIONS: RetryOptions = { @@ -145,6 +151,7 @@ export async function retryWithBackoff( shouldRetryOnContent, retryFetchErrors, signal, + getAvailabilityContext, } = { ...DEFAULT_RETRY_OPTIONS, shouldRetryOnError: isRetryableError, @@ -173,6 +180,11 @@ export async function retryWithBackoff( continue; } + const successContext = getAvailabilityContext?.(); + if (successContext) { + successContext.service.markHealthy(successContext.policy.model); + } + return result; } catch (error) { if (error instanceof Error && error.name === 'AbortError') { @@ -180,6 +192,13 @@ export async function retryWithBackoff( } const classifiedError = classifyGoogleError(error); + const failureKind = classifyFailureKind(classifiedError); + const appliedImmediate = + failureKind === 'terminal' || failureKind === 'not_found'; + if (appliedImmediate) { + applyAvailabilityTransition(getAvailabilityContext, failureKind); + } + const errorCode = getErrorStatus(error); if ( @@ -201,6 +220,7 @@ export async function retryWithBackoff( debugLogger.warn('Fallback to Flash model failed:', fallbackError); } } + // Terminal/not_found already recorded; nothing else to mark here. throw classifiedError; // Throw if no fallback or fallback failed. } @@ -224,6 +244,9 @@ export async function retryWithBackoff( console.warn('Model fallback failed:', fallbackError); } } + if (!appliedImmediate) { + applyAvailabilityTransition(getAvailabilityContext, failureKind); + } throw classifiedError instanceof RetryableQuotaError ? classifiedError : error; @@ -253,6 +276,9 @@ export async function retryWithBackoff( attempt >= maxAttempts || !shouldRetryOnError(error as Error, retryFetchErrors) ) { + if (!appliedImmediate) { + applyAvailabilityTransition(getAvailabilityContext, failureKind); + } throw error; }