feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)

This commit is contained in:
Adam Weidman
2025-12-08 06:44:34 -08:00
committed by GitHub
parent 7a72037572
commit 8f4f8baa81
20 changed files with 1611 additions and 119 deletions

View File

@@ -0,0 +1,25 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
TerminalQuotaError,
RetryableQuotaError,
} from '../utils/googleQuotaErrors.js';
import { ModelNotFoundError } from '../utils/httpErrors.js';
import type { FailureKind } from './modelPolicy.js';
export function classifyFailureKind(error: unknown): FailureKind {
if (error instanceof TerminalQuotaError) {
return 'terminal';
}
if (error instanceof RetryableQuotaError) {
return 'transient';
}
if (error instanceof ModelNotFoundError) {
return 'not_found';
}
return 'unknown';
}

View File

@@ -4,7 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { ModelHealthStatus, ModelId } from './modelAvailabilityService.js';
import type {
ModelAvailabilityService,
ModelHealthStatus,
ModelId,
} from './modelAvailabilityService.js';
/**
* Whether to prompt the user or fallback silently on a model API failure.
@@ -49,3 +53,11 @@ export interface ModelPolicy {
* The first model in the chain is the primary model.
*/
export type ModelPolicyChain = ModelPolicy[];
/**
* Context required by retry logic to apply availability policies on failure.
*/
export interface RetryAvailabilityContext {
service: ModelAvailabilityService;
policy: ModelPolicy;
}

View File

@@ -51,6 +51,7 @@ const PREVIEW_CHAIN: ModelPolicyChain = [
definePolicy({
model: PREVIEW_GEMINI_MODEL,
stateTransitions: { transient: 'sticky_retry' },
actions: { transient: 'silent' },
}),
definePolicy({ model: DEFAULT_GEMINI_MODEL }),
definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }),

View File

@@ -4,38 +4,54 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { describe, it, expect, vi, beforeEach } from 'vitest';
import {
resolvePolicyChain,
buildFallbackPolicyContext,
applyModelSelection,
} from './policyHelpers.js';
import { createDefaultPolicy } from './policyCatalog.js';
import type { Config } from '../config/config.js';
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
getPreviewFeatures: () => false,
getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
...overrides,
}) as unknown as Config;
describe('policyHelpers', () => {
describe('resolvePolicyChain', () => {
it('inserts the active model when missing from the catalog', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
const config = createMockConfig({
getModel: () => 'custom-model',
isInFallbackMode: () => false,
} as unknown as Config;
});
const chain = resolvePolicyChain(config);
expect(chain).toHaveLength(1);
expect(chain[0]?.model).toBe('custom-model');
});
it('leaves catalog order untouched when active model already present', () => {
const config = {
getPreviewFeatures: () => false,
getUserTier: () => undefined,
const config = createMockConfig({
getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
} as unknown as Config;
});
const chain = resolvePolicyChain(config);
expect(chain[0]?.model).toBe('gemini-2.5-pro');
});
it('returns the default chain when active model is "auto"', () => {
const config = createMockConfig({
getModel: () => 'auto',
});
const chain = resolvePolicyChain(config);
// Expect default chain [Pro, Flash]
expect(chain).toHaveLength(2);
expect(chain[0]?.model).toBe('gemini-2.5-pro');
expect(chain[1]?.model).toBe('gemini-2.5-flash');
});
});
describe('buildFallbackPolicyContext', () => {
@@ -57,4 +73,112 @@ describe('policyHelpers', () => {
expect(context.candidates).toEqual(chain);
});
});
describe('applyModelSelection', () => {
const mockModelConfigService = {
getResolvedConfig: vi.fn(),
};
const mockAvailabilityService = {
selectFirstAvailable: vi.fn(),
consumeStickyAttempt: vi.fn(),
};
const createExtendedMockConfig = (
overrides: Partial<Config> = {},
): Config => {
const defaults = {
isModelAvailabilityServiceEnabled: () => true,
getModelAvailabilityService: () => mockAvailabilityService,
setActiveModel: vi.fn(),
modelConfigService: mockModelConfigService,
};
return createMockConfig({ ...defaults, ...overrides } as Partial<Config>);
};
beforeEach(() => {
vi.clearAllMocks();
});
it('returns requested model if availability service is disabled', () => {
const config = createExtendedMockConfig({
isModelAvailabilityServiceEnabled: () => false,
});
const result = applyModelSelection(config, 'gemini-pro');
expect(result.model).toBe('gemini-pro');
expect(config.setActiveModel).not.toHaveBeenCalled();
});
it('returns requested model if it is available', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
});
const result = applyModelSelection(config, 'gemini-pro');
expect(result.model).toBe('gemini-pro');
expect(result.maxAttempts).toBeUndefined();
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
});
it('switches to backup model and updates config if requested is unavailable', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-flash',
});
mockModelConfigService.getResolvedConfig.mockReturnValue({
generateContentConfig: { temperature: 0.1 },
});
const currentConfig = { temperature: 0.9, topP: 1 };
const result = applyModelSelection(config, 'gemini-pro', currentConfig);
expect(result.model).toBe('gemini-flash');
expect(result.config).toEqual({
temperature: 0.1,
topP: 1,
});
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
model: 'gemini-flash',
});
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
});
it('consumes sticky attempt if indicated', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(config, 'gemini-pro');
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'gemini-pro',
);
expect(result.maxAttempts).toBe(1);
});
it('does not consume sticky attempt if consumeAttempt is false', () => {
const config = createExtendedMockConfig();
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
selectedModel: 'gemini-pro',
attempts: 1,
});
const result = applyModelSelection(
config,
'gemini-pro',
undefined,
undefined,
{
consumeAttempt: false,
},
);
expect(
mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled();
expect(result.maxAttempts).toBe(1);
});
});
});

View File

