mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 23:14:32 -07:00
fix(core): distinguish fallback chains and fix maxAttempts for auto vs explicit model selection (#26163)
This commit is contained in:
@@ -0,0 +1,414 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { BaseLlmClient } from '../core/baseLlmClient.js';
|
||||
import { FakeContentGenerator } from '../core/fakeContentGenerator.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { RetryableQuotaError } from '../utils/googleQuotaErrors.js';
|
||||
import {
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
PREVIEW_GEMINI_MODEL_AUTO,
|
||||
} from '../config/models.js';
|
||||
import fs from 'node:fs';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import type { FallbackIntent } from '../fallback/types.js';
|
||||
import { LlmRole } from '../telemetry/types.js';
|
||||
import type { GenerateContentResponse } from '@google/genai';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
describe('Auto Routing Fallback Integration', () => {
|
||||
let config: Config;
|
||||
let fakeGenerator: FakeContentGenerator;
|
||||
let client: BaseLlmClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
// Mock fs to avoid real file system access
|
||||
vi.mocked(fs.existsSync).mockReturnValue(true);
|
||||
vi.mocked(fs.statSync).mockReturnValue({
|
||||
isDirectory: () => true,
|
||||
} as fs.Stats);
|
||||
|
||||
// Provide a valid dummy sandbox policy for any readFileSync calls for TOML files
|
||||
vi.mocked(fs.readFileSync).mockImplementation((path) => {
|
||||
if (typeof path === 'string' && path.endsWith('.toml')) {
|
||||
return `
|
||||
[modes.plan]
|
||||
network = false
|
||||
readonly = true
|
||||
approvedTools = []
|
||||
|
||||
[modes.default]
|
||||
network = false
|
||||
readonly = false
|
||||
approvedTools = []
|
||||
|
||||
[modes.accepting_edits]
|
||||
network = false
|
||||
readonly = false
|
||||
approvedTools = []
|
||||
`;
|
||||
}
|
||||
return ''; // Fallback for other files
|
||||
});
|
||||
|
||||
fakeGenerator = new FakeContentGenerator([]);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it('should fallback to Flash after 3 tries and try 10 times for Flash in auto mode', async () => {
|
||||
// Instantiate real Config in auto mode
|
||||
config = new Config({
|
||||
sessionId: 'test-session',
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
cwd: '/test',
|
||||
model: PREVIEW_GEMINI_MODEL_AUTO, // Trigger auto mode
|
||||
});
|
||||
|
||||
// Force interactive mode to enable fallback handler in BaseLlmClient
|
||||
vi.spyOn(config, 'isInteractive').mockReturnValue(true);
|
||||
|
||||
client = new BaseLlmClient(
|
||||
fakeGenerator,
|
||||
config,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
);
|
||||
|
||||
let attemptsPro = 0;
|
||||
let attemptsFlash = 0;
|
||||
|
||||
const mockGoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Quota exceeded',
|
||||
details: [],
|
||||
};
|
||||
|
||||
// Spy on generateContent to simulate failures
|
||||
vi.spyOn(fakeGenerator, 'generateContent').mockImplementation(
|
||||
async (params) => {
|
||||
if (params.model === PREVIEW_GEMINI_MODEL) {
|
||||
attemptsPro++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Pro',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
} else if (params.model === PREVIEW_GEMINI_FLASH_MODEL) {
|
||||
attemptsFlash++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Flash',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
}
|
||||
throw new Error(`Unexpected model: ${params.model}`);
|
||||
},
|
||||
);
|
||||
|
||||
// Set a fallback handler that approves the switch (simulating user or auto approval)
|
||||
config.setFallbackModelHandler(
|
||||
async (failed, _fallback, _error): Promise<FallbackIntent | null> => {
|
||||
if (failed === PREVIEW_GEMINI_FLASH_MODEL) {
|
||||
return 'stop'; // Stop retrying after Flash fails
|
||||
}
|
||||
return 'retry_always'; // Trigger fallback to Flash
|
||||
},
|
||||
);
|
||||
|
||||
// Call generateContent
|
||||
const promise = client.generateContent({
|
||||
modelConfigKey: { model: PREVIEW_GEMINI_MODEL, isChatModel: true },
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||
abortSignal: new AbortController().signal,
|
||||
promptId: 'test-prompt',
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Quota exceeded for Flash'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
// Verify attempts
|
||||
expect(attemptsPro).toBe(3);
|
||||
expect(attemptsFlash).toBe(10);
|
||||
});
|
||||
|
||||
it('should try 10 times and prompt user in non-auto mode', async () => {
|
||||
// Instantiate real Config in non-auto mode
|
||||
const configNonAuto = new Config({
|
||||
sessionId: 'test-session',
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
cwd: '/test',
|
||||
model: PREVIEW_GEMINI_MODEL, // Non-auto mode
|
||||
});
|
||||
|
||||
// Force interactive mode to enable fallback handler in BaseLlmClient
|
||||
vi.spyOn(configNonAuto, 'isInteractive').mockReturnValue(true);
|
||||
|
||||
const clientNonAuto = new BaseLlmClient(
|
||||
fakeGenerator,
|
||||
configNonAuto,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
);
|
||||
|
||||
let attemptsPro = 0;
|
||||
|
||||
const mockGoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Quota exceeded',
|
||||
details: [],
|
||||
};
|
||||
|
||||
// Spy on generateContent to simulate failures
|
||||
vi.spyOn(fakeGenerator, 'generateContent').mockImplementation(
|
||||
async (params) => {
|
||||
if (params.model === PREVIEW_GEMINI_MODEL) {
|
||||
attemptsPro++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Pro',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
}
|
||||
throw new Error(`Unexpected model: ${params.model}`);
|
||||
},
|
||||
);
|
||||
|
||||
// Set a fallback handler that returns 'stop' (simulating user stopping or failing to handle)
|
||||
const handler = vi.fn(
|
||||
async (_failed, _fallback, _error): Promise<FallbackIntent | null> =>
|
||||
'stop',
|
||||
);
|
||||
configNonAuto.setFallbackModelHandler(handler);
|
||||
|
||||
const promise = clientNonAuto.generateContent({
|
||||
modelConfigKey: { model: PREVIEW_GEMINI_MODEL, isChatModel: true },
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||
abortSignal: new AbortController().signal,
|
||||
promptId: 'test-prompt',
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
maxAttempts: 10,
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Quota exceeded for Pro'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
// Verify attempts (should default to 10)
|
||||
expect(attemptsPro).toBe(10);
|
||||
|
||||
// Verify handler was called once after 10 attempts to prompt user
|
||||
expect(handler).toHaveBeenCalledTimes(1);
|
||||
expect(handler).toHaveBeenCalledWith(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
PREVIEW_GEMINI_FLASH_MODEL,
|
||||
expect.any(RetryableQuotaError),
|
||||
);
|
||||
});
|
||||
|
||||
it('should fallback to Flash after 3 tries in experimental dynamic mode', async () => {
|
||||
// Instantiate real Config in auto mode
|
||||
const configDynamic = new Config({
|
||||
sessionId: 'test-session',
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
cwd: '/test',
|
||||
model: PREVIEW_GEMINI_MODEL_AUTO, // Trigger auto mode
|
||||
});
|
||||
|
||||
// Force interactive mode to enable fallback handler in BaseLlmClient
|
||||
vi.spyOn(configDynamic, 'isInteractive').mockReturnValue(true);
|
||||
|
||||
// Enable experimental dynamic model configuration
|
||||
vi.spyOn(
|
||||
configDynamic,
|
||||
'getExperimentalDynamicModelConfiguration',
|
||||
).mockReturnValue(true);
|
||||
|
||||
const clientDynamic = new BaseLlmClient(
|
||||
fakeGenerator,
|
||||
configDynamic,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
);
|
||||
|
||||
let attemptsPro = 0;
|
||||
let attemptsFlash = 0;
|
||||
|
||||
const mockGoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Quota exceeded',
|
||||
details: [],
|
||||
};
|
||||
|
||||
// Spy on generateContent to simulate failures
|
||||
vi.spyOn(fakeGenerator, 'generateContent').mockImplementation(
|
||||
async (params) => {
|
||||
if (params.model === PREVIEW_GEMINI_MODEL) {
|
||||
attemptsPro++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Pro',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
} else if (params.model === PREVIEW_GEMINI_FLASH_MODEL) {
|
||||
attemptsFlash++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Flash',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
}
|
||||
throw new Error(`Unexpected model: ${params.model}`);
|
||||
},
|
||||
);
|
||||
|
||||
// Set a fallback handler that approves the switch
|
||||
configDynamic.setFallbackModelHandler(
|
||||
async (failed, _fallback, _error): Promise<FallbackIntent | null> => {
|
||||
if (failed === PREVIEW_GEMINI_FLASH_MODEL) {
|
||||
return 'stop';
|
||||
}
|
||||
return 'retry_always';
|
||||
},
|
||||
);
|
||||
|
||||
const promise = clientDynamic.generateContent({
|
||||
modelConfigKey: { model: PREVIEW_GEMINI_MODEL, isChatModel: true },
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||
abortSignal: new AbortController().signal,
|
||||
promptId: 'test-prompt',
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
expect(promise).rejects.toThrow('Quota exceeded for Flash'),
|
||||
vi.runAllTimersAsync(),
|
||||
]);
|
||||
|
||||
// Verify attempts
|
||||
expect(attemptsPro).toBe(3);
|
||||
expect(attemptsFlash).toBe(10);
|
||||
});
|
||||
|
||||
it('should retry Pro on next turn after successful fallback to Flash', async () => {
|
||||
// Instantiate real Config in auto mode
|
||||
config = new Config({
|
||||
sessionId: 'test-session',
|
||||
targetDir: '/test',
|
||||
debugMode: false,
|
||||
cwd: '/test',
|
||||
model: PREVIEW_GEMINI_MODEL_AUTO, // Trigger auto mode
|
||||
});
|
||||
|
||||
// Force interactive mode to enable fallback handler in BaseLlmClient
|
||||
vi.spyOn(config, 'isInteractive').mockReturnValue(true);
|
||||
|
||||
client = new BaseLlmClient(
|
||||
fakeGenerator,
|
||||
config,
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
);
|
||||
|
||||
let attemptsPro = 0;
|
||||
let attemptsFlash = 0;
|
||||
|
||||
const mockGoogleApiError = {
|
||||
code: 429,
|
||||
message: 'Quota exceeded',
|
||||
details: [],
|
||||
};
|
||||
|
||||
// Turn 1: Pro fails, Flash succeeds
|
||||
vi.spyOn(fakeGenerator, 'generateContent').mockImplementation(
|
||||
async (params) => {
|
||||
if (params.model === PREVIEW_GEMINI_MODEL) {
|
||||
attemptsPro++;
|
||||
throw new RetryableQuotaError(
|
||||
'Quota exceeded for Pro',
|
||||
mockGoogleApiError,
|
||||
0,
|
||||
);
|
||||
} else if (params.model === PREVIEW_GEMINI_FLASH_MODEL) {
|
||||
attemptsFlash++;
|
||||
return {
|
||||
candidates: [
|
||||
{
|
||||
content: { role: 'model', parts: [{ text: 'Flash success' }] },
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
}
|
||||
throw new Error(`Unexpected model: ${params.model}`);
|
||||
},
|
||||
);
|
||||
|
||||
config.setFallbackModelHandler(
|
||||
async (_failed, _fallback, _error): Promise<FallbackIntent | null> =>
|
||||
'retry_always', // Approve switch to Flash
|
||||
);
|
||||
|
||||
// Call generateContent for Turn 1
|
||||
const promise1 = client.generateContent({
|
||||
modelConfigKey: { model: PREVIEW_GEMINI_MODEL, isChatModel: true },
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||
abortSignal: new AbortController().signal,
|
||||
promptId: 'test-prompt-1',
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
const result1 = await promise1;
|
||||
|
||||
expect(result1.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'Flash success',
|
||||
);
|
||||
expect(attemptsPro).toBe(3);
|
||||
expect(attemptsFlash).toBe(1);
|
||||
|
||||
// Simulate start of next turn
|
||||
config.getModelAvailabilityService().resetTurn();
|
||||
|
||||
// Turn 2: Pro should be attempted again!
|
||||
// Let's make it succeed this time to verify it works!
|
||||
vi.spyOn(fakeGenerator, 'generateContent').mockImplementation(
|
||||
async (params) => {
|
||||
if (params.model === PREVIEW_GEMINI_MODEL) {
|
||||
return {
|
||||
candidates: [
|
||||
{ content: { role: 'model', parts: [{ text: 'Pro success' }] } },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
}
|
||||
throw new Error(`Unexpected model: ${params.model}`);
|
||||
},
|
||||
);
|
||||
|
||||
const promise2 = client.generateContent({
|
||||
modelConfigKey: { model: PREVIEW_GEMINI_MODEL, isChatModel: true }, // Request Pro again
|
||||
contents: [{ role: 'user', parts: [{ text: 'hello again' }] }],
|
||||
abortSignal: new AbortController().signal,
|
||||
promptId: 'test-prompt-2',
|
||||
role: LlmRole.UTILITY_TOOL,
|
||||
});
|
||||
|
||||
const result2 = await promise2;
|
||||
expect(result2.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'Pro success',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -77,4 +77,38 @@ describe('Fallback Integration', () => {
|
||||
// 5. Expect it to fallback to Flash (because Gemini 3 uses PREVIEW_CHAIN)
|
||||
expect(result.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
|
||||
});
|
||||
|
||||
it('should fallback to Flash after failures and restore Pro on next turn', () => {
|
||||
const requestedModel = PREVIEW_GEMINI_MODEL;
|
||||
|
||||
// 1. Initial call should return Pro with 3 attempts
|
||||
const result1 = applyModelSelection(config, {
|
||||
model: requestedModel,
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(result1.model).toBe(PREVIEW_GEMINI_MODEL);
|
||||
expect(result1.maxAttempts).toBe(3);
|
||||
|
||||
// 2. Simulate failure and transition to sticky_retry with consumed=true
|
||||
availabilityService.markRetryOncePerTurn(PREVIEW_GEMINI_MODEL, 3);
|
||||
availabilityService.consumeStickyAttempt(PREVIEW_GEMINI_MODEL);
|
||||
|
||||
// 3. Next call in same turn should fallback to Flash
|
||||
const result2 = applyModelSelection(config, {
|
||||
model: requestedModel,
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(result2.model).toBe(PREVIEW_GEMINI_FLASH_MODEL);
|
||||
|
||||
// 4. Reset turn (start of new interaction)
|
||||
availabilityService.resetTurn();
|
||||
|
||||
// 5. Next call should restore Pro with 3 attempts
|
||||
const result3 = applyModelSelection(config, {
|
||||
model: requestedModel,
|
||||
isChatModel: true,
|
||||
});
|
||||
expect(result3.model).toBe(PREVIEW_GEMINI_MODEL);
|
||||
expect(result3.maxAttempts).toBe(3);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -34,6 +34,12 @@ describe('ModelAvailabilityService', () => {
|
||||
expect(service.snapshot(model)).toEqual({ available: true });
|
||||
});
|
||||
|
||||
it('tracks retry with custom attempts', () => {
|
||||
service.markRetryOncePerTurn(model, 3);
|
||||
const selection = service.selectFirstAvailable([model]);
|
||||
expect(selection.attempts).toBe(3);
|
||||
});
|
||||
|
||||
it('tracks terminal failures', () => {
|
||||
service.markTerminal(model, 'quota');
|
||||
expect(service.snapshot(model)).toEqual({
|
||||
|
||||
@@ -22,6 +22,7 @@ type HealthState =
|
||||
status: 'sticky_retry';
|
||||
reason: TurnUnavailabilityReason;
|
||||
consumed: boolean;
|
||||
attempts: number;
|
||||
};
|
||||
|
||||
export interface ModelAvailabilitySnapshot {
|
||||
@@ -52,7 +53,7 @@ export class ModelAvailabilityService {
|
||||
this.clearState(model);
|
||||
}
|
||||
|
||||
markRetryOncePerTurn(model: ModelId) {
|
||||
markRetryOncePerTurn(model: ModelId, attempts: number = 1) {
|
||||
const currentState = this.health.get(model);
|
||||
// Do not override a terminal failure with a transient one.
|
||||
if (currentState?.status === 'terminal') {
|
||||
@@ -70,6 +71,7 @@ export class ModelAvailabilityService {
|
||||
status: 'sticky_retry',
|
||||
reason: 'retry_once_per_turn',
|
||||
consumed,
|
||||
attempts,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -106,7 +108,8 @@ export class ModelAvailabilityService {
|
||||
if (snapshot.available) {
|
||||
const state = this.health.get(model);
|
||||
// A sticky model is being attempted, so note that.
|
||||
const attempts = state?.status === 'sticky_retry' ? 1 : undefined;
|
||||
const attempts =
|
||||
state?.status === 'sticky_retry' ? state.attempts : undefined;
|
||||
return { selectedModel: model, skipped, attempts };
|
||||
} else {
|
||||
skipped.push({ model, reason: snapshot.reason ?? 'unknown' });
|
||||
|
||||
@@ -46,6 +46,7 @@ export interface ModelPolicy {
|
||||
actions: ModelPolicyActionMap;
|
||||
stateTransitions: ModelPolicyStateMap;
|
||||
isLastResort?: boolean;
|
||||
maxAttempts?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -53,8 +53,11 @@ describe('policyCatalog', () => {
|
||||
expect(chain).toHaveLength(2);
|
||||
});
|
||||
|
||||
it('marks preview transients as sticky retries', () => {
|
||||
const [previewPolicy] = getModelPolicyChain({ previewEnabled: true });
|
||||
it('marks preview transients as sticky retries when auto-selected', () => {
|
||||
const [previewPolicy] = getModelPolicyChain({
|
||||
previewEnabled: true,
|
||||
isAutoSelection: true,
|
||||
});
|
||||
expect(previewPolicy.model).toBe(PREVIEW_GEMINI_MODEL);
|
||||
expect(previewPolicy.stateTransitions.transient).toBe('sticky_retry');
|
||||
});
|
||||
|
||||
@@ -28,6 +28,7 @@ type PolicyConfig = Omit<ModelPolicy, 'actions' | 'stateTransitions'> & {
|
||||
|
||||
export interface ModelPolicyOptions {
|
||||
previewEnabled: boolean;
|
||||
isAutoSelection?: boolean;
|
||||
userTier?: UserTierId;
|
||||
useGemini31?: boolean;
|
||||
useGemini31FlashLite?: boolean;
|
||||
@@ -50,15 +51,19 @@ export const SILENT_ACTIONS: ModelPolicyActionMap = {
|
||||
|
||||
const DEFAULT_STATE: ModelPolicyStateMap = {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
};
|
||||
|
||||
const DEFAULT_CHAIN: ModelPolicyChain = [
|
||||
definePolicy({ model: DEFAULT_GEMINI_MODEL }),
|
||||
definePolicy({ model: DEFAULT_GEMINI_FLASH_MODEL, isLastResort: true }),
|
||||
];
|
||||
const AUTO_ROUTING_OVERRIDES = {
|
||||
maxAttempts: 3,
|
||||
actions: { ...DEFAULT_ACTIONS, transient: 'silent' } as ModelPolicyActionMap,
|
||||
stateTransitions: {
|
||||
...DEFAULT_STATE,
|
||||
transient: 'sticky_retry',
|
||||
} as ModelPolicyStateMap,
|
||||
};
|
||||
|
||||
const FLASH_LITE_CHAIN: ModelPolicyChain = [
|
||||
definePolicy({
|
||||
@@ -82,20 +87,45 @@ const FLASH_LITE_CHAIN: ModelPolicyChain = [
|
||||
export function getModelPolicyChain(
|
||||
options: ModelPolicyOptions,
|
||||
): ModelPolicyChain {
|
||||
const isAuto = options.isAutoSelection ?? false;
|
||||
|
||||
if (options.previewEnabled) {
|
||||
const previewModel = resolveModel(
|
||||
const proModel = resolveModel(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
options.useGemini31,
|
||||
options.useGemini31FlashLite,
|
||||
options.useCustomToolModel,
|
||||
);
|
||||
return [
|
||||
definePolicy({ model: previewModel }),
|
||||
definePolicy({ model: PREVIEW_GEMINI_FLASH_MODEL, isLastResort: true }),
|
||||
definePolicy({
|
||||
model: proModel,
|
||||
...(isAuto
|
||||
? {
|
||||
maxAttempts: 3,
|
||||
actions: { ...DEFAULT_ACTIONS, transient: 'silent' },
|
||||
stateTransitions: { ...DEFAULT_STATE, transient: 'sticky_retry' },
|
||||
}
|
||||
: {}),
|
||||
}),
|
||||
definePolicy({
|
||||
model: PREVIEW_GEMINI_FLASH_MODEL,
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
return cloneChain(DEFAULT_CHAIN);
|
||||
return [
|
||||
definePolicy({
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
...(isAuto ? AUTO_ROUTING_OVERRIDES : {}),
|
||||
}),
|
||||
definePolicy({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
}),
|
||||
];
|
||||
}
|
||||
|
||||
export function createSingleModelChain(model: string): ModelPolicyChain {
|
||||
@@ -137,6 +167,7 @@ function definePolicy(config: PolicyConfig): ModelPolicy {
|
||||
return {
|
||||
model: config.model,
|
||||
isLastResort: config.isLastResort,
|
||||
maxAttempts: config.maxAttempts,
|
||||
actions: { ...DEFAULT_ACTIONS, ...(config.actions ?? {}) },
|
||||
stateTransitions: {
|
||||
...DEFAULT_STATE,
|
||||
|
||||
@@ -9,8 +9,10 @@ import {
|
||||
resolvePolicyChain,
|
||||
buildFallbackPolicyContext,
|
||||
applyModelSelection,
|
||||
applyAvailabilityTransition,
|
||||
} from './policyHelpers.js';
|
||||
import { createDefaultPolicy, SILENT_ACTIONS } from './policyCatalog.js';
|
||||
import type { RetryAvailabilityContext } from './modelPolicy.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_LITE_MODEL,
|
||||
@@ -35,6 +37,7 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config => {
|
||||
return useGemini31 && authType === AuthType.USE_GEMINI;
|
||||
},
|
||||
getContentGeneratorConfig: () => ({ authType: undefined }),
|
||||
getMaxAttemptsPerTurn: () => 3,
|
||||
...overrides,
|
||||
} as unknown as Config;
|
||||
return config;
|
||||
@@ -201,6 +204,7 @@ describe('policyHelpers', () => {
|
||||
hasAccess: false,
|
||||
},
|
||||
{ name: 'Concrete Model (2.5 Pro)', model: 'gemini-2.5-pro' },
|
||||
{ name: 'Explicit Gemini 3', model: 'gemini-3-pro-preview' },
|
||||
{ name: 'Custom Model', model: 'my-custom-model' },
|
||||
{
|
||||
name: 'Wrap Around',
|
||||
@@ -438,4 +442,51 @@ describe('policyHelpers', () => {
|
||||
expect(result.maxAttempts).toBe(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('applyAvailabilityTransition', () => {
|
||||
it('marks terminal on terminal transition', () => {
|
||||
const mockService = { markTerminal: vi.fn() };
|
||||
const context = {
|
||||
service: mockService,
|
||||
policy: {
|
||||
model: 'test-model',
|
||||
stateTransitions: { transient: 'terminal' },
|
||||
},
|
||||
};
|
||||
const getContext = () => context as unknown as RetryAvailabilityContext;
|
||||
|
||||
applyAvailabilityTransition(getContext, 'transient');
|
||||
|
||||
expect(mockService.markTerminal).toHaveBeenCalledWith(
|
||||
'test-model',
|
||||
'capacity',
|
||||
);
|
||||
});
|
||||
|
||||
it('marks sticky and consumes on sticky_retry transition', () => {
|
||||
const mockService = {
|
||||
markRetryOncePerTurn: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
};
|
||||
const context = {
|
||||
service: mockService,
|
||||
policy: {
|
||||
model: 'test-model',
|
||||
stateTransitions: { transient: 'sticky_retry' },
|
||||
maxAttempts: 3,
|
||||
},
|
||||
};
|
||||
const getContext = () => context as unknown as RetryAvailabilityContext;
|
||||
|
||||
applyAvailabilityTransition(getContext, 'transient');
|
||||
|
||||
expect(mockService.markRetryOncePerTurn).toHaveBeenCalledWith(
|
||||
'test-model',
|
||||
3,
|
||||
);
|
||||
expect(mockService.consumeStickyAttempt).toHaveBeenCalledWith(
|
||||
'test-model',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -77,12 +77,12 @@ export function resolvePolicyChain(
|
||||
chain = config.modelConfigService.resolveChain('lite', context);
|
||||
} else if (
|
||||
isGemini3Model(resolvedModel, config) ||
|
||||
isAutoModel(preferredModel ?? '', config) ||
|
||||
isAutoModel(configuredModel, config)
|
||||
isAutoPreferred ||
|
||||
isAutoConfigured
|
||||
) {
|
||||
// 1. Try to find a chain specifically for the current configured alias
|
||||
if (
|
||||
isAutoModel(configuredModel, config) &&
|
||||
isAutoConfigured &&
|
||||
config.modelConfigService.getModelChain(configuredModel)
|
||||
) {
|
||||
chain = config.modelConfigService.resolveChain(
|
||||
@@ -92,13 +92,18 @@ export function resolvePolicyChain(
|
||||
}
|
||||
// 2. Fallback to family-based auto-routing
|
||||
if (!chain) {
|
||||
const isAutoSelection = isAutoPreferred || isAutoConfigured;
|
||||
const previewEnabled =
|
||||
hasAccessToPreview &&
|
||||
(isGemini3Model(resolvedModel, config) ||
|
||||
preferredModel === PREVIEW_GEMINI_MODEL_AUTO ||
|
||||
configuredModel === PREVIEW_GEMINI_MODEL_AUTO);
|
||||
const autoPrefix = isAutoSelection ? 'auto-' : '';
|
||||
const chainKey = previewEnabled ? 'preview' : 'default';
|
||||
chain = config.modelConfigService.resolveChain(chainKey, context);
|
||||
chain = config.modelConfigService.resolveChain(
|
||||
`${autoPrefix}${chainKey}`,
|
||||
context,
|
||||
);
|
||||
}
|
||||
}
|
||||
if (!chain) {
|
||||
@@ -116,6 +121,7 @@ export function resolvePolicyChain(
|
||||
isAutoPreferred ||
|
||||
isAutoConfigured
|
||||
) {
|
||||
const isAutoSelection = isAutoPreferred || isAutoConfigured;
|
||||
if (hasAccessToPreview) {
|
||||
const previewEnabled =
|
||||
isGemini3Model(resolvedModel, config) ||
|
||||
@@ -123,6 +129,7 @@ export function resolvePolicyChain(
|
||||
configuredModel === PREVIEW_GEMINI_MODEL_AUTO;
|
||||
chain = getModelPolicyChain({
|
||||
previewEnabled,
|
||||
isAutoSelection,
|
||||
userTier: config.getUserTier(),
|
||||
useGemini31,
|
||||
useGemini31FlashLite,
|
||||
@@ -133,6 +140,7 @@ export function resolvePolicyChain(
|
||||
// to the stable Gemini 2.5 chain.
|
||||
chain = getModelPolicyChain({
|
||||
previewEnabled: false,
|
||||
isAutoSelection,
|
||||
userTier: config.getUserTier(),
|
||||
useGemini31,
|
||||
useGemini31FlashLite,
|
||||
@@ -144,7 +152,6 @@ export function resolvePolicyChain(
|
||||
}
|
||||
chain = applyDynamicSlicing(chain, resolvedModel, wrapsAround);
|
||||
}
|
||||
|
||||
// Apply Unified Silent Injection for Plan Mode with defensive checks
|
||||
if (config?.getApprovalMode?.() === ApprovalMode.PLAN) {
|
||||
return chain.map((policy) => ({
|
||||
@@ -295,10 +302,13 @@ export function applyModelSelection(
|
||||
config.getModelAvailabilityService().consumeStickyAttempt(finalModel);
|
||||
}
|
||||
|
||||
const chain = resolvePolicyChain(config, finalModel);
|
||||
const policy = chain.find((p) => p.model === finalModel);
|
||||
|
||||
return {
|
||||
model: finalModel,
|
||||
config: generateContentConfig,
|
||||
maxAttempts: selection.attempts,
|
||||
maxAttempts: selection.attempts ?? policy?.maxAttempts,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -318,6 +328,10 @@ export function applyAvailabilityTransition(
|
||||
failureKind === 'terminal' ? 'quota' : 'capacity',
|
||||
);
|
||||
} else if (transition === 'sticky_retry') {
|
||||
context.service.markRetryOncePerTurn(context.policy.model);
|
||||
context.service.markRetryOncePerTurn(
|
||||
context.policy.model,
|
||||
context.policy.maxAttempts,
|
||||
);
|
||||
context.service.consumeStickyAttempt(context.policy.model);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -557,7 +557,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
@@ -565,12 +565,31 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
{
|
||||
model: 'gemini-3-flash-preview',
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'prompt',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
],
|
||||
'auto-preview': [
|
||||
{
|
||||
model: 'gemini-3-pro-preview',
|
||||
maxAttempts: 3,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'silent',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
@@ -578,6 +597,23 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
{
|
||||
model: 'gemini-3-flash-preview',
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'prompt',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
],
|
||||
default: [
|
||||
{
|
||||
@@ -598,12 +634,31 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
{
|
||||
model: 'gemini-2.5-flash',
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'prompt',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
],
|
||||
'auto-default': [
|
||||
{
|
||||
model: 'gemini-2.5-pro',
|
||||
maxAttempts: 3,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'silent',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
@@ -611,6 +666,23 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
{
|
||||
model: 'gemini-2.5-flash',
|
||||
isLastResort: true,
|
||||
maxAttempts: 10,
|
||||
actions: {
|
||||
terminal: 'prompt',
|
||||
transient: 'prompt',
|
||||
not_found: 'prompt',
|
||||
unknown: 'prompt',
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
},
|
||||
],
|
||||
lite: [
|
||||
{
|
||||
@@ -623,7 +695,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
@@ -638,7 +710,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
@@ -654,7 +726,7 @@ export const DEFAULT_MODEL_CONFIGS: ModelConfigServiceConfig = {
|
||||
},
|
||||
stateTransitions: {
|
||||
terminal: 'terminal',
|
||||
transient: 'sticky_retry',
|
||||
transient: 'terminal',
|
||||
not_found: 'terminal',
|
||||
unknown: 'terminal',
|
||||
},
|
||||
|
||||
@@ -43,36 +43,41 @@ vi.mock('../utils/errors.js', async (importOriginal) => {
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/retry.js', () => ({
|
||||
retryWithBackoff: vi.fn(async (fn, options) => {
|
||||
// Default implementation - just call the function
|
||||
const result = await fn();
|
||||
vi.mock('../utils/retry.js', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('../utils/retry.js')>();
|
||||
return {
|
||||
...actual,
|
||||
retryWithBackoff: vi.fn(async (fn, options) => {
|
||||
// Default implementation - just call the function
|
||||
const result = await fn();
|
||||
|
||||
// If shouldRetryOnContent is provided, test it but don't actually retry
|
||||
// (unless we want to simulate retry exhaustion for testing)
|
||||
if (options?.shouldRetryOnContent) {
|
||||
const shouldRetry = options.shouldRetryOnContent(result);
|
||||
if (shouldRetry) {
|
||||
// Check if we need to simulate retry exhaustion (for error testing)
|
||||
const responseText = result?.candidates?.[0]?.content?.parts?.[0]?.text;
|
||||
if (
|
||||
!responseText ||
|
||||
responseText.trim() === '' ||
|
||||
responseText.includes('{"color": "blue"')
|
||||
) {
|
||||
throw new Error('Retry attempts exhausted for invalid content');
|
||||
// If shouldRetryOnContent is provided, test it but don't actually retry
|
||||
// (unless we want to simulate retry exhaustion for testing)
|
||||
if (options?.shouldRetryOnContent) {
|
||||
const shouldRetry = options.shouldRetryOnContent(result);
|
||||
if (shouldRetry) {
|
||||
// Check if we need to simulate retry exhaustion (for error testing)
|
||||
const responseText =
|
||||
result?.candidates?.[0]?.content?.parts?.[0]?.text;
|
||||
if (
|
||||
!responseText ||
|
||||
responseText.trim() === '' ||
|
||||
responseText.includes('{"color": "blue"')
|
||||
) {
|
||||
throw new Error('Retry attempts exhausted for invalid content');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
const context = options?.getAvailabilityContext?.();
|
||||
if (context) {
|
||||
context.service.markHealthy(context.policy.model);
|
||||
}
|
||||
|
||||
return result;
|
||||
}),
|
||||
}));
|
||||
return result;
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockGenerateContent = vi.fn();
|
||||
const mockEmbedContent = vi.fn();
|
||||
|
||||
@@ -339,7 +339,9 @@ export class BaseLlmClient {
|
||||
retryFetchErrors: this.config.getRetryFetchErrors(),
|
||||
onRetry: (attempt, error, delayMs) => {
|
||||
const actualMaxAttempts =
|
||||
availabilityMaxAttempts ?? maxAttempts ?? DEFAULT_MAX_ATTEMPTS;
|
||||
getAvailabilityContext()?.policy.maxAttempts ??
|
||||
maxAttempts ??
|
||||
DEFAULT_MAX_ATTEMPTS;
|
||||
const modelName = getDisplayString(currentModel);
|
||||
const errorType = getRetryErrorType(error);
|
||||
|
||||
|
||||
@@ -249,6 +249,9 @@ export async function retryWithBackoff<T>(
|
||||
...cleanOptions,
|
||||
};
|
||||
|
||||
const getCurrentMaxAttempts = () =>
|
||||
getAvailabilityContext?.()?.policy.maxAttempts ?? maxAttempts;
|
||||
|
||||
let attempt = 0;
|
||||
let currentDelay = initialDelayMs;
|
||||
const throwIfAborted = () => {
|
||||
@@ -257,7 +260,7 @@ export async function retryWithBackoff<T>(
|
||||
}
|
||||
};
|
||||
|
||||
while (attempt < maxAttempts) {
|
||||
while (attempt < getCurrentMaxAttempts()) {
|
||||
if (signal?.aborted) {
|
||||
throw createAbortError();
|
||||
}
|
||||
@@ -344,7 +347,7 @@ export async function retryWithBackoff<T>(
|
||||
errorCode !== undefined && errorCode >= 500 && errorCode < 600;
|
||||
|
||||
if (classifiedError instanceof RetryableQuotaError || is500) {
|
||||
if (attempt >= maxAttempts) {
|
||||
if (attempt >= getCurrentMaxAttempts()) {
|
||||
const errorMessage =
|
||||
classifiedError instanceof Error ? classifiedError.message : '';
|
||||
debugLogger.warn(
|
||||
@@ -405,7 +408,7 @@ export async function retryWithBackoff<T>(
|
||||
|
||||
// Generic retry logic for other errors
|
||||
if (
|
||||
attempt >= maxAttempts ||
|
||||
attempt >= getCurrentMaxAttempts() ||
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
!shouldRetryOnError(error as Error, retryFetchErrors)
|
||||
) {
|
||||
|
||||
Reference in New Issue
Block a user