mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-19 02:20:42 -07:00
feat(modelAvailabilityService): integrate model availability service into backend logic (#14470)
This commit is contained in:
25
packages/core/src/availability/errorClassification.ts
Normal file
25
packages/core/src/availability/errorClassification.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
TerminalQuotaError,
|
||||
RetryableQuotaError,
|
||||
} from '../utils/googleQuotaErrors.js';
|
||||
import { ModelNotFoundError } from '../utils/httpErrors.js';
|
||||
import type { FailureKind } from './modelPolicy.js';
|
||||
|
||||
export function classifyFailureKind(error: unknown): FailureKind {
|
||||
if (error instanceof TerminalQuotaError) {
|
||||
return 'terminal';
|
||||
}
|
||||
if (error instanceof RetryableQuotaError) {
|
||||
return 'transient';
|
||||
}
|
||||
if (error instanceof ModelNotFoundError) {
|
||||
return 'not_found';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
||||
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ModelHealthStatus, ModelId } from './modelAvailabilityService.js';
|
||||
import type {
|
||||
ModelAvailabilityService,
|
||||
ModelHealthStatus,
|
||||
ModelId,
|
||||
} from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Whether to prompt the user or fallback silently on a model API failure.
|
||||
@@ -49,3 +53,11 @@ export interface ModelPolicy {
|
||||
* The first model in the chain is the primary model.
|
||||
*/
|
||||
export type ModelPolicyChain = ModelPolicy[];
|
||||
|
||||
/**
|
||||
* Context required by retry logic to apply availability policies on failure.
|
||||
*/
|
||||
export interface RetryAvailabilityContext {
|
||||
service: ModelAvailabilityService;
|
||||
policy: ModelPolicy;
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ const PREVIEW_CHAIN: ModelPolicyChain = [
|
||||
definePolicy({
|
||||
model: PREVIEW_GEMINI_MODEL,
|
||||
stateTransitions: { transient: 'sticky_retry' },
|
||||
actions: { transient: 'silent' },
|
||||
}),
|
||||
definePolicy({ model: DEFAULT_GEMINI_MODEL }),
|
||||
definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }),
|
||||
|
||||
@@ -4,38 +4,54 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
resolvePolicyChain,
|
||||
buildFallbackPolicyContext,
|
||||
applyModelSelection,
|
||||
} from './policyHelpers.js';
|
||||
import { createDefaultPolicy } from './policyCatalog.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
getModel: () => 'gemini-2.5-pro',
|
||||
isInFallbackMode: () => false,
|
||||
...overrides,
|
||||
}) as unknown as Config;
|
||||
|
||||
describe('policyHelpers', () => {
|
||||
describe('resolvePolicyChain', () => {
|
||||
it('inserts the active model when missing from the catalog', () => {
|
||||
const config = {
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'custom-model',
|
||||
isInFallbackMode: () => false,
|
||||
} as unknown as Config;
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
expect(chain).toHaveLength(1);
|
||||
expect(chain[0]?.model).toBe('custom-model');
|
||||
});
|
||||
|
||||
it('leaves catalog order untouched when active model already present', () => {
|
||||
const config = {
|
||||
getPreviewFeatures: () => false,
|
||||
getUserTier: () => undefined,
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'gemini-2.5-pro',
|
||||
isInFallbackMode: () => false,
|
||||
} as unknown as Config;
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
expect(chain[0]?.model).toBe('gemini-2.5-pro');
|
||||
});
|
||||
|
||||
it('returns the default chain when active model is "auto"', () => {
|
||||
const config = createMockConfig({
|
||||
getModel: () => 'auto',
|
||||
});
|
||||
const chain = resolvePolicyChain(config);
|
||||
|
||||
// Expect default chain [Pro, Flash]
|
||||
expect(chain).toHaveLength(2);
|
||||
expect(chain[0]?.model).toBe('gemini-2.5-pro');
|
||||
expect(chain[1]?.model).toBe('gemini-2.5-flash');
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildFallbackPolicyContext', () => {
|
||||
@@ -57,4 +73,112 @@ describe('policyHelpers', () => {
|
||||
expect(context.candidates).toEqual(chain);
|
||||
});
|
||||
});
|
||||
|
||||
describe('applyModelSelection', () => {
|
||||
const mockModelConfigService = {
|
||||
getResolvedConfig: vi.fn(),
|
||||
};
|
||||
|
||||
const mockAvailabilityService = {
|
||||
selectFirstAvailable: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
};
|
||||
|
||||
const createExtendedMockConfig = (
|
||||
overrides: Partial<Config> = {},
|
||||
): Config => {
|
||||
const defaults = {
|
||||
isModelAvailabilityServiceEnabled: () => true,
|
||||
getModelAvailabilityService: () => mockAvailabilityService,
|
||||
setActiveModel: vi.fn(),
|
||||
modelConfigService: mockModelConfigService,
|
||||
};
|
||||
return createMockConfig({ ...defaults, ...overrides } as Partial<Config>);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns requested model if availability service is disabled', () => {
|
||||
const config = createExtendedMockConfig({
|
||||
isModelAvailabilityServiceEnabled: () => false,
|
||||
});
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(result.model).toBe('gemini-pro');
|
||||
expect(config.setActiveModel).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('returns requested model if it is available', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(result.model).toBe('gemini-pro');
|
||||
expect(result.maxAttempts).toBeUndefined();
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro');
|
||||
});
|
||||
|
||||
it('switches to backup model and updates config if requested is unavailable', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-flash',
|
||||
});
|
||||
mockModelConfigService.getResolvedConfig.mockReturnValue({
|
||||
generateContentConfig: { temperature: 0.1 },
|
||||
});
|
||||
|
||||
const currentConfig = { temperature: 0.9, topP: 1 };
|
||||
const result = applyModelSelection(config, 'gemini-pro', currentConfig);
|
||||
|
||||
expect(result.model).toBe('gemini-flash');
|
||||
expect(result.config).toEqual({
|
||||
temperature: 0.1,
|
||||
topP: 1,
|
||||
});
|
||||
|
||||
expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({
|
||||
model: 'gemini-flash',
|
||||
});
|
||||
expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash');
|
||||
});
|
||||
|
||||
it('consumes sticky attempt if indicated', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(config, 'gemini-pro');
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'gemini-pro',
|
||||
);
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
|
||||
it('does not consume sticky attempt if consumeAttempt is false', () => {
|
||||
const config = createExtendedMockConfig();
|
||||
mockAvailabilityService.selectFirstAvailable.mockReturnValue({
|
||||
selectedModel: 'gemini-pro',
|
||||
attempts: 1,
|
||||
});
|
||||
|
||||
const result = applyModelSelection(
|
||||
config,
|
||||
'gemini-pro',
|
||||
undefined,
|
||||
undefined,
|
||||
{
|
||||
consumeAttempt: false,
|
||||
},
|
||||
);
|
||||
expect(
|
||||
mockAvailabilityService.consumeStickyAttempt,
|
||||
).not.toHaveBeenCalled();
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,34 +4,47 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GenerateContentConfig } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type {
|
||||
FailureKind,
|
||||
FallbackAction,
|
||||
ModelPolicy,
|
||||
ModelPolicyChain,
|
||||
RetryAvailabilityContext,
|
||||
} from './modelPolicy.js';
|
||||
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js';
|
||||
import { getEffectiveModel } from '../config/models.js';
|
||||
import { DEFAULT_GEMINI_MODEL, getEffectiveModel } from '../config/models.js';
|
||||
import type { ModelSelectionResult } from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Resolves the active policy chain for the given config, ensuring the
|
||||
* user-selected active model is represented.
|
||||
*/
|
||||
export function resolvePolicyChain(config: Config): ModelPolicyChain {
|
||||
export function resolvePolicyChain(
|
||||
config: Config,
|
||||
preferredModel?: string,
|
||||
): ModelPolicyChain {
|
||||
const chain = getModelPolicyChain({
|
||||
previewEnabled: !!config.getPreviewFeatures(),
|
||||
userTier: config.getUserTier(),
|
||||
});
|
||||
// TODO: This will be replaced when we get rid of Fallback Modes
|
||||
const activeModel = getEffectiveModel(
|
||||
config.isInFallbackMode(),
|
||||
config.getModel(),
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
// TODO: This will be replaced when we get rid of Fallback Modes.
|
||||
// Switch to getActiveModel()
|
||||
const activeModel =
|
||||
preferredModel ??
|
||||
getEffectiveModel(
|
||||
config.isInFallbackMode(),
|
||||
config.getModel(),
|
||||
config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
if (activeModel === 'auto') {
|
||||
return [...chain];
|
||||
}
|
||||
|
||||
if (chain.some((policy) => policy.model === activeModel)) {
|
||||
return chain;
|
||||
return [...chain];
|
||||
}
|
||||
|
||||
// If the user specified a model not in the default chain, we assume they want
|
||||
@@ -68,3 +81,120 @@ export function resolvePolicyAction(
|
||||
): FallbackAction {
|
||||
return policy.actions?.[failureKind] ?? 'prompt';
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a context provider for retry logic that returns the availability
|
||||
* sevice and resolves the current model's policy.
|
||||
*
|
||||
* @param modelGetter A function that returns the model ID currently being attempted.
|
||||
* (Allows handling dynamic model changes during retries).
|
||||
*/
|
||||
export function createAvailabilityContextProvider(
|
||||
config: Config,
|
||||
modelGetter: () => string,
|
||||
): () => RetryAvailabilityContext | undefined {
|
||||
return () => {
|
||||
if (!config.isModelAvailabilityServiceEnabled()) {
|
||||
return undefined;
|
||||
}
|
||||
const service = config.getModelAvailabilityService();
|
||||
const currentModel = modelGetter();
|
||||
|
||||
// Resolve the chain for the specific model we are attempting.
|
||||
const chain = resolvePolicyChain(config, currentModel);
|
||||
const policy = chain.find((p) => p.model === currentModel);
|
||||
|
||||
return policy ? { service, policy } : undefined;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Selects the model to use for an attempt via the availability service and
|
||||
* returns the selection context.
|
||||
*/
|
||||
export function selectModelForAvailability(
|
||||
config: Config,
|
||||
requestedModel: string,
|
||||
): ModelSelectionResult | undefined {
|
||||
if (!config.isModelAvailabilityServiceEnabled()) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const chain = resolvePolicyChain(config, requestedModel);
|
||||
const selection = config
|
||||
.getModelAvailabilityService()
|
||||
.selectFirstAvailable(chain.map((p) => p.model));
|
||||
|
||||
if (selection.selectedModel) return selection;
|
||||
|
||||
const backupModel =
|
||||
chain.find((p) => p.isLastResort)?.model ?? DEFAULT_GEMINI_MODEL;
|
||||
|
||||
return { selectedModel: backupModel, skipped: [] };
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies the model availability selection logic, including side effects
|
||||
* (setting active model, consuming sticky attempts) and config updates.
|
||||
*/
|
||||
export function applyModelSelection(
|
||||
config: Config,
|
||||
requestedModel: string,
|
||||
currentConfig?: GenerateContentConfig,
|
||||
overrideScope?: string,
|
||||
options: { consumeAttempt?: boolean } = {},
|
||||
): { model: string; config?: GenerateContentConfig; maxAttempts?: number } {
|
||||
const selection = selectModelForAvailability(config, requestedModel);
|
||||
|
||||
if (!selection?.selectedModel) {
|
||||
return { model: requestedModel, config: currentConfig };
|
||||
}
|
||||
|
||||
const finalModel = selection.selectedModel;
|
||||
let finalConfig = currentConfig;
|
||||
|
||||
// If model changed, re-resolve config
|
||||
if (finalModel !== requestedModel) {
|
||||
const { generateContentConfig } =
|
||||
config.modelConfigService.getResolvedConfig({
|
||||
overrideScope,
|
||||
model: finalModel,
|
||||
});
|
||||
|
||||
finalConfig = currentConfig
|
||||
? { ...currentConfig, ...generateContentConfig }
|
||||
: generateContentConfig;
|
||||
}
|
||||
|
||||
config.setActiveModel(finalModel);
|
||||
|
||||
if (selection.attempts && options.consumeAttempt !== false) {
|
||||
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
||||
}
|
||||
|
||||
return {
|
||||
model: finalModel,
|
||||
config: finalConfig,
|
||||
maxAttempts: selection.attempts,
|
||||
};
|
||||
}
|
||||
|
||||
export function applyAvailabilityTransition(
|
||||
getContext: (() => RetryAvailabilityContext | undefined) | undefined,
|
||||
failureKind: FailureKind,
|
||||
): void {
|
||||
const context = getContext?.();
|
||||
if (!context) return;
|
||||
|
||||
const transition = context.policy.stateTransitions?.[failureKind];
|
||||
if (!transition) return;
|
||||
|
||||
if (transition === 'terminal') {
|
||||
context.service.markTerminal(
|
||||
context.policy.model,
|
||||
failureKind === 'terminal' ? 'quota' : 'capacity',
|
||||
);
|
||||
} else if (transition === 'sticky_retry') {
|
||||
context.service.markRetryOncePerTurn(context.policy.model);
|
||||
}
|
||||
}
|
||||
|
||||
30
packages/core/src/availability/testUtils.ts
Normal file
30
packages/core/src/availability/testUtils.ts
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { vi } from 'vitest';
|
||||
import type {
|
||||
ModelAvailabilityService,
|
||||
ModelSelectionResult,
|
||||
} from './modelAvailabilityService.js';
|
||||
|
||||
/**
|
||||
* Test helper to create a fully mocked ModelAvailabilityService.
|
||||
*/
|
||||
export function createAvailabilityServiceMock(
|
||||
selection: ModelSelectionResult = { selectedModel: null, skipped: [] },
|
||||
): ModelAvailabilityService {
|
||||
const service = {
|
||||
markTerminal: vi.fn(),
|
||||
markHealthy: vi.fn(),
|
||||
markRetryOncePerTurn: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
snapshot: vi.fn(),
|
||||
resetTurn: vi.fn(),
|
||||
selectFirstAvailable: vi.fn().mockReturnValue(selection),
|
||||
};
|
||||
|
||||
return service as unknown as ModelAvailabilityService;
|
||||
}
|
||||
@@ -1632,3 +1632,42 @@ describe('Config setExperiments logging', () => {
|
||||
debugSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Availability Service Integration', () => {
|
||||
const baseModel = 'test-model';
|
||||
const baseParams: ConfigParameters = {
|
||||
sessionId: 'test',
|
||||
targetDir: '.',
|
||||
debugMode: false,
|
||||
model: baseModel,
|
||||
cwd: '.',
|
||||
};
|
||||
|
||||
it('setActiveModel updates active model and emits event', async () => {
|
||||
const config = new Config(baseParams);
|
||||
const model1 = 'model1';
|
||||
const model2 = 'model2';
|
||||
|
||||
config.setActiveModel(model1);
|
||||
expect(config.getActiveModel()).toBe(model1);
|
||||
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model1);
|
||||
|
||||
config.setActiveModel(model2);
|
||||
expect(config.getActiveModel()).toBe(model2);
|
||||
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(model2);
|
||||
});
|
||||
|
||||
it('getActiveModel defaults to configured model if not set', () => {
|
||||
const config = new Config(baseParams);
|
||||
expect(config.getActiveModel()).toBe(baseModel);
|
||||
});
|
||||
|
||||
it('resetTurn delegates to availability service', () => {
|
||||
const config = new Config(baseParams);
|
||||
const service = config.getModelAvailabilityService();
|
||||
const spy = vi.spyOn(service, 'resetTurn');
|
||||
|
||||
config.resetTurn();
|
||||
expect(spy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -382,6 +382,7 @@ export class Config {
|
||||
private ideMode: boolean;
|
||||
|
||||
private inFallbackMode = false;
|
||||
private _activeModel: string;
|
||||
private readonly maxSessionTurns: number;
|
||||
private readonly listSessions: boolean;
|
||||
private readonly deleteSession: string | undefined;
|
||||
@@ -504,6 +505,7 @@ export class Config {
|
||||
this.fileDiscoveryService = params.fileDiscoveryService ?? null;
|
||||
this.bugCommand = params.bugCommand;
|
||||
this.model = params.model;
|
||||
this._activeModel = params.model;
|
||||
this.enableModelAvailabilityService =
|
||||
params.enableModelAvailabilityService ?? false;
|
||||
this.enableAgents = params.enableAgents ?? false;
|
||||
@@ -810,11 +812,28 @@ export class Config {
|
||||
setModel(newModel: string): void {
|
||||
if (this.model !== newModel || this.inFallbackMode) {
|
||||
this.model = newModel;
|
||||
// When the user explicitly sets a model, that becomes the active model.
|
||||
this._activeModel = newModel;
|
||||
coreEvents.emitModelChanged(newModel);
|
||||
}
|
||||
this.setFallbackMode(false);
|
||||
}
|
||||
|
||||
getActiveModel(): string {
|
||||
return this._activeModel ?? this.model;
|
||||
}
|
||||
|
||||
setActiveModel(model: string): void {
|
||||
if (this._activeModel !== model) {
|
||||
this._activeModel = model;
|
||||
coreEvents.emitModelChanged(model);
|
||||
}
|
||||
}
|
||||
|
||||
resetTurn(): void {
|
||||
this.modelAvailabilityService.resetTurn();
|
||||
}
|
||||
|
||||
isInFallbackMode(): boolean {
|
||||
return this.inFallbackMode;
|
||||
}
|
||||
|
||||
@@ -15,9 +15,12 @@ import {
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import { BaseLlmClient, type GenerateJsonOptions } from './baseLlmClient.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { GenerateContentOptions } from './baseLlmClient.js';
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
@@ -25,6 +28,8 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { ModelConfigService } from '../services/modelConfigService.js';
|
||||
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
|
||||
|
||||
vi.mock('../utils/errorReporting.js');
|
||||
vi.mock('../telemetry/loggers.js');
|
||||
@@ -58,6 +63,11 @@ vi.mock('../utils/retry.js', () => ({
|
||||
}
|
||||
}
|
||||
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
|
||||
return result;
|
||||
}),
|
||||
}));
|
||||
@@ -97,14 +107,18 @@ describe('BaseLlmClient', () => {
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation(({ model }) => ({
|
||||
model,
|
||||
generateContentConfig: {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
},
|
||||
})),
|
||||
},
|
||||
getResolvedConfig: vi
|
||||
.fn()
|
||||
.mockImplementation(({ model }) => makeResolvedModelConfig(model)),
|
||||
} as unknown as ModelConfigService,
|
||||
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
|
||||
getModelAvailabilityService: vi.fn(),
|
||||
setActiveModel: vi.fn(),
|
||||
getPreviewFeatures: vi.fn().mockReturnValue(false),
|
||||
getUserTier: vi.fn().mockReturnValue(undefined),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
getModel: vi.fn().mockReturnValue('test-model'),
|
||||
getActiveModel: vi.fn().mockReturnValue('test-model'),
|
||||
} as unknown as Mocked<Config>;
|
||||
|
||||
client = new BaseLlmClient(mockContentGenerator, mockConfig);
|
||||
@@ -593,4 +607,243 @@ describe('BaseLlmClient', () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Availability Service Integration', () => {
|
||||
let mockAvailabilityService: ModelAvailabilityService;
|
||||
let contentOptions: GenerateContentOptions;
|
||||
let jsonOptions: GenerateJsonOptions;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.isModelAvailabilityServiceEnabled = vi
|
||||
.fn()
|
||||
.mockReturnValue(true);
|
||||
|
||||
mockAvailabilityService = createAvailabilityServiceMock({
|
||||
selectedModel: 'test-model',
|
||||
skipped: [],
|
||||
});
|
||||
|
||||
// Reflect setActiveModel into getActiveModel so availability-driven updates
|
||||
// are visible to the client under test.
|
||||
mockConfig.getActiveModel = vi.fn().mockReturnValue('test-model');
|
||||
mockConfig.setActiveModel = vi.fn((model: string) => {
|
||||
vi.mocked(mockConfig.getActiveModel).mockReturnValue(model);
|
||||
});
|
||||
|
||||
vi.spyOn(mockConfig, 'getModelAvailabilityService').mockReturnValue(
|
||||
mockAvailabilityService,
|
||||
);
|
||||
|
||||
contentOptions = {
|
||||
modelConfigKey: { model: 'test-model' },
|
||||
contents: [{ role: 'user', parts: [{ text: 'Give me a color.' }] }],
|
||||
abortSignal: abortController.signal,
|
||||
promptId: 'content-prompt-id',
|
||||
};
|
||||
|
||||
jsonOptions = {
|
||||
...defaultOptions,
|
||||
promptId: 'json-prompt-id',
|
||||
};
|
||||
});
|
||||
|
||||
it('should preserve legacy behavior when availability is disabled', async () => {
|
||||
mockConfig.isModelAvailabilityServiceEnabled = vi
|
||||
.fn()
|
||||
.mockReturnValue(false);
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse('Some text response'),
|
||||
);
|
||||
|
||||
await client.generateContent(contentOptions);
|
||||
|
||||
expect(
|
||||
mockAvailabilityService.selectFirstAvailable,
|
||||
).not.toHaveBeenCalled();
|
||||
expect(mockConfig.setActiveModel).not.toHaveBeenCalled();
|
||||
expect(mockAvailabilityService.markHealthy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should mark model as healthy on success', async () => {
|
||||
const successfulModel = 'gemini-pro';
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: successfulModel,
|
||||
skipped: [],
|
||||
});
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse('Some text response'),
|
||||
);
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: successfulModel },
|
||||
});
|
||||
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
successfulModel,
|
||||
);
|
||||
});
|
||||
|
||||
it('marks the final attempted model healthy after a retry with availability enabled', async () => {
|
||||
const firstModel = 'gemini-pro';
|
||||
const fallbackModel = 'gemini-flash';
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable)
|
||||
.mockReturnValueOnce({ selectedModel: firstModel, skipped: [] })
|
||||
.mockReturnValueOnce({ selectedModel: fallbackModel, skipped: [] });
|
||||
|
||||
mockGenerateContent
|
||||
.mockResolvedValueOnce(createMockResponse('retry-me'))
|
||||
.mockResolvedValueOnce(createMockResponse('final-response'));
|
||||
|
||||
// Run the real retryWithBackoff (with fake timers) to exercise the retry path
|
||||
vi.useFakeTimers();
|
||||
|
||||
const retryPromise = client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: firstModel },
|
||||
maxAttempts: 2,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await retryPromise;
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: firstModel },
|
||||
maxAttempts: 2,
|
||||
});
|
||||
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(firstModel);
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(fallbackModel);
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
fallbackModel,
|
||||
);
|
||||
expect(mockGenerateContent).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({ model: fallbackModel }),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should consume sticky attempt if selection has attempts', async () => {
|
||||
const stickyModel = 'gemini-pro-sticky';
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: stickyModel,
|
||||
attempts: 1,
|
||||
skipped: [],
|
||||
});
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse('Some text response'),
|
||||
);
|
||||
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
|
||||
const result = await fn();
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: stickyModel },
|
||||
});
|
||||
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
stickyModel,
|
||||
);
|
||||
expect(retryWithBackoff).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
expect.objectContaining({ maxAttempts: 1 }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should mark healthy and honor availability selection when using generateJson', async () => {
|
||||
const availableModel = 'gemini-json-pro';
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: availableModel,
|
||||
skipped: [],
|
||||
});
|
||||
mockGenerateContent.mockResolvedValue(
|
||||
createMockResponse('{"color":"violet"}'),
|
||||
);
|
||||
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
|
||||
const result = await fn();
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
|
||||
const result = await client.generateJson(jsonOptions);
|
||||
|
||||
expect(result).toEqual({ color: 'violet' });
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith(availableModel);
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
availableModel,
|
||||
);
|
||||
expect(mockGenerateContent).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({ model: availableModel }),
|
||||
jsonOptions.promptId,
|
||||
);
|
||||
});
|
||||
|
||||
it('should refresh configuration when model changes mid-retry', async () => {
|
||||
const firstModel = 'gemini-pro';
|
||||
const fallbackModel = 'gemini-flash';
|
||||
|
||||
// Provide distinct configs per model
|
||||
const getResolvedConfigMock = vi.mocked(
|
||||
mockConfig.modelConfigService.getResolvedConfig,
|
||||
);
|
||||
getResolvedConfigMock
|
||||
.mockReturnValueOnce(
|
||||
makeResolvedModelConfig(firstModel, { temperature: 0.1 }),
|
||||
)
|
||||
.mockReturnValueOnce(
|
||||
makeResolvedModelConfig(fallbackModel, { temperature: 0.9 }),
|
||||
);
|
||||
|
||||
// Availability selects the first model initially
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: firstModel,
|
||||
skipped: [],
|
||||
});
|
||||
|
||||
// Change active model after the first attempt
|
||||
let activeModel = firstModel;
|
||||
mockConfig.setActiveModel = vi.fn(); // Prevent setActiveModel from resetting getActiveModel mock
|
||||
mockConfig.getActiveModel.mockImplementation(() => activeModel);
|
||||
|
||||
// First response empty -> triggers retry; second response valid
|
||||
mockGenerateContent
|
||||
.mockResolvedValueOnce(createMockResponse(''))
|
||||
.mockResolvedValueOnce(createMockResponse('final-response'));
|
||||
|
||||
// Custom retry to force two attempts
|
||||
vi.mocked(retryWithBackoff).mockImplementation(async (fn, options) => {
|
||||
const first = (await fn()) as GenerateContentResponse;
|
||||
if (options?.shouldRetryOnContent?.(first)) {
|
||||
activeModel = fallbackModel; // simulate handler switching active model before retry
|
||||
return (await fn()) as GenerateContentResponse;
|
||||
}
|
||||
return first;
|
||||
});
|
||||
|
||||
await client.generateContent({
|
||||
...contentOptions,
|
||||
modelConfigKey: { model: firstModel },
|
||||
maxAttempts: 2,
|
||||
});
|
||||
|
||||
expect(mockGenerateContent).toHaveBeenCalledTimes(2);
|
||||
const secondCall = mockGenerateContent.mock.calls[1]?.[0];
|
||||
|
||||
expect(
|
||||
mockConfig.modelConfigService.getResolvedConfig,
|
||||
).toHaveBeenCalledWith({ model: fallbackModel });
|
||||
expect(secondCall?.model).toBe(fallbackModel);
|
||||
expect(secondCall?.config?.temperature).toBe(0.9);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -20,6 +20,10 @@ import { logMalformedJsonResponse } from '../telemetry/loggers.js';
|
||||
import { MalformedJsonResponseEvent } from '../telemetry/types.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
|
||||
const DEFAULT_MAX_ATTEMPTS = 5;
|
||||
|
||||
@@ -232,13 +236,56 @@ export class BaseLlmClient {
|
||||
): Promise<GenerateContentResponse> {
|
||||
const abortSignal = requestParams.config?.abortSignal;
|
||||
|
||||
// Define callback to fetch context dynamically since active model may get updated during retry loop
|
||||
const getAvailabilityContext = createAvailabilityContextProvider(
|
||||
this.config,
|
||||
() => requestParams.model,
|
||||
);
|
||||
|
||||
const {
|
||||
model,
|
||||
config: newConfig,
|
||||
maxAttempts: availabilityMaxAttempts,
|
||||
} = applyModelSelection(
|
||||
this.config,
|
||||
requestParams.model,
|
||||
requestParams.config,
|
||||
);
|
||||
requestParams.model = model;
|
||||
if (newConfig) {
|
||||
requestParams.config = newConfig;
|
||||
}
|
||||
if (abortSignal) {
|
||||
requestParams.config = { ...requestParams.config, abortSignal };
|
||||
}
|
||||
|
||||
try {
|
||||
const apiCall = () =>
|
||||
this.contentGenerator.generateContent(requestParams, promptId);
|
||||
const apiCall = () => {
|
||||
// If availability is enabled, ensure we use the current active model
|
||||
// in case a fallback occurred in a previous attempt.
|
||||
if (this.config.isModelAvailabilityServiceEnabled()) {
|
||||
const activeModel = this.config.getActiveModel();
|
||||
if (activeModel !== requestParams.model) {
|
||||
requestParams.model = activeModel;
|
||||
// Re-resolve config if model changed during retry
|
||||
const { generateContentConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
model: activeModel,
|
||||
});
|
||||
requestParams.config = {
|
||||
...requestParams.config,
|
||||
...generateContentConfig,
|
||||
};
|
||||
}
|
||||
}
|
||||
return this.contentGenerator.generateContent(requestParams, promptId);
|
||||
};
|
||||
|
||||
return await retryWithBackoff(apiCall, {
|
||||
shouldRetryOnContent,
|
||||
maxAttempts: maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
|
||||
maxAttempts:
|
||||
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS,
|
||||
getAvailabilityContext,
|
||||
});
|
||||
} catch (error) {
|
||||
if (abortSignal?.aborted) {
|
||||
|
||||
@@ -38,12 +38,15 @@ import { ideContextStore } from '../ide/ideContext.js';
|
||||
import type { ModelRouterService } from '../routing/modelRouterService.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import { ChatCompressionService } from '../services/chatCompressionService.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import type {
|
||||
ModelConfigKey,
|
||||
ResolvedModelConfig,
|
||||
} from '../services/modelConfigService.js';
|
||||
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import * as policyCatalog from '../availability/policyCatalog.js';
|
||||
|
||||
vi.mock('../services/chatCompressionService.js');
|
||||
|
||||
@@ -145,6 +148,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
let mockConfig: Config;
|
||||
let client: GeminiClient;
|
||||
let mockGenerateContentFn: Mock;
|
||||
let mockRouterService: { route: Mock };
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
ClearcutLogger.clearInstance();
|
||||
@@ -166,6 +170,12 @@ describe('Gemini Client (client.ts)', () => {
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
|
||||
mockRouterService = {
|
||||
route: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ model: 'default-routed-model', reason: 'test' }),
|
||||
};
|
||||
|
||||
mockContentGenerator = {
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
@@ -192,6 +202,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
.mockReturnValue(contentGeneratorConfig),
|
||||
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
|
||||
getModel: vi.fn().mockReturnValue('test-model'),
|
||||
getUserTier: vi.fn().mockReturnValue(undefined),
|
||||
getEmbeddingModel: vi.fn().mockReturnValue('test-embedding-model'),
|
||||
getApiKey: vi.fn().mockReturnValue('test-key'),
|
||||
getVertexAI: vi.fn().mockReturnValue(false),
|
||||
@@ -215,9 +226,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
|
||||
}),
|
||||
getGeminiClient: vi.fn(),
|
||||
getModelRouterService: vi.fn().mockReturnValue({
|
||||
route: vi.fn().mockResolvedValue({ model: 'default-routed-model' }),
|
||||
}),
|
||||
getModelRouterService: vi
|
||||
.fn()
|
||||
.mockReturnValue(mockRouterService as unknown as ModelRouterService),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
@@ -251,6 +262,11 @@ describe('Gemini Client (client.ts)', () => {
|
||||
},
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getExperiments: () => {},
|
||||
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
|
||||
getActiveModel: vi.fn().mockReturnValue('test-model'),
|
||||
setActiveModel: vi.fn(),
|
||||
resetTurn: vi.fn(),
|
||||
getModelAvailabilityService: vi.fn(),
|
||||
} as unknown as Config;
|
||||
mockConfig.getHookSystem = vi
|
||||
.fn()
|
||||
@@ -1614,6 +1630,7 @@ ${JSON.stringify(
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
|
||||
{ type: GeminiEventType.InvalidStream },
|
||||
{ type: GeminiEventType.ModelInfo, value: 'default-routed-model' },
|
||||
{ type: GeminiEventType.Content, value: 'Continued content' },
|
||||
]);
|
||||
|
||||
@@ -1701,8 +1718,8 @@ ${JSON.stringify(
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
// We expect 3 events (model_info + original + 1 retry)
|
||||
expect(events.length).toBe(3);
|
||||
// We expect 4 events (model_info + original + model_info + 1 retry)
|
||||
expect(events.length).toBe(4);
|
||||
expect(
|
||||
events
|
||||
.filter((e) => e.type !== GeminiEventType.ModelInfo)
|
||||
@@ -1977,6 +1994,154 @@ ${JSON.stringify(
|
||||
});
|
||||
});
|
||||
|
||||
describe('Availability Service Integration', () => {
|
||||
let mockAvailabilityService: ModelAvailabilityService;
|
||||
|
||||
beforeEach(() => {
|
||||
mockAvailabilityService = createAvailabilityServiceMock();
|
||||
|
||||
vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue(
|
||||
mockAvailabilityService,
|
||||
);
|
||||
vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue(
|
||||
true,
|
||||
);
|
||||
vi.mocked(mockConfig.setActiveModel).mockClear();
|
||||
mockRouterService.route.mockResolvedValue({
|
||||
model: 'model-a',
|
||||
reason: 'test',
|
||||
});
|
||||
vi.mocked(mockConfig.getModelRouterService).mockReturnValue(
|
||||
mockRouterService as unknown as ModelRouterService,
|
||||
);
|
||||
vi.spyOn(policyCatalog, 'getModelPolicyChain').mockReturnValue([
|
||||
{
|
||||
model: 'model-a',
|
||||
isLastResort: false,
|
||||
actions: {},
|
||||
stateTransitions: {},
|
||||
},
|
||||
{
|
||||
model: 'model-b',
|
||||
isLastResort: true,
|
||||
actions: {},
|
||||
stateTransitions: {},
|
||||
},
|
||||
]);
|
||||
|
||||
mockTurnRunFn.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should select first available model, set active, and not consume sticky attempt (done lower in chain)', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
|
||||
{
|
||||
selectedModel: 'model-a',
|
||||
attempts: 1,
|
||||
skipped: [],
|
||||
},
|
||||
);
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-avail',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(
|
||||
mockAvailabilityService.selectFirstAvailable,
|
||||
).toHaveBeenCalledWith(['model-a', 'model-b']);
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-a');
|
||||
expect(
|
||||
mockAvailabilityService.consumeStickyAttempt,
|
||||
).not.toHaveBeenCalled();
|
||||
// Ensure turn.run used the selected model
|
||||
expect(mockTurnRunFn).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ model: 'model-a' }),
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
|
||||
it('should default to last resort model if selection returns null', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
|
||||
{
|
||||
selectedModel: null,
|
||||
skipped: [],
|
||||
},
|
||||
);
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-avail-fallback',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockConfig.setActiveModel).toHaveBeenCalledWith('model-b'); // Last resort
|
||||
expect(
|
||||
mockAvailabilityService.consumeStickyAttempt,
|
||||
).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should reset turn on new message stream', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
|
||||
{
|
||||
selectedModel: 'model-a',
|
||||
skipped: [],
|
||||
},
|
||||
);
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-reset',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
expect(mockConfig.resetTurn).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should NOT reset turn on invalid stream retry', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue(
|
||||
{
|
||||
selectedModel: 'model-a',
|
||||
skipped: [],
|
||||
},
|
||||
);
|
||||
// We simulate a retry by calling sendMessageStream with isInvalidStreamRetry=true
|
||||
// But the public API doesn't expose that argument directly unless we use the private method or simulate the recursion.
|
||||
// We can simulate recursion by mocking turn run to return invalid stream once.
|
||||
|
||||
vi.spyOn(
|
||||
client['config'],
|
||||
'getContinueOnFailedApiCall',
|
||||
).mockReturnValue(true);
|
||||
const mockStream1 = (async function* () {
|
||||
yield { type: GeminiEventType.InvalidStream };
|
||||
})();
|
||||
const mockStream2 = (async function* () {
|
||||
yield { type: 'content', value: 'ok' };
|
||||
})();
|
||||
mockTurnRunFn
|
||||
.mockReturnValueOnce(mockStream1)
|
||||
.mockReturnValueOnce(mockStream2);
|
||||
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-retry',
|
||||
);
|
||||
await fromAsync(stream);
|
||||
|
||||
// resetTurn should be called once (for the initial call) but NOT for the recursive call
|
||||
expect(mockConfig.resetTurn).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('IDE context with pending tool calls', () => {
|
||||
let mockChat: Partial<GeminiChat>;
|
||||
|
||||
|
||||
@@ -57,6 +57,11 @@ import type { RoutingContext } from '../routing/routingStrategy.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { calculateRequestTokenCount } from '../utils/tokenCalculation.js';
|
||||
import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
import type { RetryAvailabilityContext } from '../utils/retry.js';
|
||||
|
||||
const MAX_TURNS = 100;
|
||||
|
||||
@@ -405,6 +410,10 @@ export class GeminiClient {
|
||||
turns: number = MAX_TURNS,
|
||||
isInvalidStreamRetry: boolean = false,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!isInvalidStreamRetry) {
|
||||
this.config.resetTurn();
|
||||
}
|
||||
|
||||
// Fire BeforeAgent hook through MessageBus (only if hooks are enabled)
|
||||
const hooksEnabled = this.config.getEnableHooks();
|
||||
const messageBus = this.config.getMessageBus();
|
||||
@@ -534,11 +543,21 @@ export class GeminiClient {
|
||||
const router = await this.config.getModelRouterService();
|
||||
const decision = await router.route(routingContext);
|
||||
modelToUse = decision.model;
|
||||
// Lock the model for the rest of the sequence
|
||||
this.currentSequenceModel = modelToUse;
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
}
|
||||
|
||||
// availability logic
|
||||
const { model: finalModel } = applyModelSelection(
|
||||
this.config,
|
||||
modelToUse,
|
||||
undefined,
|
||||
undefined,
|
||||
{ consumeAttempt: false },
|
||||
);
|
||||
modelToUse = finalModel;
|
||||
|
||||
this.currentSequenceModel = modelToUse;
|
||||
yield { type: GeminiEventType.ModelInfo, value: modelToUse };
|
||||
|
||||
const resultStream = turn.run({ model: modelToUse }, request, linkedSignal);
|
||||
for await (const event of resultStream) {
|
||||
if (this.loopDetector.addAndCheck(event)) {
|
||||
@@ -665,14 +684,53 @@ export class GeminiClient {
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(this.config, userMemory);
|
||||
const {
|
||||
model,
|
||||
config: newConfig,
|
||||
maxAttempts: availabilityMaxAttempts,
|
||||
} = applyModelSelection(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
currentAttemptGenerateContentConfig,
|
||||
modelConfigKey.overrideScope,
|
||||
);
|
||||
currentAttemptModel = model;
|
||||
if (newConfig) {
|
||||
currentAttemptGenerateContentConfig = newConfig;
|
||||
}
|
||||
|
||||
// Define callback to refresh context based on currentAttemptModel which might be updated by fallback handler
|
||||
const getAvailabilityContext: () => RetryAvailabilityContext | undefined =
|
||||
createAvailabilityContextProvider(
|
||||
this.config,
|
||||
() => currentAttemptModel,
|
||||
);
|
||||
|
||||
const apiCall = () => {
|
||||
const modelConfigToUse = this.config.isInFallbackMode()
|
||||
? fallbackModelConfig
|
||||
: desiredModelConfig;
|
||||
currentAttemptModel = modelConfigToUse.model;
|
||||
currentAttemptGenerateContentConfig =
|
||||
modelConfigToUse.generateContentConfig;
|
||||
let modelConfigToUse = desiredModelConfig;
|
||||
|
||||
if (!this.config.isModelAvailabilityServiceEnabled()) {
|
||||
modelConfigToUse = this.config.isInFallbackMode()
|
||||
? fallbackModelConfig
|
||||
: desiredModelConfig;
|
||||
currentAttemptModel = modelConfigToUse.model;
|
||||
currentAttemptGenerateContentConfig =
|
||||
modelConfigToUse.generateContentConfig;
|
||||
} else {
|
||||
// AvailabilityService
|
||||
const active = this.config.getActiveModel();
|
||||
if (active !== currentAttemptModel) {
|
||||
currentAttemptModel = active;
|
||||
// Re-resolve config if model changed
|
||||
const newConfig = this.config.modelConfigService.getResolvedConfig({
|
||||
...modelConfigKey,
|
||||
model: currentAttemptModel,
|
||||
});
|
||||
currentAttemptGenerateContentConfig =
|
||||
newConfig.generateContentConfig;
|
||||
}
|
||||
}
|
||||
|
||||
const requestConfig: GenerateContentConfig = {
|
||||
...currentAttemptGenerateContentConfig,
|
||||
abortSignal,
|
||||
@@ -698,7 +756,10 @@ export class GeminiClient {
|
||||
const result = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
maxAttempts: availabilityMaxAttempts,
|
||||
getAvailabilityContext,
|
||||
});
|
||||
|
||||
return result;
|
||||
} catch (error: unknown) {
|
||||
if (abortSignal.aborted) {
|
||||
|
||||
@@ -29,6 +29,10 @@ import { retryWithBackoff, type RetryOptions } from '../utils/retry.js';
|
||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||
import { HookSystem } from '../hooks/hookSystem.js';
|
||||
import { createMockMessageBus } from '../test-utils/mock-message-bus.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import * as policyHelpers from '../availability/policyHelpers.js';
|
||||
import { makeResolvedModelConfig } from '../services/modelConfigServiceTestUtils.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -115,7 +119,14 @@ describe('GeminiChat', () => {
|
||||
|
||||
mockHandleFallback.mockClear();
|
||||
// Default mock implementation for tests that don't care about retry logic
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
|
||||
const result = await apiCall();
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
return result;
|
||||
});
|
||||
mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getTelemetryLogPromptsEnabled: () => true,
|
||||
@@ -141,6 +152,7 @@ describe('GeminiChat', () => {
|
||||
}),
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockContentGenerator),
|
||||
getRetryFetchErrors: vi.fn().mockReturnValue(false),
|
||||
getUserTier: vi.fn().mockReturnValue(undefined),
|
||||
modelConfigService: {
|
||||
getResolvedConfig: vi.fn().mockImplementation((modelConfigKey) => {
|
||||
const thinkingConfig = modelConfigKey.model.startsWith('gemini-3')
|
||||
@@ -165,6 +177,10 @@ describe('GeminiChat', () => {
|
||||
setPreviewModelFallbackMode: vi.fn(),
|
||||
isInteractive: vi.fn().mockReturnValue(false),
|
||||
getEnableHooks: vi.fn().mockReturnValue(false),
|
||||
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
|
||||
getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
setActiveModel: vi.fn(),
|
||||
getModelAvailabilityService: vi.fn(),
|
||||
} as unknown as Config;
|
||||
|
||||
// Use proper MessageBus mocking for Phase 3 preparation
|
||||
@@ -2359,4 +2375,260 @@ describe('GeminiChat', () => {
|
||||
expect(newContents).toEqual(history);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Availability Service Integration', () => {
|
||||
let mockAvailabilityService: ModelAvailabilityService;
|
||||
|
||||
beforeEach(async () => {
|
||||
mockAvailabilityService = createAvailabilityServiceMock();
|
||||
vi.mocked(mockConfig.getModelAvailabilityService).mockReturnValue(
|
||||
mockAvailabilityService,
|
||||
);
|
||||
vi.mocked(mockConfig.isModelAvailabilityServiceEnabled).mockReturnValue(
|
||||
true,
|
||||
);
|
||||
|
||||
// Stateful mock for activeModel
|
||||
let activeModel = 'model-a';
|
||||
vi.mocked(mockConfig.getActiveModel).mockImplementation(
|
||||
() => activeModel,
|
||||
);
|
||||
vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => {
|
||||
activeModel = model;
|
||||
});
|
||||
|
||||
vi.spyOn(policyHelpers, 'resolvePolicyChain').mockReturnValue([
|
||||
{
|
||||
model: 'model-a',
|
||||
isLastResort: false,
|
||||
actions: {},
|
||||
stateTransitions: {},
|
||||
},
|
||||
{
|
||||
model: 'model-b',
|
||||
isLastResort: false,
|
||||
actions: {},
|
||||
stateTransitions: {},
|
||||
},
|
||||
{
|
||||
model: 'model-c',
|
||||
isLastResort: true,
|
||||
actions: {},
|
||||
stateTransitions: {},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('should mark healthy on successful stream', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: 'model-b',
|
||||
skipped: [],
|
||||
});
|
||||
// Simulate selection happening upstream
|
||||
mockConfig.setActiveModel('model-b');
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Response' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'gemini-pro' },
|
||||
'test',
|
||||
'prompt-healthy',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
expect(mockAvailabilityService.markHealthy).toHaveBeenCalledWith(
|
||||
'model-b',
|
||||
);
|
||||
});
|
||||
|
||||
it('caps retries to a single attempt when selection is sticky', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: 'model-a',
|
||||
attempts: 1,
|
||||
skipped: [],
|
||||
});
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Response' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'gemini-pro' },
|
||||
'test',
|
||||
'prompt-sticky-once',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
expect(mockRetryWithBackoff).toHaveBeenCalledWith(
|
||||
expect.any(Function),
|
||||
expect.objectContaining({ maxAttempts: 1 }),
|
||||
);
|
||||
expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'model-a',
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass attempted model to onPersistent429 callback which calls handleFallback', async () => {
|
||||
vi.mocked(mockAvailabilityService.selectFirstAvailable).mockReturnValue({
|
||||
selectedModel: 'model-a',
|
||||
skipped: [],
|
||||
});
|
||||
// Simulate selection happening upstream
|
||||
mockConfig.setActiveModel('model-a');
|
||||
|
||||
// Simulate retry logic behavior: catch error, call onPersistent429
|
||||
const error = new TerminalQuotaError('Quota', {
|
||||
code: 429,
|
||||
message: 'quota',
|
||||
details: [],
|
||||
});
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
|
||||
error,
|
||||
);
|
||||
|
||||
// We need retryWithBackoff to trigger the callback
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall, options) => {
|
||||
try {
|
||||
await apiCall();
|
||||
} catch (e) {
|
||||
if (options?.onPersistent429) {
|
||||
await options.onPersistent429(AuthType.LOGIN_WITH_GOOGLE, e);
|
||||
}
|
||||
throw e; // throw anyway to end test
|
||||
}
|
||||
});
|
||||
|
||||
const consume = async () => {
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'gemini-pro' },
|
||||
'test',
|
||||
'prompt-fallback-arg',
|
||||
new AbortController().signal,
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
};
|
||||
|
||||
await expect(consume()).rejects.toThrow();
|
||||
|
||||
// handleFallback is called with the ATTEMPTED model (model-a), not the requested one (gemini-pro)
|
||||
expect(mockHandleFallback).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
'model-a',
|
||||
expect.anything(),
|
||||
error,
|
||||
);
|
||||
});
|
||||
|
||||
it('re-resolves generateContentConfig when active model changes between retries', async () => {
|
||||
// Availability enabled with stateful active model
|
||||
let activeModel = 'model-a';
|
||||
vi.mocked(mockConfig.getActiveModel).mockImplementation(
|
||||
() => activeModel,
|
||||
);
|
||||
vi.mocked(mockConfig.setActiveModel).mockImplementation((model) => {
|
||||
activeModel = model;
|
||||
});
|
||||
|
||||
// Different configs per model
|
||||
vi.mocked(mockConfig.modelConfigService.getResolvedConfig)
|
||||
.mockReturnValueOnce(
|
||||
makeResolvedModelConfig('model-a', { temperature: 0.1 }),
|
||||
)
|
||||
.mockReturnValueOnce(
|
||||
makeResolvedModelConfig('model-b', { temperature: 0.9 }),
|
||||
);
|
||||
|
||||
// First attempt uses model-a, then simulate availability switching to model-b
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => {
|
||||
await apiCall(); // first attempt
|
||||
activeModel = 'model-b'; // simulate switch before retry
|
||||
return apiCall(); // second attempt
|
||||
});
|
||||
|
||||
// Generators for each attempt
|
||||
const firstResponse = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'first' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
const secondResponse = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'second' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockResolvedValueOnce(firstResponse)
|
||||
.mockResolvedValueOnce(secondResponse);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ model: 'gemini-pro' },
|
||||
'test',
|
||||
'prompt-config-refresh',
|
||||
new AbortController().signal,
|
||||
);
|
||||
// Consume to drive both attempts
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
}
|
||||
|
||||
expect(
|
||||
mockContentGenerator.generateContentStream,
|
||||
).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
model: 'model-a',
|
||||
config: expect.objectContaining({ temperature: 0.1 }),
|
||||
}),
|
||||
expect.any(String),
|
||||
);
|
||||
expect(
|
||||
mockContentGenerator.generateContentStream,
|
||||
).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
expect.objectContaining({
|
||||
model: 'model-b',
|
||||
config: expect.objectContaining({ temperature: 0.9 }),
|
||||
}),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -48,6 +48,10 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
import type { ModelConfigKey } from '../services/modelConfigService.js';
|
||||
import { estimateTokenCountSync } from '../utils/tokenCalculation.js';
|
||||
import {
|
||||
applyModelSelection,
|
||||
createAvailabilityContextProvider,
|
||||
} from '../availability/policyHelpers.js';
|
||||
import {
|
||||
fireAfterModelHook,
|
||||
fireBeforeModelHook,
|
||||
@@ -410,35 +414,74 @@ export class GeminiChat {
|
||||
requestContents: Content[],
|
||||
prompt_id: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
let effectiveModel = model;
|
||||
const contentsForPreviewModel =
|
||||
this.ensureActiveLoopHasThoughtSignatures(requestContents);
|
||||
|
||||
// Track final request parameters for AfterModel hooks
|
||||
let lastModelToUse = model;
|
||||
let lastConfig: GenerateContentConfig = generateContentConfig;
|
||||
const {
|
||||
model: availabilityFinalModel,
|
||||
config: newAvailabilityConfig,
|
||||
maxAttempts: availabilityMaxAttempts,
|
||||
} = applyModelSelection(this.config, model, generateContentConfig);
|
||||
|
||||
const abortSignal = generateContentConfig.abortSignal;
|
||||
let lastModelToUse = availabilityFinalModel;
|
||||
let currentGenerateContentConfig: GenerateContentConfig =
|
||||
newAvailabilityConfig ?? generateContentConfig;
|
||||
if (abortSignal) {
|
||||
currentGenerateContentConfig = {
|
||||
...currentGenerateContentConfig,
|
||||
abortSignal,
|
||||
};
|
||||
}
|
||||
let lastConfig: GenerateContentConfig = currentGenerateContentConfig;
|
||||
let lastContentsToUse: Content[] = requestContents;
|
||||
|
||||
const getAvailabilityContext = createAvailabilityContextProvider(
|
||||
this.config,
|
||||
() => lastModelToUse,
|
||||
);
|
||||
const apiCall = async () => {
|
||||
let modelToUse = getEffectiveModel(
|
||||
this.config.isInFallbackMode(),
|
||||
model,
|
||||
this.config.getPreviewFeatures(),
|
||||
);
|
||||
let modelToUse: string;
|
||||
|
||||
// Preview Model Bypass Logic:
|
||||
// If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro
|
||||
// IF the effective model is currently Preview Model.
|
||||
if (
|
||||
this.config.isPreviewModelBypassMode() &&
|
||||
modelToUse === PREVIEW_GEMINI_MODEL
|
||||
) {
|
||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||
if (this.config.isModelAvailabilityServiceEnabled()) {
|
||||
modelToUse = this.config.getActiveModel();
|
||||
if (modelToUse !== lastModelToUse) {
|
||||
const { generateContentConfig: newConfig } =
|
||||
this.config.modelConfigService.getResolvedConfig({
|
||||
model: modelToUse,
|
||||
});
|
||||
currentGenerateContentConfig = {
|
||||
...currentGenerateContentConfig,
|
||||
...newConfig,
|
||||
};
|
||||
if (abortSignal) {
|
||||
currentGenerateContentConfig.abortSignal = abortSignal;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
modelToUse = getEffectiveModel(
|
||||
this.config.isInFallbackMode(),
|
||||
model,
|
||||
this.config.getPreviewFeatures(),
|
||||
);
|
||||
|
||||
// Preview Model Bypass Logic:
|
||||
// If we are in "Preview Model Bypass Mode" (transient failure), we force downgrade to 2.5 Pro
|
||||
// IF the effective model is currently Preview Model.
|
||||
// Note: In availability mode, this should ideally be handled by policy, but preserving
|
||||
// bypass logic for now as it handles specific transient behavior.
|
||||
if (
|
||||
this.config.isPreviewModelBypassMode() &&
|
||||
modelToUse === PREVIEW_GEMINI_MODEL
|
||||
) {
|
||||
modelToUse = DEFAULT_GEMINI_MODEL;
|
||||
}
|
||||
}
|
||||
|
||||
effectiveModel = modelToUse;
|
||||
lastModelToUse = modelToUse;
|
||||
const config = {
|
||||
...generateContentConfig,
|
||||
...currentGenerateContentConfig,
|
||||
// TODO(12622): Ensure we don't overrwrite these when they are
|
||||
// passed via config.
|
||||
systemInstruction: this.systemInstruction,
|
||||
@@ -543,7 +586,7 @@ export class GeminiChat {
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => handleFallback(this.config, effectiveModel, authType, error);
|
||||
) => handleFallback(this.config, lastModelToUse, authType, error);
|
||||
|
||||
const streamResponse = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: onPersistent429Callback,
|
||||
@@ -551,10 +594,12 @@ export class GeminiChat {
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
signal: generateContentConfig.abortSignal,
|
||||
maxAttempts:
|
||||
this.config.isPreviewModelFallbackMode() &&
|
||||
availabilityMaxAttempts ??
|
||||
(this.config.isPreviewModelFallbackMode() &&
|
||||
model === PREVIEW_GEMINI_MODEL
|
||||
? 1
|
||||
: undefined,
|
||||
: undefined),
|
||||
getAvailabilityContext,
|
||||
});
|
||||
|
||||
// Store the original request for AfterModel hooks
|
||||
@@ -565,7 +610,7 @@ export class GeminiChat {
|
||||
};
|
||||
|
||||
return this.processStreamResponse(
|
||||
effectiveModel,
|
||||
lastModelToUse,
|
||||
streamResponse,
|
||||
originalRequest,
|
||||
);
|
||||
|
||||
@@ -93,6 +93,7 @@ describe('GeminiChat Network Retries', () => {
|
||||
generateContentConfig: { temperature: 0 },
|
||||
})),
|
||||
},
|
||||
isModelAvailabilityServiceEnabled: vi.fn().mockReturnValue(false),
|
||||
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
|
||||
setPreviewModelBypassMode: vi.fn(),
|
||||
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
@@ -45,25 +46,20 @@ vi.mock('../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock debugLogger to prevent console pollution and allow spying
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
log: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
function createAvailabilityMock(
|
||||
result: ReturnType<ModelAvailabilityService['selectFirstAvailable']>,
|
||||
): ModelAvailabilityService {
|
||||
return {
|
||||
markTerminal: vi.fn(),
|
||||
markHealthy: vi.fn(),
|
||||
markRetryOncePerTurn: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
snapshot: vi.fn(),
|
||||
selectFirstAvailable: vi.fn().mockReturnValue(result),
|
||||
resetTurn: vi.fn(),
|
||||
} as unknown as ModelAvailabilityService;
|
||||
}
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
isInFallbackMode: vi.fn(() => false),
|
||||
@@ -75,8 +71,12 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
setPreviewModelBypassMode: vi.fn(),
|
||||
fallbackHandler: undefined,
|
||||
getFallbackModelHandler: vi.fn(),
|
||||
setActiveModel: vi.fn(),
|
||||
getModelAvailabilityService: vi.fn(() =>
|
||||
createAvailabilityMock({ selectedModel: FALLBACK_MODEL, skipped: [] }),
|
||||
createAvailabilityServiceMock({
|
||||
selectedModel: FALLBACK_MODEL,
|
||||
skipped: [],
|
||||
}),
|
||||
),
|
||||
getModel: vi.fn(() => MOCK_PRO_MODEL),
|
||||
getPreviewFeatures: vi.fn(() => false),
|
||||
@@ -98,6 +98,12 @@ describe('handleFallback', () => {
|
||||
mockConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
// Explicitly set the property to ensure it's present for legacy checks
|
||||
mockConfig.fallbackModelHandler = mockHandler;
|
||||
|
||||
// We mocked debugLogger, so we don't need to spy on console.error for handler failures
|
||||
// But tests might check console.error usage in legacy code if any?
|
||||
// The handler uses console.error in legacyHandleFallback.
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged');
|
||||
});
|
||||
@@ -538,7 +544,7 @@ describe('handleFallback', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
availability = createAvailabilityMock({
|
||||
availability = createAvailabilityServiceMock({
|
||||
selectedModel: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
skipped: [],
|
||||
});
|
||||
@@ -612,7 +618,17 @@ describe('handleFallback', () => {
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
|
||||
expect(policyConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
// Silent actions should not trigger the legacy fallback mode (via activateFallbackMode),
|
||||
// but setActiveModel might trigger it via legacy sync if it switches to Flash.
|
||||
// However, the test requirement is "doesn't emit fallback mode".
|
||||
// Since we are mocking setActiveModel, we can verify setFallbackMode isn't called *independently*.
|
||||
// But setActiveModel is mocked, so it won't trigger side effects unless the implementation does.
|
||||
// We verified setActiveModel is called.
|
||||
// We verify setFallbackMode is NOT called (which would happen if activateFallbackMode was called).
|
||||
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
} finally {
|
||||
chainSpy.mockRestore();
|
||||
}
|
||||
@@ -707,5 +723,34 @@ describe('handleFallback', () => {
|
||||
expect(policyConfig.getModelAvailabilityService).toHaveBeenCalled();
|
||||
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
|
||||
policyHandler.mockResolvedValue('retry_always');
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
|
||||
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
// TODO: add logging expect statement
|
||||
});
|
||||
|
||||
it('calls setActiveModel when handler returns "stop"', async () => {
|
||||
policyHandler.mockResolvedValue('stop');
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
|
||||
// TODO: add logging expect statement
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,17 +12,14 @@ import {
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import { ModelNotFoundError } from '../utils/httpErrors.js';
|
||||
import {
|
||||
RetryableQuotaError,
|
||||
TerminalQuotaError,
|
||||
} from '../utils/googleQuotaErrors.js';
|
||||
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import type { FallbackIntent, FallbackRecommendation } from './types.js';
|
||||
import type { FailureKind } from '../availability/modelPolicy.js';
|
||||
import { classifyFailureKind } from '../availability/errorClassification.js';
|
||||
import {
|
||||
buildFallbackPolicyContext,
|
||||
resolvePolicyChain,
|
||||
@@ -126,6 +123,9 @@ async function handlePolicyDrivenFallback(
|
||||
chain,
|
||||
failedModel,
|
||||
);
|
||||
|
||||
const failureKind = classifyFailureKind(error);
|
||||
|
||||
if (!candidates.length) {
|
||||
return null;
|
||||
}
|
||||
@@ -145,7 +145,7 @@ async function handlePolicyDrivenFallback(
|
||||
return null;
|
||||
}
|
||||
|
||||
const failureKind = classifyFailureKind(error);
|
||||
// failureKind is already declared and calculated above
|
||||
const action = resolvePolicyAction(failureKind, selectedPolicy);
|
||||
|
||||
if (action === 'silent') {
|
||||
@@ -183,6 +183,7 @@ async function handlePolicyDrivenFallback(
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
authType,
|
||||
error, // Pass the error so processIntent can handle preview-specific logic
|
||||
);
|
||||
} catch (handlerError) {
|
||||
debugLogger.error('Fallback handler failed:', handlerError);
|
||||
@@ -209,25 +210,42 @@ async function processIntent(
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<boolean> {
|
||||
const isAvailabilityEnabled = config.isModelAvailabilityServiceEnabled();
|
||||
|
||||
switch (intent) {
|
||||
case 'retry_always':
|
||||
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
|
||||
// For all other errors, activate previewModel fallback.
|
||||
if (
|
||||
failedModel === PREVIEW_GEMINI_MODEL &&
|
||||
!(error instanceof TerminalQuotaError)
|
||||
) {
|
||||
activatePreviewModelFallbackMode(config);
|
||||
if (isAvailabilityEnabled) {
|
||||
// TODO(telemetry): Implement generic fallback event logging. Existing
|
||||
// logFlashFallback is specific to a single Model.
|
||||
config.setActiveModel(fallbackModel);
|
||||
} else {
|
||||
activateFallbackMode(config, authType);
|
||||
// If the error is non-retryable, e.g. TerminalQuota Error, trigger a regular fallback to flash.
|
||||
// For all other errors, activate previewModel fallback.
|
||||
if (
|
||||
failedModel === PREVIEW_GEMINI_MODEL &&
|
||||
!(error instanceof TerminalQuotaError)
|
||||
) {
|
||||
activatePreviewModelFallbackMode(config);
|
||||
} else {
|
||||
activateFallbackMode(config, authType);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
||||
case 'retry_once':
|
||||
if (isAvailabilityEnabled) {
|
||||
config.setActiveModel(fallbackModel);
|
||||
}
|
||||
return true;
|
||||
|
||||
case 'stop':
|
||||
activateFallbackMode(config, authType);
|
||||
if (isAvailabilityEnabled) {
|
||||
// TODO(telemetry): Implement generic fallback event logging. Existing
|
||||
// logFlashFallback is specific to a single Model.
|
||||
config.setActiveModel(fallbackModel);
|
||||
} else {
|
||||
activateFallbackMode(config, authType);
|
||||
}
|
||||
return false;
|
||||
|
||||
case 'retry_later':
|
||||
@@ -260,16 +278,3 @@ function activatePreviewModelFallbackMode(config: Config) {
|
||||
// We might want a specific event for Preview Model fallback, but for now we just set the mode.
|
||||
}
|
||||
}
|
||||
|
||||
function classifyFailureKind(error?: unknown): FailureKind {
|
||||
if (error instanceof TerminalQuotaError) {
|
||||
return 'terminal';
|
||||
}
|
||||
if (error instanceof RetryableQuotaError) {
|
||||
return 'transient';
|
||||
}
|
||||
if (error instanceof ModelNotFoundError) {
|
||||
return 'not_found';
|
||||
}
|
||||
return 'unknown';
|
||||
}
|
||||
|
||||
23
packages/core/src/services/modelConfigServiceTestUtils.ts
Normal file
23
packages/core/src/services/modelConfigServiceTestUtils.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ResolvedModelConfig } from '../services/modelConfigService.js';
|
||||
|
||||
/**
|
||||
* Creates a ResolvedModelConfig with sensible defaults, allowing overrides.
|
||||
*/
|
||||
export const makeResolvedModelConfig = (
|
||||
model: string,
|
||||
overrides: Partial<ResolvedModelConfig['generateContentConfig']> = {},
|
||||
): ResolvedModelConfig =>
|
||||
({
|
||||
model,
|
||||
generateContentConfig: {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
...overrides,
|
||||
},
|
||||
}) as ResolvedModelConfig;
|
||||
@@ -17,6 +17,9 @@ import {
|
||||
RetryableQuotaError,
|
||||
} from './googleQuotaErrors.js';
|
||||
import { PREVIEW_GEMINI_MODEL } from '../config/models.js';
|
||||
import type { ModelPolicy } from '../availability/modelPolicy.js';
|
||||
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
|
||||
// Helper to create a mock function that fails a certain number of times
|
||||
const createFailingFunction = (
|
||||
@@ -104,7 +107,6 @@ describe('retryWithBackoff', () => {
|
||||
|
||||
const promise = retryWithBackoff(mockFn);
|
||||
|
||||
// Expect it to fail with the error from the 5th attempt.
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Simulated error attempt 3'),
|
||||
vi.runAllTimersAsync(),
|
||||
@@ -566,4 +568,171 @@ describe('retryWithBackoff', () => {
|
||||
);
|
||||
expect(mockFn).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
describe('Availability Context Integration', () => {
|
||||
let mockService: ModelAvailabilityService;
|
||||
let mockPolicy1: ModelPolicy;
|
||||
let mockPolicy2: ModelPolicy;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useRealTimers();
|
||||
mockService = createAvailabilityServiceMock();
|
||||
|
||||
mockPolicy1 = {
|
||||
model: 'model-1',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
},
|
||||
};
|
||||
|
||||
mockPolicy2 = {
|
||||
model: 'model-2',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
it('updates availability context per attempt and applies transitions to the correct policy', async () => {
|
||||
const error = new TerminalQuotaError(
|
||||
'quota exceeded',
|
||||
{ code: 429, message: 'quota', details: [] },
|
||||
10,
|
||||
);
|
||||
|
||||
const fn = vi.fn().mockImplementation(async () => {
|
||||
throw error; // Always fail with quota
|
||||
});
|
||||
|
||||
const onPersistent429 = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce('model-2') // First fallback success
|
||||
.mockResolvedValueOnce(null); // Second fallback fails (give up)
|
||||
|
||||
// Context provider returns policy1 first, then policy2
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValueOnce({ service: mockService, policy: mockPolicy1 })
|
||||
.mockReturnValueOnce({ service: mockService, policy: mockPolicy2 });
|
||||
|
||||
await expect(
|
||||
retryWithBackoff(fn, {
|
||||
maxAttempts: 3,
|
||||
initialDelayMs: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
onPersistent429,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
}),
|
||||
).rejects.toThrow(TerminalQuotaError);
|
||||
|
||||
// Verify failures
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-2', 'quota');
|
||||
|
||||
// Verify sequences
|
||||
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
'model-1',
|
||||
'quota',
|
||||
);
|
||||
expect(mockService.markTerminal).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
'model-2',
|
||||
'quota',
|
||||
);
|
||||
});
|
||||
|
||||
it('marks sticky_retry after retries are exhausted for transient failures', async () => {
|
||||
const transientError = new RetryableQuotaError(
|
||||
'transient error',
|
||||
{ code: 429, message: 'transient', details: [] },
|
||||
0,
|
||||
);
|
||||
|
||||
const fn = vi.fn().mockRejectedValue(transientError);
|
||||
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue({ service: mockService, policy: mockPolicy1 });
|
||||
|
||||
vi.useFakeTimers();
|
||||
const promise = retryWithBackoff(fn, {
|
||||
maxAttempts: 3,
|
||||
getAvailabilityContext: getContext,
|
||||
initialDelayMs: 1,
|
||||
maxDelayMs: 1,
|
||||
}).catch((err) => err);
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
const result = await promise;
|
||||
expect(result).toBe(transientError);
|
||||
|
||||
expect(fn).toHaveBeenCalledTimes(3);
|
||||
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith('model-1');
|
||||
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledTimes(1);
|
||||
expect(mockService.markTerminal).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('maps different failure kinds to correct terminal reasons', async () => {
|
||||
const quotaError = new TerminalQuotaError(
|
||||
'quota',
|
||||
{ code: 429, message: 'q', details: [] },
|
||||
10,
|
||||
);
|
||||
const notFoundError = new ModelNotFoundError('not found', 404);
|
||||
const genericError = new Error('unknown error');
|
||||
|
||||
const fn = vi
|
||||
.fn()
|
||||
.mockRejectedValueOnce(quotaError)
|
||||
.mockRejectedValueOnce(notFoundError)
|
||||
.mockRejectedValueOnce(genericError);
|
||||
|
||||
const policy: ModelPolicy = {
|
||||
model: 'model-1',
|
||||
actions: {},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal', // from quotaError
|
||||
not_found: 'terminal', // from notFoundError
|
||||
unknown: 'terminal', // from genericError
|
||||
},
|
||||
};
|
||||
|
||||
const getContext = vi
|
||||
.fn()
|
||||
.mockReturnValue({ service: mockService, policy });
|
||||
|
||||
// Run for quotaError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith('model-1', 'quota');
|
||||
|
||||
// Run for notFoundError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith(
|
||||
'model-1',
|
||||
'capacity',
|
||||
);
|
||||
|
||||
// Run for genericError
|
||||
await retryWithBackoff(fn, {
|
||||
maxAttempts: 1,
|
||||
getAvailabilityContext: getContext,
|
||||
}).catch(() => {});
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith(
|
||||
'model-1',
|
||||
'capacity',
|
||||
);
|
||||
|
||||
expect(mockService.markTerminal).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -8,13 +8,18 @@ import type { GenerateContentResponse } from '@google/genai';
|
||||
import { ApiError } from '@google/genai';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
classifyGoogleError,
|
||||
RetryableQuotaError,
|
||||
TerminalQuotaError,
|
||||
RetryableQuotaError,
|
||||
classifyGoogleError,
|
||||
} from './googleQuotaErrors.js';
|
||||
import { delay, createAbortError } from './delay.js';
|
||||
import { debugLogger } from './debugLogger.js';
|
||||
import { getErrorStatus, ModelNotFoundError } from './httpErrors.js';
|
||||
import type { RetryAvailabilityContext } from '../availability/modelPolicy.js';
|
||||
import { classifyFailureKind } from '../availability/errorClassification.js';
|
||||
import { applyAvailabilityTransition } from '../availability/policyHelpers.js';
|
||||
|
||||
export type { RetryAvailabilityContext };
|
||||
|
||||
export interface RetryOptions {
|
||||
maxAttempts: number;
|
||||
@@ -29,6 +34,7 @@ export interface RetryOptions {
|
||||
authType?: string;
|
||||
retryFetchErrors?: boolean;
|
||||
signal?: AbortSignal;
|
||||
getAvailabilityContext?: () => RetryAvailabilityContext | undefined;
|
||||
}
|
||||
|
||||
const DEFAULT_RETRY_OPTIONS: RetryOptions = {
|
||||
@@ -145,6 +151,7 @@ export async function retryWithBackoff<T>(
|
||||
shouldRetryOnContent,
|
||||
retryFetchErrors,
|
||||
signal,
|
||||
getAvailabilityContext,
|
||||
} = {
|
||||
...DEFAULT_RETRY_OPTIONS,
|
||||
shouldRetryOnError: isRetryableError,
|
||||
@@ -173,6 +180,11 @@ export async function retryWithBackoff<T>(
|
||||
continue;
|
||||
}
|
||||
|
||||
const successContext = getAvailabilityContext?.();
|
||||
if (successContext) {
|
||||
successContext.service.markHealthy(successContext.policy.model);
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
@@ -180,6 +192,13 @@ export async function retryWithBackoff<T>(
|
||||
}
|
||||
|
||||
const classifiedError = classifyGoogleError(error);
|
||||
const failureKind = classifyFailureKind(classifiedError);
|
||||
const appliedImmediate =
|
||||
failureKind === 'terminal' || failureKind === 'not_found';
|
||||
if (appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
|
||||
const errorCode = getErrorStatus(error);
|
||||
|
||||
if (
|
||||
@@ -201,6 +220,7 @@ export async function retryWithBackoff<T>(
|
||||
debugLogger.warn('Fallback to Flash model failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
// Terminal/not_found already recorded; nothing else to mark here.
|
||||
throw classifiedError; // Throw if no fallback or fallback failed.
|
||||
}
|
||||
|
||||
@@ -224,6 +244,9 @@ export async function retryWithBackoff<T>(
|
||||
console.warn('Model fallback failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
if (!appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
throw classifiedError instanceof RetryableQuotaError
|
||||
? classifiedError
|
||||
: error;
|
||||
@@ -253,6 +276,9 @@ export async function retryWithBackoff<T>(
|
||||
attempt >= maxAttempts ||
|
||||
!shouldRetryOnError(error as Error, retryFetchErrors)
|
||||
) {
|
||||
if (!appliedImmediate) {
|
||||
applyAvailabilityTransition(getAvailabilityContext, failureKind);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user