@@ -4,34 +4,47 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { GenerateContentConfig } from '@google/genai';
import type { Config } from '../config/config.js';
import type {
FailureKind,
FallbackAction,
ModelPolicy,
ModelPolicyChain,
RetryAvailabilityContext,
} from './modelPolicy.js';
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js';
import { getEffectiveModel } from '../config/models.js';
import { DEFAULT_GEMINI_MODEL, getEffectiveModel } from '../config/models.js';
import type { ModelSelectionResult } from './modelAvailabilityService.js';
/**
* Resolves the active policy chain for the given config, ensuring the
* user-selected active model is represented.
*/
export function resolvePolicyChain(config: Config): ModelPolicyChain {
export function resolvePolicyChain(
config: Config,
preferredModel?: string,
): ModelPolicyChain {
const chain = getModelPolicyChain({
previewEnabled: !!config.getPreviewFeatures(),
userTier: config.getUserTier(),
});
// TODO: This will be replaced when we get rid of Fallback Modes
const activeModel = getEffectiveModel(
config.isInFallbackMode(),
config.getModel(),
config.getPreviewFeatures(),
);
// TODO: This will be replaced when we get rid of Fallback Modes.
// Switch to getActiveModel()
const activeModel =
preferredModel ??
getEffectiveModel(
config.isInFallbackMode(),
config.getModel(),
config.getPreviewFeatures(),
);
if (activeModel === 'auto') {
return [...chain];
}
if (chain.some((policy) => policy.model === activeModel)) {
return chain;
return [...chain];
}
// If the user specified a model not in the default chain, we assume they want
@@ -68,3 +81,120 @@ export function resolvePolicyAction(
): FallbackAction {
return policy.actions?.[failureKind] ?? 'prompt';
}
/**
* Creates a context provider for retry logic that returns the availability
* sevice and resolves the current model's policy.
*
* @param modelGetter A function that returns the model ID currently being attempted.
* (Allows handling dynamic model changes during retries).
*/
export function createAvailabilityContextProvider(
config: Config,
modelGetter: () => string,
): () => RetryAvailabilityContext | undefined {
return () => {
if (!config.isModelAvailabilityServiceEnabled()) {
return undefined;
}
const service = config.getModelAvailabilityService();
const currentModel = modelGetter();
// Resolve the chain for the specific model we are attempting.
const chain = resolvePolicyChain(config, currentModel);
const policy = chain.find((p) => p.model === currentModel);
return policy ? { service, policy } : undefined;
};
}
/**
* Selects the model to use for an attempt via the availability service and
* returns the selection context.
*/
export function selectModelForAvailability(
config: Config,
requestedModel: string,
): ModelSelectionResult | undefined {
if (!config.isModelAvailabilityServiceEnabled()) {
return undefined;
}
const chain = resolvePolicyChain(config, requestedModel);
const selection = config
.getModelAvailabilityService()
.selectFirstAvailable(chain.map((p) => p.model));
if (selection.selectedModel) return selection;
const backupModel =
chain.find((p) => p.isLastResort)?.model ?? DEFAULT_GEMINI_MODEL;
return { selectedModel: backupModel, skipped: [] };
}
/**
* Applies the model availability selection logic, including side effects
* (setting active model, consuming sticky attempts) and config updates.
*/
export function applyModelSelection(
config: Config,
requestedModel: string,
currentConfig?: GenerateContentConfig,
overrideScope?: string,
options: { consumeAttempt?: boolean } = {},
): { model: string; config?: GenerateContentConfig; maxAttempts?: number } {
const selection = selectModelForAvailability(config, requestedModel);
if (!selection?.selectedModel) {
return { model: requestedModel, config: currentConfig };
}
const finalModel = selection.selectedModel;
let finalConfig = currentConfig;
// If model changed, re-resolve config
if (finalModel !== requestedModel) {
const { generateContentConfig } =
config.modelConfigService.getResolvedConfig({
overrideScope,
model: finalModel,
});
finalConfig = currentConfig
? { ...currentConfig, ...generateContentConfig }
: generateContentConfig;
}
config.setActiveModel(finalModel);
if (selection.attempts && options.consumeAttempt !== false) {
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
}
return {
model: finalModel,
config: finalConfig,
maxAttempts: selection.attempts,
};
}
export function applyAvailabilityTransition(
getContext: (() => RetryAvailabilityContext | undefined) | undefined,
failureKind: FailureKind,
): void {
const context = getContext?.();
if (!context) return;
const transition = context.policy.stateTransitions?.[failureKind];
if (!transition) return;
if (transition === 'terminal') {
context.service.markTerminal(
context.policy.model,
failureKind === 'terminal' ? 'quota' : 'capacity',
);
} else if (transition === 'sticky_retry') {
context.service.markRetryOncePerTurn(context.policy.model);
}
}

View File

@@ -0,0 +1,30 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { vi } from 'vitest';
import type {
ModelAvailabilityService,
ModelSelectionResult,
} from './modelAvailabilityService.js';
/**
* Test helper to create a fully mocked ModelAvailabilityService.
*/
export function createAvailabilityServiceMock(
selection: ModelSelectionResult = { selectedModel: null, skipped: [] },
): ModelAvailabilityService {
const service = {
markTerminal: vi.fn(),
markHealthy: vi.fn(),
markRetryOncePerTurn: vi.fn(),
consumeStickyAttempt: vi.fn(),
snapshot: vi.fn(),
resetTurn: vi.fn(),
selectFirstAvailable: vi.fn().mockReturnValue(selection),
};
return service as unknown as ModelAvailabilityService;
}

View File

@@ -1632,3 +1632,42 @@ describe('Config setExperiments logging', () => {
debugSpy.mockRestore();
});
});
describe('Availability Service Integration', () => {
const baseModel = 'test-model';
const baseParams: ConfigParameters = {
sessionId: 'test',
targetDir: '.',
debugMode: false,
model: baseModel,
cwd: '.',
};
it('setActiveModel updates active model and emits event', async () => {
const config = new Config(baseParams);
const model1 = 'model1';
const model2 = 'model2';
config.setActiveModel(model1);
expect(config.getActiveModel()).toBe(model1);
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model1);
config.setActiveModel(model2);
expect(config.getActiveModel()).toBe(model2);
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model2);
});
it('getActiveModel defaults to configured model if not set', () => {
const config = new Config(baseParams);
expect(config.getActiveModel()).toBe(baseModel);
});
it('resetTurn delegates to availability service', () => {
const config = new Config(baseParams);
const service = config.getModelAvailabilityService();
const spy = vi.spyOn(service, 'resetTurn');
config.resetTurn();
expect(spy).toHaveBeenCalled();
});
});

View File

