From 4cee7e83c43f1e672f56bf77b1644eb1e7f18f13 Mon Sep 17 00:00:00 2001 From: Sehoon Shon Date: Fri, 12 Dec 2025 13:41:16 -0500 Subject: [PATCH] Do not fallback for manual models (#84) * Update display name for alias model * fix tests --- packages/cli/src/test-utils/render.tsx | 1 + packages/cli/src/ui/components/Footer.tsx | 2 +- .../src/ui/components/ProQuotaDialog.test.tsx | 34 ++++++++++++++++++- .../cli/src/ui/components/ProQuotaDialog.tsx | 18 ++-------- .../core/src/availability/policyCatalog.ts | 4 +++ .../src/availability/policyHelpers.test.ts | 32 ++++++++++++++++- .../core/src/availability/policyHelpers.ts | 34 +++++++++++++------ packages/core/src/config/models.ts | 13 ++++++- packages/core/src/core/client.test.ts | 13 +++++-- packages/core/src/fallback/handler.test.ts | 28 +++++++++++++-- 10 files changed, 145 insertions(+), 34 deletions(-) diff --git a/packages/cli/src/test-utils/render.tsx b/packages/cli/src/test-utils/render.tsx index 894fd06568..4d9349196d 100644 --- a/packages/cli/src/test-utils/render.tsx +++ b/packages/cli/src/test-utils/render.tsx @@ -91,6 +91,7 @@ const mockConfig = { isTrustedFolder: () => true, getIdeMode: () => false, getEnableInteractiveShell: () => true, + getPreviewFeatures: () => false, }; const configProxy = new Proxy(mockConfig, { diff --git a/packages/cli/src/ui/components/Footer.tsx b/packages/cli/src/ui/components/Footer.tsx index 32a5376cd7..5192461c43 100644 --- a/packages/cli/src/ui/components/Footer.tsx +++ b/packages/cli/src/ui/components/Footer.tsx @@ -149,7 +149,7 @@ export const Footer: React.FC = () => { - {getDisplayString(model)} + {getDisplayString(model, config.getPreviewFeatures())} {!hideContextPercentage && ( <> {' '} diff --git a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx index 51c3de02cd..2bf83a6067 100644 --- a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx +++ b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx @@ -33,7 +33,7 @@ describe('ProQuotaDialog', () => { const { unmount } = render( { unmount(); }); + it('should render "Keep trying" and "Stop" options when failed model and fallback model are the same', () => { + const { unmount } = render( + , + ); + + expect(RadioButtonSelect).toHaveBeenCalledWith( + expect.objectContaining({ + items: [ + { + label: 'Keep trying', + value: 'retry_once', + key: 'retry_once', + }, + { + label: 'Stop', + value: 'retry_later', + key: 'retry_later', + }, + ], + }), + undefined, + ); + unmount(); + }); + it('should render switch, upgrade, and stop options for free tier', () => { const { unmount } = render( = {}): Config => describe('policyHelpers', () => { describe('resolvePolicyChain', () => { - it('inserts the active model when missing from the catalog', () => { + it('returns a single-model chain for a custom model', () => { const config = createMockConfig({ getModel: () => 'custom-model', }); @@ -53,6 +53,25 @@ describe('policyHelpers', () => { expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); }); + + it('starts chain from preferredModel when model is "auto"', () => { + const config = createMockConfig({ + getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + }); + const chain = resolvePolicyChain(config, 'gemini-2.5-flash'); + expect(chain).toHaveLength(1); + expect(chain[0]?.model).toBe('gemini-2.5-flash'); + }); + + it('wraps around the chain when wrapsAround is true', () => { + const config = createMockConfig({ + getModel: () => DEFAULT_GEMINI_MODEL_AUTO, + }); + const chain = resolvePolicyChain(config, 'gemini-2.5-flash', true); + expect(chain).toHaveLength(2); + expect(chain[0]?.model).toBe('gemini-2.5-flash'); + expect(chain[1]?.model).toBe('gemini-2.5-pro'); + }); }); describe('buildFallbackPolicyContext', () => { @@ -67,6 +86,17 @@ describe('policyHelpers', () => { expect(context.candidates.map((p) => p.model)).toEqual(['c']); }); + it('wraps around when building fallback context if wrapsAround is true', () => { + const chain = [ + createDefaultPolicy('a'), + createDefaultPolicy('b'), + createDefaultPolicy('c'), + ]; + const context = buildFallbackPolicyContext(chain, 'b', true); + expect(context.failedPolicy?.model).toBe('b'); + expect(context.candidates.map((p) => p.model)).toEqual(['c', 'a']); + }); + it('returns full chain when model is not in policy list', () => { const chain = [createDefaultPolicy('a'), createDefaultPolicy('b')]; const context = buildFallbackPolicyContext(chain, 'x'); diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index b170587c9e..4177461544 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -13,8 +13,17 @@ import type { ModelPolicyChain, RetryAvailabilityContext, } from './modelPolicy.js'; -import { createDefaultPolicy, getModelPolicyChain } from './policyCatalog.js'; -import { DEFAULT_GEMINI_MODEL, resolveModel } from '../config/models.js'; +import { + createDefaultPolicy, + createSingleModelChain, + getModelPolicyChain, +} from './policyCatalog.js'; +import { + DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_MODEL_AUTO, + resolveModel, +} from '../config/models.js'; import type { ModelSelectionResult } from './modelAvailabilityService.js'; /** @@ -31,15 +40,20 @@ export function resolvePolicyChain( const modelFromConfig = preferredModel ?? config.getActiveModel?.() ?? config.getModel(); - const isPreviewRequest = - modelFromConfig.includes('gemini-3') || - modelFromConfig.includes('preview') || - modelFromConfig === 'fiercefalcon'; + let chain; + + if ( + config.getModel() === PREVIEW_GEMINI_MODEL_AUTO || + config.getModel() === DEFAULT_GEMINI_MODEL_AUTO + ) { + chain = getModelPolicyChain({ + previewEnabled: config.getModel() === PREVIEW_GEMINI_MODEL_AUTO, + userTier: config.getUserTier(), + }); + } else { + chain = createSingleModelChain(modelFromConfig); + } - const chain = getModelPolicyChain({ - previewEnabled: isPreviewRequest, - userTier: config.getUserTier(), - }); const activeModel = resolveModel(modelFromConfig); const activeIndex = chain.findIndex((policy) => policy.model === activeModel); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 3ea6c9f053..111ca57634 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -115,12 +115,23 @@ export function getEffectiveModel( return resolveModel(requestedModel, previewFeaturesEnabled); } -export function getDisplayString(model: string) { +export function getDisplayString( + model: string, + previewFeaturesEnabled: boolean = false, +) { switch (model) { case PREVIEW_GEMINI_MODEL_AUTO: return 'Auto (Gemini 3)'; case DEFAULT_GEMINI_MODEL_AUTO: return 'Auto (Gemini 2.5)'; + case GEMINI_MODEL_ALIAS_PRO: + return previewFeaturesEnabled + ? PREVIEW_GEMINI_MODEL + : DEFAULT_GEMINI_MODEL; + case GEMINI_MODEL_ALIAS_FLASH: + return previewFeaturesEnabled + ? PREVIEW_GEMINI_FLASH_MODEL + : DEFAULT_GEMINI_FLASH_MODEL; default: return model; } diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index e4aca47fc2..09c6308e88 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -30,7 +30,10 @@ import { type ChatCompressionInfo, } from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, +} from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { setSimulate429 } from '../utils/testUtils.js'; import { tokenLimit } from './tokenLimits.js'; @@ -2044,7 +2047,9 @@ ${JSON.stringify( skipped: [], }, ); - + vi.mocked(mockConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); const stream = client.sendMessageStream( [{ text: 'Hi' }], new AbortController().signal, @@ -2074,7 +2079,9 @@ ${JSON.stringify( skipped: [], }, ); - + vi.mocked(mockConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); const stream = client.sendMessageStream( [{ text: 'Hi' }], new AbortController().signal, diff --git a/packages/core/src/fallback/handler.test.ts b/packages/core/src/fallback/handler.test.ts index 2ffef97a05..488006c1da 100644 --- a/packages/core/src/fallback/handler.test.ts +++ b/packages/core/src/fallback/handler.test.ts @@ -22,8 +22,10 @@ import { AuthType } from '../core/contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, + DEFAULT_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_MODEL_AUTO, } from '../config/models.js'; import type { FallbackModelHandler } from './types.js'; import { openBrowserSecurely } from '../utils/secure-browser-launcher.js'; @@ -152,7 +154,9 @@ describe('handleFallback', () => { it('uses availability selection with correct candidates when enabled', async () => { // Direct mock manipulation since it's already a vi.fn() vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true); - vi.mocked(policyConfig.getModel).mockReturnValue(DEFAULT_GEMINI_MODEL); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); await handleFallback(policyConfig, DEFAULT_GEMINI_MODEL, AUTH_OAUTH); @@ -162,6 +166,9 @@ describe('handleFallback', () => { }); it('falls back to last resort when availability returns null', async () => { + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); availability.selectFirstAvailable = vi .fn() .mockReturnValue({ selectedModel: null, skipped: [] }); @@ -224,6 +231,9 @@ describe('handleFallback', () => { it('does not wrap around to upgrade candidates if the current model was selected at the end (e.g. by router)', async () => { // Last-resort failure (Flash) in [Preview, Pro, Flash] checks Preview then Pro (all upstream). vi.mocked(policyConfig.getPreviewFeatures).mockReturnValue(true); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); availability.selectFirstAvailable = vi.fn().mockReturnValue({ selectedModel: MOCK_PRO_MODEL, @@ -255,7 +265,9 @@ describe('handleFallback', () => { vi.mocked(policyConfig.getActiveModel).mockReturnValue( PREVIEW_GEMINI_MODEL, ); - vi.mocked(policyConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL); + vi.mocked(policyConfig.getModel).mockReturnValue( + PREVIEW_GEMINI_MODEL_AUTO, + ); const result = await handleFallback( policyConfig, @@ -315,6 +327,9 @@ describe('handleFallback', () => { 5, ); policyHandler.mockResolvedValue('retry_always'); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); await handleFallback( policyConfig, @@ -342,6 +357,9 @@ describe('handleFallback', () => { 1000, ); policyHandler.mockResolvedValue('retry_once'); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); await handleFallback( policyConfig, @@ -362,6 +380,9 @@ describe('handleFallback', () => { availability.selectFirstAvailable = vi .fn() .mockReturnValue({ selectedModel: null, skipped: [] }); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); const result = await handleFallback( policyConfig, @@ -381,6 +402,9 @@ describe('handleFallback', () => { it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => { policyHandler.mockResolvedValue('retry_always'); + vi.mocked(policyConfig.getModel).mockReturnValue( + DEFAULT_GEMINI_MODEL_AUTO, + ); const result = await handleFallback( policyConfig,