Do not fallback for manual models (#84)

* Update display name for alias model

* fix tests
This commit is contained in:
Sehoon Shon
2025-12-12 13:41:16 -05:00
committed by Tommaso Sciortino
parent 16e06adb46
commit 9ab79b712c
10 changed files with 145 additions and 34 deletions
+1
View File
@@ -91,6 +91,7 @@ const mockConfig = {
isTrustedFolder: () => true, isTrustedFolder: () => true,
getIdeMode: () => false, getIdeMode: () => false,
getEnableInteractiveShell: () => true, getEnableInteractiveShell: () => true,
getPreviewFeatures: () => false,
}; };
const configProxy = new Proxy(mockConfig, { const configProxy = new Proxy(mockConfig, {
+1 -1
View File
@@ -149,7 +149,7 @@ export const Footer: React.FC = () => {
<Box alignItems="center" justifyContent="flex-end"> <Box alignItems="center" justifyContent="flex-end">
<Box alignItems="center"> <Box alignItems="center">
<Text color={theme.text.accent}> <Text color={theme.text.accent}>
{getDisplayString(model)} {getDisplayString(model, config.getPreviewFeatures())}
{!hideContextPercentage && ( {!hideContextPercentage && (
<> <>
{' '} {' '}
@@ -33,7 +33,7 @@ describe('ProQuotaDialog', () => {
const { unmount } = render( const { unmount } = render(
<ProQuotaDialog <ProQuotaDialog
failedModel={DEFAULT_GEMINI_FLASH_MODEL} failedModel={DEFAULT_GEMINI_FLASH_MODEL}
fallbackModel="gemini-2.5-pro" fallbackModel={DEFAULT_GEMINI_FLASH_MODEL}
message="flash error" message="flash error"
isTerminalQuotaError={true} // should not matter isTerminalQuotaError={true} // should not matter
onChoice={mockOnChoice} onChoice={mockOnChoice}
@@ -97,6 +97,38 @@ describe('ProQuotaDialog', () => {
unmount(); unmount();
}); });
it('should render "Keep trying" and "Stop" options when failed model and fallback model are the same', () => {
const { unmount } = render(
<ProQuotaDialog
failedModel={PREVIEW_GEMINI_MODEL}
fallbackModel={PREVIEW_GEMINI_MODEL}
message="flash error"
isTerminalQuotaError={true}
onChoice={mockOnChoice}
userTier={UserTierId.FREE}
/>,
);
expect(RadioButtonSelect).toHaveBeenCalledWith(
expect.objectContaining({
items: [
{
label: 'Keep trying',
value: 'retry_once',
key: 'retry_once',
},
{
label: 'Stop',
value: 'retry_later',
key: 'retry_later',
},
],
}),
undefined,
);
unmount();
});
it('should render switch, upgrade, and stop options for free tier', () => { it('should render switch, upgrade, and stop options for free tier', () => {
const { unmount } = render( const { unmount } = render(
<ProQuotaDialog <ProQuotaDialog
@@ -9,14 +9,7 @@ import { Box, Text } from 'ink';
import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
import { theme } from '../semantic-colors.js'; import { theme } from '../semantic-colors.js';
import { import { DEFAULT_GEMINI_MODEL, UserTierId } from '@google/gemini-cli-core';
DEFAULT_GEMINI_FLASH_LITE_MODEL,
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
FLASH_PREVIEW_MODEL_REVERT_BEFORE_MERGE,
PREVIEW_GEMINI_FLASH_MODEL,
UserTierId,
} from '@google/gemini-cli-core';
interface ProQuotaDialogProps { interface ProQuotaDialogProps {
failedModel: string; failedModel: string;
@@ -43,13 +36,8 @@ export function ProQuotaDialog({
const isPaidTier = const isPaidTier =
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
let items; let items;
// flash and flash lite don't have options to switch or upgrade. // Do not provide a fallback option if failed model and fallbackmodel are same.
if ( if (failedModel === fallbackModel) {
failedModel === DEFAULT_GEMINI_FLASH_MODEL ||
failedModel === DEFAULT_GEMINI_FLASH_LITE_MODEL ||
failedModel === PREVIEW_GEMINI_FLASH_MODEL ||
failedModel === FLASH_PREVIEW_MODEL_REVERT_BEFORE_MERGE
) {
items = [ items = [
{ {
label: 'Keep trying', label: 'Keep trying',
@@ -66,6 +66,10 @@ export function getModelPolicyChain(
return cloneChain(DEFAULT_CHAIN); return cloneChain(DEFAULT_CHAIN);
} }
export function createSingleModelChain(model: string): ModelPolicyChain {
return [definePolicy({ model, isLastResort: true })];
}
/** /**
* Provides a default policy scaffold for models not present in the catalog. * Provides a default policy scaffold for models not present in the catalog.
*/ */
@@ -25,7 +25,7 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
describe('policyHelpers', () => { describe('policyHelpers', () => {
describe('resolvePolicyChain', () => { describe('resolvePolicyChain', () => {
it('inserts the active model when missing from the catalog', () => { it('returns a single-model chain for a custom model', () => {
const config = createMockConfig({ const config = createMockConfig({
getModel: () => 'custom-model', getModel: () => 'custom-model',
}); });
@@ -53,6 +53,25 @@ describe('policyHelpers', () => {
expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[0]?.model).toBe('gemini-2.5-pro');
expect(chain[1]?.model).toBe('gemini-2.5-flash'); expect(chain[1]?.model).toBe('gemini-2.5-flash');
}); });
it('starts chain from preferredModel when model is "auto"', () => {
const config = createMockConfig({
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
});
const chain = resolvePolicyChain(config, 'gemini-2.5-flash');
expect(chain).toHaveLength(1);
expect(chain[0]?.model).toBe('gemini-2.5-flash');
});
it('wraps around the chain when wrapsAround is true', () => {
const config = createMockConfig({
getModel: () => DEFAULT_GEMINI_MODEL_AUTO,
});
const chain = resolvePolicyChain(config, 'gemini-2.5-flash', true);
expect(chain).toHaveLength(2);
expect(chain[0]?.model).toBe('gemini-2.5-flash');
expect(chain[1]?.model).toBe('gemini-2.5-pro');
});
}); });
describe('buildFallbackPolicyContext', () => { describe('buildFallbackPolicyContext', () => {
@@ -67,6 +86,17 @@ describe('policyHelpers', () => {
expect(context.candidates.map((p) => p.model)).toEqual(['c']); expect(context.candidates.map((p) => p.model)).toEqual(['c']);
}); });
it('wraps around when building fallback context if wrapsAround is true', () => {
const chain = [
createDefaultPolicy('a'),
createDefaultPolicy('b'),
createDefaultPolicy('c'),
];
const context = buildFallbackPolicyContext(chain, 'b', true);
expect(context.failedPolicy?.model).toBe('b');
expect(context.candidates.map((p) => p.model)).toEqual(['c', 'a']);
});
it('returns full chain when model is not in policy list', () => { it('returns full chain when model is not in policy list', () => {
const chain = [createDefaultPolicy('a'), createDefaultPolicy('b')]; const chain = [createDefaultPolicy('a'), createDefaultPolicy('b')];
const context = buildFallbackPolicyContext(chain, 'x'); const context = buildFallbackPolicyContext(chain, 'x');
+24 -10
View File
@@ -13,8 +13,17 @@ import type {
ModelPolicyChain, ModelPolicyChain,
RetryAvailabilityContext, RetryAvailabilityContext,
} from './modelPolicy.js'; } from './modelPolicy.js';
import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js'; import {
import { DEFAULT_GEMINI_MODEL, resolveModel } from '../config/models.js'; createDefaultPolicy,
createSingleModelChain,
getModelPolicyChain,
} from './policyCatalog.js';
import {
DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO,
resolveModel,
} from '../config/models.js';
import type { ModelSelectionResult } from './modelAvailabilityService.js'; import type { ModelSelectionResult } from './modelAvailabilityService.js';
/** /**
@@ -31,15 +40,20 @@ export function resolvePolicyChain(
const modelFromConfig = const modelFromConfig =
preferredModel ?? config.getActiveModel?.() ?? config.getModel(); preferredModel ?? config.getActiveModel?.() ?? config.getModel();
const isPreviewRequest = let chain;
modelFromConfig.includes('gemini-3') ||
modelFromConfig.includes('preview') || if (
modelFromConfig === 'fiercefalcon'; config.getModel() === PREVIEW_GEMINI_MODEL_AUTO ||
config.getModel() === DEFAULT_GEMINI_MODEL_AUTO
) {
chain = getModelPolicyChain({
previewEnabled: config.getModel() === PREVIEW_GEMINI_MODEL_AUTO,
userTier: config.getUserTier(),
});
} else {
chain = createSingleModelChain(modelFromConfig);
}
const chain = getModelPolicyChain({
previewEnabled: isPreviewRequest,
userTier: config.getUserTier(),
});
const activeModel = resolveModel(modelFromConfig); const activeModel = resolveModel(modelFromConfig);
const activeIndex = chain.findIndex((policy) => policy.model === activeModel); const activeIndex = chain.findIndex((policy) => policy.model === activeModel);
+12 -1
View File
@@ -115,12 +115,23 @@ export function getEffectiveModel(
return resolveModel(requestedModel, previewFeaturesEnabled); return resolveModel(requestedModel, previewFeaturesEnabled);
} }
export function getDisplayString(model: string) { export function getDisplayString(
model: string,
previewFeaturesEnabled: boolean = false,
) {
switch (model) { switch (model) {
case PREVIEW_GEMINI_MODEL_AUTO: case PREVIEW_GEMINI_MODEL_AUTO:
return 'Auto (Gemini 3)'; return 'Auto (Gemini 3)';
case DEFAULT_GEMINI_MODEL_AUTO: case DEFAULT_GEMINI_MODEL_AUTO:
return 'Auto (Gemini 2.5)'; return 'Auto (Gemini 2.5)';
case GEMINI_MODEL_ALIAS_PRO:
return previewFeaturesEnabled
? PREVIEW_GEMINI_MODEL
: DEFAULT_GEMINI_MODEL;
case GEMINI_MODEL_ALIAS_FLASH:
return previewFeaturesEnabled
? PREVIEW_GEMINI_FLASH_MODEL
: DEFAULT_GEMINI_FLASH_MODEL;
default: default:
return model; return model;
} }
+10 -3
View File
@@ -30,7 +30,10 @@ import {
type ChatCompressionInfo, type ChatCompressionInfo,
} from './turn.js'; } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
} from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { setSimulate429 } from '../utils/testUtils.js'; import { setSimulate429 } from '../utils/testUtils.js';
import { tokenLimit } from './tokenLimits.js'; import { tokenLimit } from './tokenLimits.js';
@@ -2044,7 +2047,9 @@ ${JSON.stringify(
skipped: [], skipped: [],
}, },
); );
vi.mocked(mockConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
const stream = client.sendMessageStream( const stream = client.sendMessageStream(
[{ text: 'Hi' }], [{ text: 'Hi' }],
new AbortController().signal, new AbortController().signal,
@@ -2074,7 +2079,9 @@ ${JSON.stringify(
skipped: [], skipped: [],
}, },
); );
vi.mocked(mockConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
const stream = client.sendMessageStream( const stream = client.sendMessageStream(
[{ text: 'Hi' }], [{ text: 'Hi' }],
new AbortController().signal, new AbortController().signal,
+26 -2
View File
@@ -22,8 +22,10 @@ import { AuthType } from '../core/contentGenerator.js';
import { import {
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL_AUTO,
} from '../config/models.js'; } from '../config/models.js';
import type { FallbackModelHandler } from './types.js'; import type { FallbackModelHandler } from './types.js';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js'; import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
@@ -152,7 +154,9 @@ describe('handleFallback', () => {
it('uses availability selection with correct candidates when enabled', async () => { it('uses availability selection with correct candidates when enabled', async () => {
// Direct mock manipulation since it's already a vi.fn() // Direct mock manipulation since it's already a vi.fn()
vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true); vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true);
vi.mocked(policyConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL); vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
await handleFallback(policyConfig, DEFAULT_GEMINI_MODEL, AUTH_OAUTH); await handleFallback(policyConfig, DEFAULT_GEMINI_MODEL, AUTH_OAUTH);
@@ -162,6 +166,9 @@ describe('handleFallback', () => {
}); });
it('falls back to last resort when availability returns null', async () => { it('falls back to last resort when availability returns null', async () => {
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
availability.selectFirstAvailable = vi availability.selectFirstAvailable = vi
.fn() .fn()
.mockReturnValue({ selectedModel: null, skipped: [] }); .mockReturnValue({ selectedModel: null, skipped: [] });
@@ -224,6 +231,9 @@ describe('handleFallback', () => {
it('does not wrap around to upgrade candidates if the current model was selected at the end (e.g. by router)', async () => { it('does not wrap around to upgrade candidates if the current model was selected at the end (e.g. by router)', async () => {
// Last-resort failure (Flash) in [Preview, Pro, Flash] checks Preview then Pro (all upstream). // Last-resort failure (Flash) in [Preview, Pro, Flash] checks Preview then Pro (all upstream).
vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true); vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true);
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
availability.selectFirstAvailable = vi.fn().mockReturnValue({ availability.selectFirstAvailable = vi.fn().mockReturnValue({
selectedModel: MOCK_PRO_MODEL, selectedModel: MOCK_PRO_MODEL,
@@ -255,7 +265,9 @@ describe('handleFallback', () => {
vi.mocked(policyConfig.getActiveModel).mockReturnValue( vi.mocked(policyConfig.getActiveModel).mockReturnValue(
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
); );
vi.mocked(policyConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL); vi.mocked(policyConfig.getModel).mockReturnValue(
PREVIEW_GEMINI_MODEL_AUTO,
);
const result = await handleFallback( const result = await handleFallback(
policyConfig, policyConfig,
@@ -315,6 +327,9 @@ describe('handleFallback', () => {
5, 5,
); );
policyHandler.mockResolvedValue('retry_always'); policyHandler.mockResolvedValue('retry_always');
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
await handleFallback( await handleFallback(
policyConfig, policyConfig,
@@ -342,6 +357,9 @@ describe('handleFallback', () => {
1000, 1000,
); );
policyHandler.mockResolvedValue('retry_once'); policyHandler.mockResolvedValue('retry_once');
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
await handleFallback( await handleFallback(
policyConfig, policyConfig,
@@ -362,6 +380,9 @@ describe('handleFallback', () => {
availability.selectFirstAvailable = vi availability.selectFirstAvailable = vi
.fn() .fn()
.mockReturnValue({ selectedModel: null, skipped: [] }); .mockReturnValue({ selectedModel: null, skipped: [] });
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
const result = await handleFallback( const result = await handleFallback(
policyConfig, policyConfig,
@@ -381,6 +402,9 @@ describe('handleFallback', () => {
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => { it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
policyHandler.mockResolvedValue('retry_always'); policyHandler.mockResolvedValue('retry_always');
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
);
const result = await handleFallback( const result = await handleFallback(
policyConfig, policyConfig,