@@ -382,6 +382,7 @@ export class Config {
private ideMode: boolean;
private inFallbackMode = false;
private _activeModel: string;
private readonly maxSessionTurns: number;
private readonly listSessions: boolean;
private readonly deleteSession: string | undefined;
@@ -504,6 +505,7 @@ export class Config {
this.fileDiscoveryService = params.fileDiscoveryService ?? null;
this.bugCommand = params.bugCommand;
this.model = params.model;
this._activeModel = params.model;
this.enableModelAvailabilityService =
params.enableModelAvailabilityService ?? false;
this.enableAgents = params.enableAgents ?? false;
@@ -810,11 +812,28 @@ export class Config {
setModel(newModel: string): void {
if (this.model !== newModel || this.inFallbackMode) {
this.model = newModel;
// When the user explicitly sets a model, that becomes the active model.
this._activeModel = newModel;
coreEvents.emitModelChanged(newModel);
}
this.setFallbackMode(false);
}
getActiveModel(): string {
return this._activeModel ?? this.model;
}
setActiveModel(model: string): void {
if (this._activeModel !== model) {
this._activeModel = model;
coreEvents.emitModelChanged(model);
}
}
resetTurn(): void {
this.modelAvailabilityService.resetTurn();
}
isInFallbackMode(): boolean {
return this.inFallbackMode;
}

View File

@@ -15,9 +15,12 @@ import {
type Mock,
} from 'vitest';
import type { GenerateContentResponse } from '@google/genai';
import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js';
import type { ContentGenerator } from './contentGenerator.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { GenerateContentOptions } from './baseLlmClient.js';
import type { GenerateContentResponse } from '@google/genai';
import type { Config } from '../config/config.js';
import { AuthType } from './contentGenerator.js';
import { reportError } from '../utils/errorReporting.js';
@@ -25,6 +28,8 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js';
import { retryWithBackoff } from '../utils/retry.js';
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { getErrorMessage } from '../utils/errors.js';
import type { ModelConfigService } from '../services/modelConfigService.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
vi.mock('../utils/errorReporting.js');
vi.mock('../telemetry/loggers.js');
@@ -58,6 +63,11 @@ vi.mock('../utils/retry.js', () => ({
}
}
const context = options?.getAvailabilityContext?.();
if (context) {
context.service.markHealthy(context.policy.model);
}
return result;
}),
}));
@@ -97,14 +107,18 @@ describe('BaseLlmClient', () => {
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
isInteractive: vi.fn().mockReturnValue(false),
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
model,
generateContentConfig: {
temperature: 0,
topP: 1,
},
})),
},
getResolvedConfig: vi
.fn()
.mockImplementation(({ model }) => makeResolvedModelConfig(model)),
} as unknown as ModelConfigService,
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
getModelAvailabilityService: vi.fn(),
setActiveModel: vi.fn(),
getPreviewFeatures: vi.fn().mockReturnValue(false),
getUserTier: vi.fn().mockReturnValue(undefined),
isInFallbackMode: vi.fn().mockReturnValue(false),
getModel: vi.fn().mockReturnValue('test-model'),
getActiveModel: vi.fn().mockReturnValue('test-model'),
} as unknown as Mocked<Config>;
client = new BaseLlmClient(mockContentGenerator, mockConfig);
@@ -593,4 +607,243 @@ describe('BaseLlmClient', () => {
);
});
});
describe('Availability Service Integration', () => {
let mockAvailabilityService: ModelAvailabilityService;
let contentOptions: GenerateContentOptions;
let jsonOptions: GenerateJsonOptions;
beforeEach(() => {
mockConfig.isModelAvailabilityServiceEnabled = vi
.fn()
.mockReturnValue(true);
mockAvailabilityService = createAvailabilityServiceMock({
selectedModel: 'test-model',
skipped: [],
});
// Reflect setActiveModel into getActiveModel so availability-driven updates
// are visible to the client under test.
mockConfig.getActiveModel = vi.fn().mockReturnValue('test-model');
mockConfig.setActiveModel = vi.fn((model: string) => {
vi.mocked(mockConfig.getActiveModel).mockReturnValue(model);
});
vi.spyOn(mockConfig, 'getModelAvailabilityService').mockReturnValue(
mockAvailabilityService,
);
contentOptions = {
modelConfigKey: { model: 'test-model' },
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
abortSignal: abortController.signal,
promptId: 'content-prompt-id',
};
jsonOptions = {
...defaultOptions,
promptId: 'json-prompt-id',
};
});
it('should preserve legacy behavior when availability is disabled', async () => {
mockConfig.isModelAvailabilityServiceEnabled = vi
.fn()
.mockReturnValue(false);
mockGenerateContent.mockResolvedValue(
createMockResponse('Some text response'),
);
await client.generateContent(contentOptions);
expect(
mockAvailabilityService.selectFirstAvailable,
).not.toHaveBeenCalled();
expect(mockConfig.setActiveModel).not.toHaveBeenCalled();
expect(mockAvailabilityService.markHealthy).not.toHaveBeenCalled();
});
it('should mark model as healthy on success', async () => {
const successfulModel = 'gemini-pro';
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: successfulModel,
skipped: [],
});
mockGenerateContent.mockResolvedValue(
createMockResponse('Some text response'),
);
await client.generateContent({
...contentOptions,
modelConfigKey: { model: successfulModel },
});
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
successfulModel,
);
});
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
const firstModel = 'gemini-pro';
const fallbackModel = 'gemini-flash';
vi.mocked(mockAvailabilityService.selectFirstAvailable)
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
mockGenerateContent
.mockResolvedValueOnce(createMockResponse('retry-me'))
.mockResolvedValueOnce(createMockResponse('final-response'));
// Run the real retryWithBackoff (with fake timers) to exercise the retry path
vi.useFakeTimers();
const retryPromise = client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
maxAttempts: 2,
});
await vi.runAllTimersAsync();
await retryPromise;
await client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
maxAttempts: 2,
});
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
fallbackModel,
);
expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: fallbackModel }),
expect.any(String),
);
});
it('should consume sticky attempt if selection has attempts', async () => {
const stickyModel = 'gemini-pro-sticky';
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: stickyModel,
attempts: 1,
skipped: [],
});
mockGenerateContent.mockResolvedValue(
createMockResponse('Some text response'),
);
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
const result = await fn();
const context = options?.getAvailabilityContext?.();
if (context) {
context.service.markHealthy(context.policy.model);
}
return result;
});
await client.generateContent({
...contentOptions,
modelConfigKey: { model: stickyModel },
});
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
stickyModel,
);
expect(retryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({ maxAttempts: 1 }),
);
});
it('should mark healthy and honor availability selection when using generateJson', async () => {
const availableModel = 'gemini-json-pro';
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: availableModel,
skipped: [],
});
mockGenerateContent.mockResolvedValue(
createMockResponse('{"color":"violet"}'),
);
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
const result = await fn();
const context = options?.getAvailabilityContext?.();
if (context) {
context.service.markHealthy(context.policy.model);
}
return result;
});
const result = await client.generateJson(jsonOptions);
expect(result).toEqual({ color: 'violet' });
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
availableModel,
);
expect(mockGenerateContent).toHaveBeenLastCalledWith(
expect.objectContaining({ model: availableModel }),
jsonOptions.promptId,
);
});
it('should refresh configuration when model changes mid-retry', async () => {
const firstModel = 'gemini-pro';
const fallbackModel = 'gemini-flash';
// Provide distinct configs per model
const getResolvedConfigMock = vi.mocked(
mockConfig.modelConfigService.getResolvedConfig,
);
getResolvedConfigMock
.mockReturnValueOnce(
makeResolvedModelConfig(firstModel, { temperature: 0.1 }),
)
.mockReturnValueOnce(
makeResolvedModelConfig(fallbackModel, { temperature: 0.9 }),
);
// Availability selects the first model initially
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: firstModel,
skipped: [],
});
// Change active model after the first attempt
let activeModel = firstModel;
mockConfig.setActiveModel = vi.fn(); // Prevent setActiveModel from resetting getActiveModel mock
mockConfig.getActiveModel.mockImplementation(() => activeModel);
// First response empty -> triggers retry; second response valid
mockGenerateContent
.mockResolvedValueOnce(createMockResponse(''))
.mockResolvedValueOnce(createMockResponse('final-response'));
// Custom retry to force two attempts
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
const first = (await fn()) as GenerateContentResponse;
if (options?.shouldRetryOnContent?.(first)) {
activeModel = fallbackModel; // simulate handler switching active model before retry
return (await fn()) as GenerateContentResponse;
}
return first;
});
await client.generateContent({
...contentOptions,
modelConfigKey: { model: firstModel },
maxAttempts: 2,
});
expect(mockGenerateContent).toHaveBeenCalledTimes(2);
const secondCall = mockGenerateContent.mock.calls[1]?.[0];
expect(
mockConfig.modelConfigService.getResolvedConfig,
).toHaveBeenCalledWith({ model: fallbackModel });
expect(secondCall?.model).toBe(fallbackModel);
expect(secondCall?.config?.temperature).toBe(0.9);
});
});
});

View File

@@ -20,6 +20,10 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js';
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
import { retryWithBackoff } from '../utils/retry.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import {
applyModelSelection,
createAvailabilityContextProvider,
} from '../availability/policyHelpers.js';
const DEFAULT_MAX_ATTEMPTS = 5;
@@ -232,13 +236,56 @@ export class BaseLlmClient {
): Promise<GenerateContentResponse> {
const abortSignal = requestParams.config?.abortSignal;
// Define callback to fetch context dynamically since active model may get updated during retry loop
const getAvailabilityContext = createAvailabilityContextProvider(
this.config,
() => requestParams.model,
);
const {
model,
config: newConfig,
maxAttempts: availabilityMaxAttempts,
} = applyModelSelection(
this.config,
requestParams.model,
requestParams.config,
);
requestParams.model = model;
if (newConfig) {
requestParams.config = newConfig;
}
if (abortSignal) {
requestParams.config = { ...requestParams.config, abortSignal };
}
try {
const apiCall = () =>
this.contentGenerator.generateContent(requestParams, promptId);
const apiCall = () => {
// If availability is enabled, ensure we use the current active model
// in case a fallback occurred in a previous attempt.
if (this.config.isModelAvailabilityServiceEnabled()) {
const activeModel = this.config.getActiveModel();
if (activeModel !== requestParams.model) {
requestParams.model = activeModel;
// Re-resolve config if model changed during retry
const { generateContentConfig } =
this.config.modelConfigService.getResolvedConfig({
model: activeModel,
});
requestParams.config = {
...requestParams.config,
...generateContentConfig,
};
}
}
return this.contentGenerator.generateContent(requestParams, promptId);
};
return await retryWithBackoff(apiCall, {
shouldRetryOnContent,
maxAttempts: maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
maxAttempts:
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
getAvailabilityContext,
});
} catch (error) {
if (abortSignal?.aborted) {

View File

@@ -38,12 +38,15 @@ import { ideContextStore } from '../ide/ideContext.js';
import type { ModelRouterService } from '../routing/modelRouterService.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { ChatCompressionService } from '../services/chatCompressionService.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import type {
ModelConfigKey,
ResolvedModelConfig,
} from '../services/modelConfigService.js';
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import { HookSystem } from '../hooks/hookSystem.js';
import * as policyCatalog from '../availability/policyCatalog.js';
vi.mock('../services/chatCompressionService.js');
@@ -145,6 +148,7 @@ describe('Gemini Client (client.ts)', () => {
let mockConfig: Config;
let client: GeminiClient;
let mockGenerateContentFn: Mock;
let mockRouterService: { route: Mock };
beforeEach(async () => {
vi.resetAllMocks();
ClearcutLogger.clearInstance();
@@ -166,6 +170,12 @@ describe('Gemini Client (client.ts)', () => {
// Disable 429 simulation for tests
setSimulate429(false);
mockRouterService = {
route: vi
.fn()
.mockResolvedValue({ model: 'default-routed-model', reason: 'test' }),
};
mockContentGenerator = {
generateContent: mockGenerateContentFn,
generateContentStream: vi.fn(),
@@ -192,6 +202,7 @@ describe('Gemini Client (client.ts)', () => {
.mockReturnValue(contentGeneratorConfig),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
getModel: vi.fn().mockReturnValue('test-model'),
getUserTier: vi.fn().mockReturnValue(undefined),
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
getApiKey: vi.fn().mockReturnValue('test-key'),
getVertexAI: vi.fn().mockReturnValue(false),
@@ -215,9 +226,9 @@ describe('Gemini Client (client.ts)', () => {
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getGeminiClient: vi.fn(),
getModelRouterService: vi.fn().mockReturnValue({
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
}),
getModelRouterService: vi
.fn()
.mockReturnValue(mockRouterService as unknown as ModelRouterService),
getMessageBus: vi.fn().mockReturnValue(undefined),
getEnableHooks: vi.fn().mockReturnValue(false),
isInFallbackMode: vi.fn().mockReturnValue(false),
@@ -251,6 +262,11 @@ describe('Gemini Client (client.ts)', () => {
},
isInteractive: vi.fn().mockReturnValue(false),
getExperiments: () => {},
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
getActiveModel: vi.fn().mockReturnValue('test-model'),
setActiveModel: vi.fn(),
resetTurn: vi.fn(),
getModelAvailabilityService: vi.fn(),
} as unknown as Config;
mockConfig.getHookSystem = vi
.fn()
@@ -1614,6 +1630,7 @@ ${JSON.stringify(
expect(events).toEqual([
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
{ type: GeminiEventType.InvalidStream },
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
{ type: GeminiEventType.Content, value: 'Continued content' },
]);
@@ -1701,8 +1718,8 @@ ${JSON.stringify(
const events = await fromAsync(stream);
// Assert
// We expect 3 events (model_info + original + 1 retry)
expect(events.length).toBe(3);
// We expect 4 events (model_info + original + model_info + 1 retry)
expect(events.length).toBe(4);
expect(
events
.filter((e) => e.type !== GeminiEventType.ModelInfo)
@@ -1977,6 +1994,154 @@ ${JSON.stringify(
});
});
describe('Availability Service Integration', () => {
let mockAvailabilityService: ModelAvailabilityService;
beforeEach(() => {
mockAvailabilityService = createAvailabilityServiceMock();
vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue(
mockAvailabilityService,
);
vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue(
true,
);
vi.mocked(mockConfig.setActiveModel).mockClear();
mockRouterService.route.mockResolvedValue({
model: 'model-a',
reason: 'test',
});
vi.mocked(mockConfig.getModelRouterService).mockReturnValue(
mockRouterService as unknown as ModelRouterService,
);
vi.spyOn(policyCatalog, 'getModelPolicyChain').mockReturnValue([
{
model: 'model-a',
isLastResort: false,
actions: {},
stateTransitions: {},
},
{
model: 'model-b',
isLastResort: true,
actions: {},
stateTransitions: {},
},
]);
mockTurnRunFn.mockReturnValue(
(async function* () {
yield { type: 'content', value: 'Hello' };
})(),
);
});
it('should select first available model, set active, and not consume sticky attempt (done lower in chain)', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
{
selectedModel: 'model-a',
attempts: 1,
skipped: [],
},
);
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-avail',
);
await fromAsync(stream);
expect(
mockAvailabilityService.selectFirstAvailable,
).toHaveBeenCalledWith(['model-a', 'model-b']);
expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-a');
expect(
mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled();
// Ensure turn.run used the selected model
expect(mockTurnRunFn).toHaveBeenCalledWith(
expect.objectContaining({ model: 'model-a' }),
expect.anything(),
expect.anything(),
);
});
it('should default to last resort model if selection returns null', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
{
selectedModel: null,
skipped: [],
},
);
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-avail-fallback',
);
await fromAsync(stream);
expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-b'); // Last resort
expect(
mockAvailabilityService.consumeStickyAttempt,
).not.toHaveBeenCalled();
});
it('should reset turn on new message stream', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
{
selectedModel: 'model-a',
skipped: [],
},
);
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-reset',
);
await fromAsync(stream);
expect(mockConfig.resetTurn).toHaveBeenCalled();
});
it('should NOT reset turn on invalid stream retry', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
{
selectedModel: 'model-a',
skipped: [],
},
);
// We simulate a retry by calling sendMessageStream with isInvalidStreamRetry=true
// But the public API doesn't expose that argument directly unless we use the private method or simulate the recursion.
// We can simulate recursion by mocking turn run to return invalid stream once.
vi.spyOn(
client['config'],
'getContinueOnFailedApiCall',
).mockReturnValue(true);
const mockStream1 = (async function* () {
yield { type: GeminiEventType.InvalidStream };
})();
const mockStream2 = (async function* () {
yield { type: 'content', value: 'ok' };
})();
mockTurnRunFn
.mockReturnValueOnce(mockStream1)
.mockReturnValueOnce(mockStream2);
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-retry',
);
await fromAsync(stream);
// resetTurn should be called once (for the initial call) but NOT for the recursive call
expect(mockConfig.resetTurn).toHaveBeenCalledTimes(1);
});
});
describe('IDE context with pending tool calls', () => {
let mockChat: Partial<GeminiChat>;

View File

@@ -57,6 +57,11 @@ import type { RoutingContext } from '../routing/routingStrategy.js';
import { debugLogger } from '../utils/debugLogger.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { calculateRequestTokenCount } from '../utils/tokenCalculation.js';
import {
applyModelSelection,
createAvailabilityContextProvider,
} from '../availability/policyHelpers.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
const MAX_TURNS = 100;
@@ -405,6 +410,10 @@ export class GeminiClient {
turns: number = MAX_TURNS,
isInvalidStreamRetry: boolean = false,
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!isInvalidStreamRetry) {
this.config.resetTurn();
}
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
const hooksEnabled = this.config.getEnableHooks();
const messageBus = this.config.getMessageBus();
@@ -534,11 +543,21 @@ export class GeminiClient {
const router = await this.config.getModelRouterService();
const decision = await router.route(routingContext);
modelToUse = decision.model;
// Lock the model for the rest of the sequence
this.currentSequenceModel = modelToUse;
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
}
// availability logic
const { model: finalModel } = applyModelSelection(
this.config,
modelToUse,
undefined,
undefined,
{ consumeAttempt: false },
);
modelToUse = finalModel;
this.currentSequenceModel = modelToUse;
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
const resultStream = turn.run({ model: modelToUse }, request, linkedSignal);
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
@@ -665,14 +684,53 @@ export class GeminiClient {
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
const {
model,
config: newConfig,
maxAttempts: availabilityMaxAttempts,
} = applyModelSelection(
this.config,
currentAttemptModel,
currentAttemptGenerateContentConfig,
modelConfigKey.overrideScope,
);
currentAttemptModel = model;
if (newConfig) {
currentAttemptGenerateContentConfig = newConfig;
}
// Define callback to refresh context based on currentAttemptModel which might be updated by fallback handler
const getAvailabilityContext: () => RetryAvailabilityContext | undefined =
createAvailabilityContextProvider(
this.config,
() => currentAttemptModel,
);
const apiCall = () => {
const modelConfigToUse = this.config.isInFallbackMode()
? fallbackModelConfig
: desiredModelConfig;
currentAttemptModel = modelConfigToUse.model;
currentAttemptGenerateContentConfig =
modelConfigToUse.generateContentConfig;
let modelConfigToUse = desiredModelConfig;
if (!this.config.isModelAvailabilityServiceEnabled()) {
modelConfigToUse = this.config.isInFallbackMode()
? fallbackModelConfig
: desiredModelConfig;
currentAttemptModel = modelConfigToUse.model;
currentAttemptGenerateContentConfig =
modelConfigToUse.generateContentConfig;
} else {
// AvailabilityService
const active = this.config.getActiveModel();
if (active !== currentAttemptModel) {
currentAttemptModel = active;
// Re-resolve config if model changed
const newConfig = this.config.modelConfigService.getResolvedConfig({
...modelConfigKey,
model: currentAttemptModel,
});
currentAttemptGenerateContentConfig =
newConfig.generateContentConfig;
}
}
const requestConfig: GenerateContentConfig = {
...currentAttemptGenerateContentConfig,
abortSignal,
@@ -698,7 +756,10 @@ export class GeminiClient {
const result = await retryWithBackoff(apiCall, {
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
maxAttempts: availabilityMaxAttempts,
getAvailabilityContext,
});
return result;
} catch (error: unknown) {
if (abortSignal.aborted) {

View File

@@ -29,6 +29,10 @@ import { retryWithBackoff, type RetryOptions } from '../utils/retry.js';
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
import { HookSystem } from '../hooks/hookSystem.js';
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import * as policyHelpers from '../availability/policyHelpers.js';
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -115,7 +119,14 @@ describe('GeminiChat', () => {
mockHandleFallback.mockClear();
// Default mock implementation for tests that don't care about retry logic
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
const result = await apiCall();
const context = options?.getAvailabilityContext?.();
if (context) {
context.service.markHealthy(context.policy.model);
}
return result;
});
mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
@@ -141,6 +152,7 @@ describe('GeminiChat', () => {
}),
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
getRetryFetchErrors: vi.fn().mockReturnValue(false),
getUserTier: vi.fn().mockReturnValue(undefined),
modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => {
const thinkingConfig = modelConfigKey.model.startsWith('gemini-3')
@@ -165,6 +177,10 @@ describe('GeminiChat', () => {
setPreviewModelFallbackMode: vi.fn(),
isInteractive: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false),
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
setActiveModel: vi.fn(),
getModelAvailabilityService: vi.fn(),
} as unknown as Config;
// Use proper MessageBus mocking for Phase 3 preparation
@@ -2359,4 +2375,260 @@ describe('GeminiChat', () => {
expect(newContents).toEqual(history);
});
});
describe('Availability Service Integration', () => {
let mockAvailabilityService: ModelAvailabilityService;
beforeEach(async () => {
mockAvailabilityService = createAvailabilityServiceMock();
vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue(
mockAvailabilityService,
);
vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue(
true,
);
// Stateful mock for activeModel
let activeModel = 'model-a';
vi.mocked(mockConfig.getActiveModel).mockImplementation(
() => activeModel,
);
vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => {
activeModel = model;
});
vi.spyOn(policyHelpers, 'resolvePolicyChain').mockReturnValue([
{
model: 'model-a',
isLastResort: false,
actions: {},
stateTransitions: {},
},
{
model: 'model-b',
isLastResort: false,
actions: {},
stateTransitions: {},
},
{
model: 'model-c',
isLastResort: true,
actions: {},
stateTransitions: {},
},
]);
});
it('should mark healthy on successful stream', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: 'model-b',
skipped: [],
});
// Simulate selection happening upstream
mockConfig.setActiveModel('model-b');
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Response' }], role: 'model' },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-healthy',
new AbortController().signal,
);
for await (const _ of stream) {
// consume
}
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
'model-b',
);
});
it('caps retries to a single attempt when selection is sticky', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: 'model-a',
attempts: 1,
skipped: [],
});
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Response' }], role: 'model' },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-sticky-once',
new AbortController().signal,
);
for await (const _ of stream) {
// consume
}
expect(mockRetryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({ maxAttempts: 1 }),
);
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
'model-a',
);
});
it('should pass attempted model to onPersistent429 callback which calls handleFallback', async () => {
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
selectedModel: 'model-a',
skipped: [],
});
// Simulate selection happening upstream
mockConfig.setActiveModel('model-a');
// Simulate retry logic behavior: catch error, call onPersistent429
const error = new TerminalQuotaError('Quota', {
code: 429,
message: 'quota',
details: [],
});
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
error,
);
// We need retryWithBackoff to trigger the callback
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
try {
await apiCall();
} catch (e) {
if (options?.onPersistent429) {
await options.onPersistent429(AuthType.LOGIN_WITH_GOOGLE, e);
}
throw e; // throw anyway to end test
}
});
const consume = async () => {
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-fallback-arg',
new AbortController().signal,
);
for await (const _ of stream) {
// consume
}
};
await expect(consume()).rejects.toThrow();
// handleFallback is called with the ATTEMPTED model (model-a), not the requested one (gemini-pro)
expect(mockHandleFallback).toHaveBeenCalledWith(
expect.anything(),
'model-a',
expect.anything(),
error,
);
});
it('re-resolves generateContentConfig when active model changes between retries', async () => {
// Availability enabled with stateful active model
let activeModel = 'model-a';
vi.mocked(mockConfig.getActiveModel).mockImplementation(
() => activeModel,
);
vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => {
activeModel = model;
});
// Different configs per model
vi.mocked(mockConfig.modelConfigService.getResolvedConfig)
.mockReturnValueOnce(
makeResolvedModelConfig('model-a', { temperature: 0.1 }),
)
.mockReturnValueOnce(
makeResolvedModelConfig('model-b', { temperature: 0.9 }),
);
// First attempt uses model-a, then simulate availability switching to model-b
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
await apiCall(); // first attempt
activeModel = 'model-b'; // simulate switch before retry
return apiCall(); // second attempt
});
// Generators for each attempt
const firstResponse = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'first' }], role: 'model' },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
const secondResponse = (async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'second' }], role: 'model' },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockContentGenerator.generateContentStream)
.mockResolvedValueOnce(firstResponse)
.mockResolvedValueOnce(secondResponse);
const stream = await chat.sendMessageStream(
{ model: 'gemini-pro' },
'test',
'prompt-config-refresh',
new AbortController().signal,
);
// Consume to drive both attempts
for await (const _ of stream) {
// consume
}
expect(
mockContentGenerator.generateContentStream,
).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
model: 'model-a',
config: expect.objectContaining({ temperature: 0.1 }),
}),
expect.any(String),
);
expect(
mockContentGenerator.generateContentStream,
).toHaveBeenNthCalledWith(
2,
expect.objectContaining({
model: 'model-b',
config: expect.objectContaining({ temperature: 0.9 }),
}),
expect.any(String),
);
});
});
});

View File

@@ -48,6 +48,10 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
import { partListUnionToString } from './geminiRequest.js';
import type { ModelConfigKey } from '../services/modelConfigService.js';
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
import {
applyModelSelection,
createAvailabilityContextProvider,
} from '../availability/policyHelpers.js';
import {
fireAfterModelHook,
fireBeforeModelHook,
@@ -410,35 +414,74 @@ export class GeminiChat {
requestContents: Content[],
prompt_id: string,
): Promise<AsyncGenerator<GenerateContentResponse>> {
let effectiveModel = model;
const contentsForPreviewModel =
this.ensureActiveLoopHasThoughtSignatures(requestContents);
// Track final request parameters for AfterModel hooks
let lastModelToUse = model;
let lastConfig: GenerateContentConfig = generateContentConfig;
const {
model: availabilityFinalModel,
config: newAvailabilityConfig,
maxAttempts: availabilityMaxAttempts,
} = applyModelSelection(this.config, model, generateContentConfig);
const abortSignal = generateContentConfig.abortSignal;
let lastModelToUse = availabilityFinalModel;
let currentGenerateContentConfig: GenerateContentConfig =
newAvailabilityConfig ?? generateContentConfig;
if (abortSignal) {
currentGenerateContentConfig = {
...currentGenerateContentConfig,
abortSignal,
};
}
let lastConfig: GenerateContentConfig = currentGenerateContentConfig;
let lastContentsToUse: Content[] = requestContents;
const getAvailabilityContext = createAvailabilityContextProvider(
this.config,
() => lastModelToUse,
);
const apiCall = async () => {
let modelToUse = getEffectiveModel(
this.config.isInFallbackMode(),
model,
this.config.getPreviewFeatures(),
);
let modelToUse: string;
// Preview Model Bypass Logic:
// If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro
// IF the effective model is currently Preview Model.
if (
this.config.isPreviewModelBypassMode() &&
modelToUse === PREVIEW_GEMINI_MODEL
) {
modelToUse = DEFAULT_GEMINI_MODEL;
if (this.config.isModelAvailabilityServiceEnabled()) {
modelToUse = this.config.getActiveModel();
if (modelToUse !== lastModelToUse) {
const { generateContentConfig: newConfig } =
this.config.modelConfigService.getResolvedConfig({
model: modelToUse,
});
currentGenerateContentConfig = {
...currentGenerateContentConfig,
...newConfig,
};
if (abortSignal) {
currentGenerateContentConfig.abortSignal = abortSignal;
}
}
} else {
modelToUse = getEffectiveModel(
this.config.isInFallbackMode(),
model,
this.config.getPreviewFeatures(),
);
// Preview Model Bypass Logic:
// If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro
// IF the effective model is currently Preview Model.
// Note: In availability mode, this should ideally be handled by policy, but preserving
// bypass logic for now as it handles specific transient behavior.
if (
this.config.isPreviewModelBypassMode() &&
modelToUse === PREVIEW_GEMINI_MODEL
) {
modelToUse = DEFAULT_GEMINI_MODEL;
}
}
effectiveModel = modelToUse;
lastModelToUse = modelToUse;
const config = {
...generateContentConfig,
...currentGenerateContentConfig,
// TODO(12622): Ensure we don't overrwrite these when they are
// passed via config.
systemInstruction: this.systemInstruction,
@@ -543,7 +586,7 @@ export class GeminiChat {
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) => handleFallback(this.config, effectiveModel, authType, error);
) => handleFallback(this.config, lastModelToUse, authType, error);
const streamResponse = await retryWithBackoff(apiCall, {
onPersistent429: onPersistent429Callback,
@@ -551,10 +594,12 @@ export class GeminiChat {
retryFetchErrors: this.config.getRetryFetchErrors(),
signal: generateContentConfig.abortSignal,
maxAttempts:
this.config.isPreviewModelFallbackMode() &&
availabilityMaxAttempts ??
(this.config.isPreviewModelFallbackMode() &&
model === PREVIEW_GEMINI_MODEL
? 1
: undefined,
: undefined),
getAvailabilityContext,
});
// Store the original request for AfterModel hooks
@@ -565,7 +610,7 @@ export class GeminiChat {
};
return this.processStreamResponse(
effectiveModel,
lastModelToUse,
streamResponse,
originalRequest,
);

View File

@@ -93,6 +93,7 @@ describe('GeminiChat Network Retries', () => {
generateContentConfig: { temperature: 0 },
})),
},
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
setPreviewModelBypassMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),

View File

@@ -17,6 +17,7 @@ import {
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
@@ -45,25 +46,20 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
openBrowserSecurely: vi.fn(),
}));
// Mock debugLogger to prevent console pollution and allow spying
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
warn: vi.fn(),
error: vi.fn(),
log: vi.fn(),
},
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
function createAvailabilityMock(
result: ReturnType<ModelAvailabilityService['selectFirstAvailable']>,
): ModelAvailabilityService {
return {
markTerminal: vi.fn(),
markHealthy: vi.fn(),
markRetryOncePerTurn: vi.fn(),
consumeStickyAttempt: vi.fn(),
snapshot: vi.fn(),
selectFirstAvailable: vi.fn().mockReturnValue(result),
resetTurn: vi.fn(),
} as unknown as ModelAvailabilityService;
}
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
isInFallbackMode: vi.fn(() => false),
@@ -75,8 +71,12 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
setPreviewModelBypassMode: vi.fn(),
fallbackHandler: undefined,
getFallbackModelHandler: vi.fn(),
setActiveModel: vi.fn(),
getModelAvailabilityService: vi.fn(() =>
createAvailabilityMock({ selectedModel: FALLBACK_MODEL, skipped: [] }),
createAvailabilityServiceMock({
selectedModel: FALLBACK_MODEL,
skipped: [],
}),
),
getModel: vi.fn(() => MOCK_PRO_MODEL),
getPreviewFeatures: vi.fn(() => false),
@@ -98,6 +98,12 @@ describe('handleFallback', () => {
mockConfig = createMockConfig({
fallbackModelHandler: mockHandler,
});
// Explicitly set the property to ensure it's present for legacy checks
mockConfig.fallbackModelHandler = mockHandler;
// We mocked debugLogger, so we don't need to spy on console.error for handler failures
// But tests might check console.error usage in legacy code if any?
// The handler uses console.error in legacyHandleFallback.
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged');
});
@@ -538,7 +544,7 @@ describe('handleFallback', () => {
beforeEach(() => {
vi.clearAllMocks();
availability = createAvailabilityMock({
availability = createAvailabilityServiceMock({
selectedModel: DEFAULT_GEMINI_FLASH_MODEL,
skipped: [],
});
@@ -612,7 +618,17 @@ describe('handleFallback', () => {
expect(result).toBe(true);
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
expect(policyConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
);
// Silent actions should not trigger the legacy fallback mode (via activateFallbackMode),
// but setActiveModel might trigger it via legacy sync if it switches to Flash.
// However, the test requirement is "doesn't emit fallback mode".
// Since we are mocking setActiveModel, we can verify setFallbackMode isn't called *independently*.
// But setActiveModel is mocked, so it won't trigger side effects unless the implementation does.
// We verified setActiveModel is called.
// We verify setFallbackMode is NOT called (which would happen if activateFallbackMode was called).
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
} finally {
chainSpy.mockRestore();
}
@@ -707,5 +723,34 @@ describe('handleFallback', () => {
expect(policyConfig.getModelAvailabilityService).toHaveBeenCalled();
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
});
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
policyHandler.mockResolvedValue('retry_always');
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
// TODO: add logging expect statement
});
it('calls setActiveModel when handler returns "stop"', async () => {
policyHandler.mockResolvedValue('stop');
const result = await handleFallback(
policyConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
// TODO: add logging expect statement
});
});
});

View File

@@ -12,17 +12,14 @@ import {
PREVIEW_GEMINI_MODEL,
} from '../config/models.js';
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
import { coreEvents } from '../utils/events.js';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import { debugLogger } from '../utils/debugLogger.js';
import { getErrorMessage } from '../utils/errors.js';
import { ModelNotFoundError } from '../utils/httpErrors.js';
import {
RetryableQuotaError,
TerminalQuotaError,
} from '../utils/googleQuotaErrors.js';
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
import { coreEvents } from '../utils/events.js';
import type { FallbackIntent, FallbackRecommendation } from './types.js';
import type { FailureKind } from '../availability/modelPolicy.js';
import { classifyFailureKind } from '../availability/errorClassification.js';
import {
buildFallbackPolicyContext,
resolvePolicyChain,
@@ -126,6 +123,9 @@ async function handlePolicyDrivenFallback(
chain,
failedModel,
);
const failureKind = classifyFailureKind(error);
if (!candidates.length) {
return null;
}
@@ -145,7 +145,7 @@ async function handlePolicyDrivenFallback(
return null;
}
const failureKind = classifyFailureKind(error);
// failureKind is already declared and calculated above
const action = resolvePolicyAction(failureKind, selectedPolicy);
if (action === 'silent') {
@@ -183,6 +183,7 @@ async function handlePolicyDrivenFallback(
failedModel,
fallbackModel,
authType,
error, // Pass the error so processIntent can handle preview-specific logic
);
} catch (handlerError) {
debugLogger.error('Fallback handler failed:', handlerError);
@@ -209,25 +210,42 @@ async function processIntent(
authType?: string,
error?: unknown,
): Promise<boolean> {
const isAvailabilityEnabled = config.isModelAvailabilityServiceEnabled();
switch (intent) {
case 'retry_always':
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
// For all other errors, activate previewModel fallback.
if (
failedModel === PREVIEW_GEMINI_MODEL &&
!(error instanceof TerminalQuotaError)
) {
activatePreviewModelFallbackMode(config);
if (isAvailabilityEnabled) {
// TODO(telemetry): Implement generic fallback event logging. Existing
// logFlashFallback is specific to a single Model.
config.setActiveModel(fallbackModel);
} else {
activateFallbackMode(config, authType);
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
// For all other errors, activate previewModel fallback.
if (
failedModel === PREVIEW_GEMINI_MODEL &&
!(error instanceof TerminalQuotaError)
) {
activatePreviewModelFallbackMode(config);
} else {
activateFallbackMode(config, authType);
}
}
return true;
case 'retry_once':
if (isAvailabilityEnabled) {
config.setActiveModel(fallbackModel);
}
return true;
case 'stop':
activateFallbackMode(config, authType);
if (isAvailabilityEnabled) {
// TODO(telemetry): Implement generic fallback event logging. Existing
// logFlashFallback is specific to a single Model.
config.setActiveModel(fallbackModel);
} else {
activateFallbackMode(config, authType);
}
return false;
case 'retry_later':
@@ -260,16 +278,3 @@ function activatePreviewModelFallbackMode(config: Config) {
// We might want a specific event for Preview Model fallback, but for now we just set the mode.
}
}
function classifyFailureKind(error?: unknown): FailureKind {
if (error instanceof TerminalQuotaError) {
return 'terminal';
}
if (error instanceof RetryableQuotaError) {
return 'transient';
}
if (error instanceof ModelNotFoundError) {
return 'not_found';
}
return 'unknown';
}

View File

@@ -0,0 +1,23 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { ResolvedModelConfig } from '../services/modelConfigService.js';
/**
* Creates a ResolvedModelConfig with sensible defaults, allowing overrides.
*/
export const makeResolvedModelConfig = (
model: string,
overrides: Partial<ResolvedModelConfig['generateContentConfig']> = {},
): ResolvedModelConfig =>
({
model,
generateContentConfig: {
temperature: 0,
topP: 1,
...overrides,
},
}) as ResolvedModelConfig;

View File

@@ -17,6 +17,9 @@ import {
RetryableQuotaError,
} from './googleQuotaErrors.js';
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
import type { ModelPolicy } from '../availability/modelPolicy.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
// Helper to create a mock function that fails a certain number of times
const createFailingFunction = (
@@ -104,7 +107,6 @@ describe('retryWithBackoff', () => {
const promise = retryWithBackoff(mockFn);
// Expect it to fail with the error from the 5th attempt.
await Promise.all([
expect(promise).rejects.toThrow('Simulated error attempt 3'),
vi.runAllTimersAsync(),
@@ -566,4 +568,171 @@ describe('retryWithBackoff', () => {
);
expect(mockFn).toHaveBeenCalledTimes(2);
});
describe('Availability Context Integration', () => {
let mockService: ModelAvailabilityService;
let mockPolicy1: ModelPolicy;
let mockPolicy2: ModelPolicy;
beforeEach(() => {
vi.useRealTimers();
mockService = createAvailabilityServiceMock();
mockPolicy1 = {
model: 'model-1',
actions: {},
stateTransitions: {
terminal: 'terminal',
transient: 'sticky_retry',
},
};
mockPolicy2 = {
model: 'model-2',
actions: {},
stateTransitions: {
terminal: 'terminal',
},
};
});
it('updates availability context per attempt and applies transitions to the correct policy', async () => {
const error = new TerminalQuotaError(
'quota exceeded',
{ code: 429, message: 'quota', details: [] },
10,
);
const fn = vi.fn().mockImplementation(async () => {
throw error; // Always fail with quota
});
const onPersistent429 = vi
.fn()
.mockResolvedValueOnce('model-2') // First fallback success
.mockResolvedValueOnce(null); // Second fallback fails (give up)
// Context provider returns policy1 first, then policy2
const getContext = vi
.fn()
.mockReturnValueOnce({ service: mockService, policy: mockPolicy1 })
.mockReturnValueOnce({ service: mockService, policy: mockPolicy2 });
await expect(
retryWithBackoff(fn, {
maxAttempts: 3,
initialDelayMs: 1,
getAvailabilityContext: getContext,
onPersistent429,
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
).rejects.toThrow(TerminalQuotaError);
// Verify failures
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
expect(mockService.markTerminal).toHaveBeenCalledWith('model-2', 'quota');
// Verify sequences
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
1,
'model-1',
'quota',
);
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
2,
'model-2',
'quota',
);
});
it('marks sticky_retry after retries are exhausted for transient failures', async () => {
const transientError = new RetryableQuotaError(
'transient error',
{ code: 429, message: 'transient', details: [] },
0,
);
const fn = vi.fn().mockRejectedValue(transientError);
const getContext = vi
.fn()
.mockReturnValue({ service: mockService, policy: mockPolicy1 });
vi.useFakeTimers();
const promise = retryWithBackoff(fn, {
maxAttempts: 3,
getAvailabilityContext: getContext,
initialDelayMs: 1,
maxDelayMs: 1,
}).catch((err) => err);
await vi.runAllTimersAsync();
const result = await promise;
expect(result).toBe(transientError);
expect(fn).toHaveBeenCalledTimes(3);
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith('model-1');
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledTimes(1);
expect(mockService.markTerminal).not.toHaveBeenCalled();
});
it('maps different failure kinds to correct terminal reasons', async () => {
const quotaError = new TerminalQuotaError(
'quota',
{ code: 429, message: 'q', details: [] },
10,
);
const notFoundError = new ModelNotFoundError('not found', 404);
const genericError = new Error('unknown error');
const fn = vi
.fn()
.mockRejectedValueOnce(quotaError)
.mockRejectedValueOnce(notFoundError)
.mockRejectedValueOnce(genericError);
const policy: ModelPolicy = {
model: 'model-1',
actions: {},
stateTransitions: {
terminal: 'terminal', // from quotaError
not_found: 'terminal', // from notFoundError
unknown: 'terminal', // from genericError
},
};
const getContext = vi
.fn()
.mockReturnValue({ service: mockService, policy });
// Run for quotaError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
// Run for notFoundError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith(
'model-1',
'capacity',
);
// Run for genericError
await retryWithBackoff(fn, {
maxAttempts: 1,
getAvailabilityContext: getContext,
}).catch(() => {});
expect(mockService.markTerminal).toHaveBeenCalledWith(
'model-1',
'capacity',
);
expect(mockService.markTerminal).toHaveBeenCalledTimes(3);
});
});
});

View File

@@ -8,13 +8,18 @@ import type { GenerateContentResponse } from '@google/genai';
import { ApiError } from '@google/genai';
import { AuthType } from '../core/contentGenerator.js';
import {
classifyGoogleError,
RetryableQuotaError,
TerminalQuotaError,
RetryableQuotaError,
classifyGoogleError,
} from './googleQuotaErrors.js';
import { delay, createAbortError } from './delay.js';
import { debugLogger } from './debugLogger.js';
import { getErrorStatus, ModelNotFoundError } from './httpErrors.js';
import type { RetryAvailabilityContext } from '../availability/modelPolicy.js';
import { classifyFailureKind } from '../availability/errorClassification.js';
import { applyAvailabilityTransition } from '../availability/policyHelpers.js';
export type { RetryAvailabilityContext };
export interface RetryOptions {
maxAttempts: number;
@@ -29,6 +34,7 @@ export interface RetryOptions {
authType?: string;
retryFetchErrors?: boolean;
signal?: AbortSignal;
getAvailabilityContext?: () => RetryAvailabilityContext | undefined;
}
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
@@ -145,6 +151,7 @@ export async function retryWithBackoff<T>(
shouldRetryOnContent,
retryFetchErrors,
signal,
getAvailabilityContext,
} = {
...DEFAULT_RETRY_OPTIONS,
shouldRetryOnError: isRetryableError,
@@ -173,6 +180,11 @@ export async function retryWithBackoff<T>(
continue;
}
const successContext = getAvailabilityContext?.();
if (successContext) {
successContext.service.markHealthy(successContext.policy.model);
}
return result;
} catch (error) {
if (error instanceof Error && error.name === 'AbortError') {
@@ -180,6 +192,13 @@ export async function retryWithBackoff<T>(
}
const classifiedError = classifyGoogleError(error);
const failureKind = classifyFailureKind(classifiedError);
const appliedImmediate =
failureKind === 'terminal' || failureKind === 'not_found';
if (appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
const errorCode = getErrorStatus(error);
if (
@@ -201,6 +220,7 @@ export async function retryWithBackoff<T>(
debugLogger.warn('Fallback to Flash model failed:', fallbackError);
}
}
// Terminal/not_found already recorded; nothing else to mark here.
throw classifiedError; // Throw if no fallback or fallback failed.
}
@@ -224,6 +244,9 @@ export async function retryWithBackoff<T>(
console.warn('Model fallback failed:', fallbackError);
}
}
if (!appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
throw classifiedError instanceof RetryableQuotaError
? classifiedError
: error;
@@ -253,6 +276,9 @@ export async function retryWithBackoff<T>(
attempt >= maxAttempts ||
!shouldRetryOnError(error as Error, retryFetchErrors)
) {
if (!appliedImmediate) {
applyAvailabilityTransition(getAvailabilityContext, failureKind);
}
throw error;
}