diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index 4514aac011..9fdebe1b2a 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -88,12 +88,11 @@ export class Task { this.eventBus = eventBus; this.completedToolCalls = []; this._resetToolCompletionPromise(); - this.config.setFlashFallbackHandler( - async (currentModel: string, fallbackModel: string): Promise => { - config.setModel(fallbackModel); // gemini-cli-core sets to DEFAULT_GEMINI_FLASH_MODEL - // Switch model for future use but return false to stop current retry - return false; - }, + this.config.setFallbackModelHandler( + // For a2a-server, we want to automatically switch to the fallback model + // for future requests without retrying the current one. The 'stop' + // intent achieves this. + async () => 'stop', ); } diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index 92ef706d3c..ab7c19dd70 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -43,7 +43,7 @@ export function createMockConfig( getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }), getModel: vi.fn().mockReturnValue('gemini-pro'), getUsageStatisticsEnabled: vi.fn().mockReturnValue(false), - setFlashFallbackHandler: vi.fn(), + setFallbackModelHandler: vi.fn(), initialize: vi.fn().mockResolvedValue(undefined), getProxy: vi.fn().mockReturnValue(undefined), getHistory: vi.fn().mockReturnValue([]), diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index 1f6024d469..a5d83ca383 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -4,31 +4,55 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import { render, cleanup } from 'ink-testing-library'; import { AppContainer } from './AppContainer.js'; import { type Config, makeFakeConfig } from '@google/gemini-cli-core'; import type { LoadedSettings } from '../config/settings.js'; import type { InitializationResult } from '../core/initializer.js'; +import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js'; +import { UIStateContext, type UIState } from './contexts/UIStateContext.js'; +import { + UIActionsContext, + type UIActions, +} from './contexts/UIActionsContext.js'; +import { useContext } from 'react'; + +// Helper component will read the context values provided by AppContainer +// so we can assert against them in our tests. +let capturedUIState: UIState; +let capturedUIActions: UIActions; +function TestContextConsumer() { + capturedUIState = useContext(UIStateContext)!; + capturedUIActions = useContext(UIActionsContext)!; + return null; +} -// Mock App component to isolate AppContainer testing vi.mock('./App.js', () => ({ - App: () => 'App Component', + App: TestContextConsumer, })); -// Mock all the hooks and utilities -vi.mock('./hooks/useHistory.js'); +vi.mock('./hooks/useQuotaAndFallback.js'); +vi.mock('./hooks/useHistoryManager.js'); vi.mock('./hooks/useThemeCommand.js'); -vi.mock('./hooks/useAuthCommand.js'); +vi.mock('./auth/useAuth.js'); vi.mock('./hooks/useEditorSettings.js'); vi.mock('./hooks/useSettingsCommand.js'); -vi.mock('./hooks/useSlashCommandProcessor.js'); +vi.mock('./hooks/slashCommandProcessor.js'); vi.mock('./hooks/useConsoleMessages.js'); vi.mock('./hooks/useTerminalSize.js', () => ({ useTerminalSize: vi.fn(() => ({ columns: 80, rows: 24 })), })); vi.mock('./hooks/useGeminiStream.js'); -vi.mock('./hooks/useVim.js'); +vi.mock('./hooks/vim.js'); vi.mock('./hooks/useFocus.js'); vi.mock('./hooks/useBracketedPaste.js'); vi.mock('./hooks/useKeypress.js'); @@ -40,7 +64,7 @@ vi.mock('./hooks/useWorkspaceMigration.js'); vi.mock('./hooks/useGitBranchName.js'); vi.mock('./contexts/VimModeContext.js'); vi.mock('./contexts/SessionContext.js'); -vi.mock('./hooks/useTextBuffer.js'); +vi.mock('./components/shared/text-buffer.js'); vi.mock('./hooks/useLogger.js'); // Mock external utilities @@ -49,14 +73,153 @@ vi.mock('../utils/handleAutoUpdate.js'); vi.mock('./utils/ConsolePatcher.js'); vi.mock('../utils/cleanup.js'); +import { useHistory } from './hooks/useHistoryManager.js'; +import { useThemeCommand } from './hooks/useThemeCommand.js'; +import { useAuthCommand } from './auth/useAuth.js'; +import { useEditorSettings } from './hooks/useEditorSettings.js'; +import { useSettingsCommand } from './hooks/useSettingsCommand.js'; +import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js'; +import { useConsoleMessages } from './hooks/useConsoleMessages.js'; +import { useGeminiStream } from './hooks/useGeminiStream.js'; +import { useVim } from './hooks/vim.js'; +import { useFolderTrust } from './hooks/useFolderTrust.js'; +import { useMessageQueue } from './hooks/useMessageQueue.js'; +import { useAutoAcceptIndicator } from './hooks/useAutoAcceptIndicator.js'; +import { useWorkspaceMigration } from './hooks/useWorkspaceMigration.js'; +import { useGitBranchName } from './hooks/useGitBranchName.js'; +import { useVimMode } from './contexts/VimModeContext.js'; +import { useSessionStats } from './contexts/SessionContext.js'; +import { useTextBuffer } from './components/shared/text-buffer.js'; +import { useLogger } from './hooks/useLogger.js'; +import { useLoadingIndicator } from './hooks/useLoadingIndicator.js'; + describe('AppContainer State Management', () => { let mockConfig: Config; let mockSettings: LoadedSettings; let mockInitResult: InitializationResult; + // Create typed mocks for all hooks + const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock; + const mockedUseHistory = useHistory as Mock; + const mockedUseThemeCommand = useThemeCommand as Mock; + const mockedUseAuthCommand = useAuthCommand as Mock; + const mockedUseEditorSettings = useEditorSettings as Mock; + const mockedUseSettingsCommand = useSettingsCommand as Mock; + const mockedUseSlashCommandProcessor = useSlashCommandProcessor as Mock; + const mockedUseConsoleMessages = useConsoleMessages as Mock; + const mockedUseGeminiStream = useGeminiStream as Mock; + const mockedUseVim = useVim as Mock; + const mockedUseFolderTrust = useFolderTrust as Mock; + const mockedUseMessageQueue = useMessageQueue as Mock; + const mockedUseAutoAcceptIndicator = useAutoAcceptIndicator as Mock; + const mockedUseWorkspaceMigration = useWorkspaceMigration as Mock; + const mockedUseGitBranchName = useGitBranchName as Mock; + const mockedUseVimMode = useVimMode as Mock; + const mockedUseSessionStats = useSessionStats as Mock; + const mockedUseTextBuffer = useTextBuffer as Mock; + const mockedUseLogger = useLogger as Mock; + const mockedUseLoadingIndicator = useLoadingIndicator as Mock; + beforeEach(() => { vi.clearAllMocks(); + capturedUIState = null!; + capturedUIActions = null!; + + // **Provide a default return value for EVERY mocked hook.** + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: null, + handleProQuotaChoice: vi.fn(), + }); + mockedUseHistory.mockReturnValue({ + history: [], + addItem: vi.fn(), + updateItem: vi.fn(), + clearItems: vi.fn(), + loadHistory: vi.fn(), + }); + mockedUseThemeCommand.mockReturnValue({ + isThemeDialogOpen: false, + openThemeDialog: vi.fn(), + handleThemeSelect: vi.fn(), + handleThemeHighlight: vi.fn(), + }); + mockedUseAuthCommand.mockReturnValue({ + authState: 'authenticated', + setAuthState: vi.fn(), + authError: null, + onAuthError: vi.fn(), + }); + mockedUseEditorSettings.mockReturnValue({ + isEditorDialogOpen: false, + openEditorDialog: vi.fn(), + handleEditorSelect: vi.fn(), + exitEditorDialog: vi.fn(), + }); + mockedUseSettingsCommand.mockReturnValue({ + isSettingsDialogOpen: false, + openSettingsDialog: vi.fn(), + closeSettingsDialog: vi.fn(), + }); + mockedUseSlashCommandProcessor.mockReturnValue({ + handleSlashCommand: vi.fn(), + slashCommands: [], + pendingHistoryItems: [], + commandContext: {}, + shellConfirmationRequest: null, + confirmationRequest: null, + }); + mockedUseConsoleMessages.mockReturnValue({ + consoleMessages: [], + handleNewMessage: vi.fn(), + clearConsoleMessages: vi.fn(), + }); + mockedUseGeminiStream.mockReturnValue({ + streamingState: 'idle', + submitQuery: vi.fn(), + initError: null, + pendingHistoryItems: [], + thought: null, + cancelOngoingRequest: vi.fn(), + }); + mockedUseVim.mockReturnValue({ handleInput: vi.fn() }); + mockedUseFolderTrust.mockReturnValue({ + isFolderTrustDialogOpen: false, + handleFolderTrustSelect: vi.fn(), + isRestarting: false, + }); + mockedUseMessageQueue.mockReturnValue({ + messageQueue: [], + addMessage: vi.fn(), + clearQueue: vi.fn(), + getQueuedMessagesText: vi.fn().mockReturnValue(''), + }); + mockedUseAutoAcceptIndicator.mockReturnValue(false); + mockedUseWorkspaceMigration.mockReturnValue({ + showWorkspaceMigrationDialog: false, + workspaceExtensions: [], + onWorkspaceMigrationDialogOpen: vi.fn(), + onWorkspaceMigrationDialogClose: vi.fn(), + }); + mockedUseGitBranchName.mockReturnValue('main'); + mockedUseVimMode.mockReturnValue({ + isVimEnabled: false, + toggleVimEnabled: vi.fn(), + }); + mockedUseSessionStats.mockReturnValue({ stats: {} }); + mockedUseTextBuffer.mockReturnValue({ + text: '', + setText: vi.fn(), + // Add other properties if AppContainer uses them + }); + mockedUseLogger.mockReturnValue({ + getPreviousUserMessages: vi.fn().mockResolvedValue([]), + }); + mockedUseLoadingIndicator.mockReturnValue({ + elapsedTime: '0.0s', + currentLoadingPhrase: '', + }); + // Mock Config mockConfig = makeFakeConfig(); @@ -325,7 +488,73 @@ describe('AppContainer State Management', () => { expect(() => unmount()).not.toThrow(); }); }); -}); -// TODO: Add comprehensive integration test once all hook mocks are complete -// For now, the 14 passing unit tests provide good coverage of AppContainer functionality + describe('Quota and Fallback Integration', () => { + it('passes a null proQuotaRequest to UIStateContext by default', () => { + // The default mock from beforeEach already sets proQuotaRequest to null + render( + , + ); + + // Assert that the context value is as expected + expect(capturedUIState.proQuotaRequest).toBeNull(); + }); + + it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => { + // Arrange: Create a mock request object that a UI dialog would receive + const mockRequest = { + failedModel: 'gemini-pro', + fallbackModel: 'gemini-flash', + resolve: vi.fn(), + }; + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: mockRequest, + handleProQuotaChoice: vi.fn(), + }); + + // Act: Render the container + render( + , + ); + + // Assert: The mock request is correctly passed through the context + expect(capturedUIState.proQuotaRequest).toEqual(mockRequest); + }); + + it('passes the handleProQuotaChoice function to UIActionsContext', () => { + // Arrange: Create a mock handler function + const mockHandler = vi.fn(); + mockedUseQuotaAndFallback.mockReturnValue({ + proQuotaRequest: null, + handleProQuotaChoice: mockHandler, + }); + + // Act: Render the container + render( + , + ); + + // Assert: The action in the context is the mock handler we provided + expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler); + + // You can even verify that the plumbed function is callable + capturedUIActions.handleProQuotaChoice('auth'); + expect(mockHandler).toHaveBeenCalledWith('auth'); + }); + }); +}); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index b253ff88b5..e7b6af7eb9 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -24,18 +24,15 @@ import { MessageType, StreamingState } from './types.js'; import { type EditorType, type Config, - IdeClient, type DetectedIde, - ideContext, type IdeContext, + type UserTierId, + DEFAULT_GEMINI_FLASH_MODEL, + IdeClient, + ideContext, getErrorMessage, getAllGeminiMdFilenames, - UserTierId, AuthType, - isProQuotaExceededError, - isGenericQuotaExceededError, - logFlashFallback, - FlashFallbackEvent, clearCachedCredentialFile, } from '@google/gemini-cli-core'; import { validateAuthMethod } from '../config/auth.js'; @@ -44,6 +41,7 @@ import process from 'node:process'; import { useHistory } from './hooks/useHistoryManager.js'; import { useThemeCommand } from './hooks/useThemeCommand.js'; import { useAuthCommand } from './auth/useAuth.js'; +import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js'; import { useEditorSettings } from './hooks/useEditorSettings.js'; import { useSettingsCommand } from './hooks/useSettingsCommand.js'; import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js'; @@ -123,12 +121,18 @@ export const AppContainer = (props: AppContainerProps) => { const [isTrustedFolder, setIsTrustedFolder] = useState( config.isTrustedFolder(), ); - const [currentModel, setCurrentModel] = useState(config.getModel()); + + // Helper to determine the effective model, considering the fallback state. + const getEffectiveModel = useCallback(() => { + if (config.isInFallbackMode()) { + return DEFAULT_GEMINI_FLASH_MODEL; + } + return config.getModel(); + }, [config]); + + const [currentModel, setCurrentModel] = useState(getEffectiveModel()); + const [userTier, setUserTier] = useState(undefined); - const [isProQuotaDialogOpen, setIsProQuotaDialogOpen] = useState(false); - const [proQuotaDialogResolver, setProQuotaDialogResolver] = useState< - ((value: boolean) => void) | null - >(null); // Auto-accept indicator const showAutoAcceptIndicator = useAutoAcceptIndicator({ @@ -167,18 +171,17 @@ export const AppContainer = (props: AppContainerProps) => { // Watch for model changes (e.g., from Flash fallback) useEffect(() => { const checkModelChange = () => { - const configModel = config.getModel(); - if (configModel !== currentModel) { - setCurrentModel(configModel); + const effectiveModel = getEffectiveModel(); + if (effectiveModel !== currentModel) { + setCurrentModel(effectiveModel); } }; - // Check immediately and then periodically checkModelChange(); const interval = setInterval(checkModelChange, 1000); // Check every second return () => clearInterval(interval); - }, [config, currentModel]); + }, [config, currentModel, getEffectiveModel]); const { consoleMessages, @@ -273,6 +276,14 @@ export const AppContainer = (props: AppContainerProps) => { config, ); + const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({ + config, + historyManager, + userTier, + setAuthState, + setModelSwitchedFromQuotaError, + }); + // Derive auth state variables for backward compatibility with UIStateContext const isAuthDialogOpen = authState === AuthState.Updating; const isAuthenticating = authState === AuthState.Unauthenticated; @@ -477,132 +488,6 @@ Logging in with Google... Please restart Gemini CLI to continue. } }, [config, historyManager, settings.merged]); - // Set up Flash fallback handler - useEffect(() => { - const flashFallbackHandler = async ( - currentModel: string, - fallbackModel: string, - error?: unknown, - ): Promise => { - // Check if we've already switched to the fallback model - if (config.isInFallbackMode()) { - // If we're already in fallback mode, don't show the dialog again - return false; - } - - let message: string; - - if ( - config.getContentGeneratorConfig().authType === - AuthType.LOGIN_WITH_GOOGLE - ) { - // Use actual user tier if available; otherwise, default to FREE tier behavior (safe default) - const isPaidTier = - userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; - - // Check if this is a Pro quota exceeded error - if (error && isProQuotaExceededError(error)) { - if (isPaidTier) { - message = `⚡ You have reached your daily ${currentModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - message = `⚡ You have reached your daily ${currentModel} quota limit. -⚡ You can choose to authenticate with a paid API key or continue with the fallback model. -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } else if (error && isGenericQuotaExceededError(error)) { - if (isPaidTier) { - message = `⚡ You have reached your daily quota limit. -⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session. -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - message = `⚡ You have reached your daily quota limit. -⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session. -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } else { - if (isPaidTier) { - // Default fallback message for other cases (like consecutive 429s) - message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session. -⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit -⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; - } else { - // Default fallback message for other cases (like consecutive 429s) - message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session. -⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit -⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist -⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key -⚡ You can switch authentication methods by typing /auth`; - } - } - - // Add message to UI history - historyManager.addItem( - { - type: MessageType.INFO, - text: message, - }, - Date.now(), - ); - - // For Pro quota errors, show the dialog and wait for user's choice - if (error && isProQuotaExceededError(error)) { - // Set the flag to prevent tool continuation - setModelSwitchedFromQuotaError(true); - // Set global quota error flag to prevent Flash model calls - config.setQuotaErrorOccurred(true); - - // Show the ProQuotaDialog and wait for user's choice - const shouldContinueWithFallback = await new Promise( - (resolve) => { - setIsProQuotaDialogOpen(true); - setProQuotaDialogResolver(() => resolve); - }, - ); - - // If user chose to continue with fallback, we don't need to stop the current prompt - if (shouldContinueWithFallback) { - // Switch to fallback model for future use - config.setModel(fallbackModel); - config.setFallbackMode(true); - logFlashFallback( - config, - new FlashFallbackEvent( - config.getContentGeneratorConfig().authType!, - ), - ); - return true; // Continue with current prompt using fallback model - } - - // If user chose to authenticate, stop current prompt - return false; - } - - // For other quota errors, automatically switch to fallback model - // Set the flag to prevent tool continuation - setModelSwitchedFromQuotaError(true); - // Set global quota error flag to prevent Flash model calls - config.setQuotaErrorOccurred(true); - } - - // Switch model for future use but return false to stop current retry - config.setModel(fallbackModel); - config.setFallbackMode(true); - logFlashFallback( - config, - new FlashFallbackEvent(config.getContentGeneratorConfig().authType!), - ); - return false; // Don't continue with current prompt - }; - - config.setFlashFallbackHandler(flashFallbackHandler); - }, [config, historyManager, userTier]); - const cancelHandlerRef = useRef<() => void>(() => {}); const { @@ -681,22 +566,6 @@ Logging in with Google... Please restart Gemini CLI to continue. refreshStatic(); }, [historyManager, clearConsoleMessagesState, refreshStatic]); - const handleProQuotaChoice = useCallback( - (choice: 'auth' | 'continue') => { - setIsProQuotaDialogOpen(false); - if (proQuotaDialogResolver) { - if (choice === 'auth') { - proQuotaDialogResolver(false); // Don't continue with fallback, show auth dialog - setAuthState(AuthState.Updating); - } else { - proQuotaDialogResolver(true); // Continue with fallback model - } - setProQuotaDialogResolver(null); - } - }, - [proQuotaDialogResolver, setAuthState], - ); - const { handleInput: vimHandleInput } = useVim(buffer, handleFinalSubmit); /** @@ -712,7 +581,7 @@ Logging in with Google... Please restart Gemini CLI to continue. !isProcessing && (streamingState === StreamingState.Idle || streamingState === StreamingState.Responding) && - !isProQuotaDialogOpen; + !proQuotaRequest; // Compute available terminal height based on controls measurement const availableTerminalHeight = useMemo(() => { @@ -1029,7 +898,7 @@ Logging in with Google... Please restart Gemini CLI to continue. isAuthDialogOpen || isEditorDialogOpen || showPrivacyNotice || - isProQuotaDialogOpen, + !!proQuotaRequest, [ showWorkspaceMigrationDialog, shouldShowIdePrompt, @@ -1042,7 +911,7 @@ Logging in with Google... Please restart Gemini CLI to continue. isAuthDialogOpen, isEditorDialogOpen, showPrivacyNotice, - isProQuotaDialogOpen, + proQuotaRequest, ], ); @@ -1101,11 +970,9 @@ Logging in with Google... Please restart Gemini CLI to continue. showAutoAcceptIndicator, showWorkspaceMigrationDialog, workspaceExtensions, - // Use current state values instead of config.getModel() currentModel, userTier, - isProQuotaDialogOpen, - // New fields + proQuotaRequest, contextFileNames, errorCount, availableTerminalHeight, @@ -1174,10 +1041,8 @@ Logging in with Google... Please restart Gemini CLI to continue. showAutoAcceptIndicator, showWorkspaceMigrationDialog, workspaceExtensions, - // Quota-related state dependencies userTier, - isProQuotaDialogOpen, - // New fields dependencies + proQuotaRequest, contextFileNames, errorCount, availableTerminalHeight, @@ -1196,7 +1061,6 @@ Logging in with Google... Please restart Gemini CLI to continue. updateInfo, showIdeRestartPrompt, isRestarting, - // Quota-related dependencies currentModel, ], ); diff --git a/packages/cli/src/ui/components/DialogManager.tsx b/packages/cli/src/ui/components/DialogManager.tsx index 76cfb2e6f8..9a79d2eebd 100644 --- a/packages/cli/src/ui/components/DialogManager.tsx +++ b/packages/cli/src/ui/components/DialogManager.tsx @@ -22,7 +22,6 @@ import { useUIState } from '../contexts/UIStateContext.js'; import { useUIActions } from '../contexts/UIActionsContext.js'; import { useConfig } from '../contexts/ConfigContext.js'; import { useSettings } from '../contexts/SettingsContext.js'; -import { DEFAULT_GEMINI_FLASH_MODEL } from '@google/gemini-cli-core'; import process from 'node:process'; // Props for DialogManager @@ -54,11 +53,11 @@ export const DialogManager = () => { /> ); } - if (uiState.isProQuotaDialogOpen) { + if (uiState.proQuotaRequest) { return ( ); diff --git a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx index 31bb4f03f6..c3a1afda6a 100644 --- a/packages/cli/src/ui/components/ProQuotaDialog.test.tsx +++ b/packages/cli/src/ui/components/ProQuotaDialog.test.tsx @@ -22,7 +22,7 @@ describe('ProQuotaDialog', () => { it('should render with correct title and options', () => { const { lastFrame } = render( {}} />, @@ -53,7 +53,7 @@ describe('ProQuotaDialog', () => { const mockOnChoice = vi.fn(); render( , @@ -72,7 +72,7 @@ describe('ProQuotaDialog', () => { const mockOnChoice = vi.fn(); render( , diff --git a/packages/cli/src/ui/components/ProQuotaDialog.tsx b/packages/cli/src/ui/components/ProQuotaDialog.tsx index d94d069857..c547967508 100644 --- a/packages/cli/src/ui/components/ProQuotaDialog.tsx +++ b/packages/cli/src/ui/components/ProQuotaDialog.tsx @@ -10,13 +10,13 @@ import { RadioButtonSelect } from './shared/RadioButtonSelect.js'; import { Colors } from '../colors.js'; interface ProQuotaDialogProps { - currentModel: string; + failedModel: string; fallbackModel: string; onChoice: (choice: 'auth' | 'continue') => void; } export function ProQuotaDialog({ - currentModel, + failedModel, fallbackModel, onChoice, }: ProQuotaDialogProps): React.JSX.Element { @@ -38,7 +38,7 @@ export function ProQuotaDialog({ return ( - Pro quota limit reached for {currentModel}. + Pro quota limit reached for {failedModel}. void; +} + export interface UIState { history: HistoryItem[]; isThemeDialogOpen: boolean; @@ -78,9 +85,8 @@ export interface UIState { workspaceExtensions: any[]; // Extension[] // Quota-related state userTier: UserTierId | undefined; - isProQuotaDialogOpen: boolean; + proQuotaRequest: ProQuotaDialogRequest | null; currentModel: string; - // New fields for complete state management contextFileNames: string[]; errorCount: number; availableTerminalHeight: number | undefined; diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts new file mode 100644 index 0000000000..7dd93eb72e --- /dev/null +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -0,0 +1,391 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + vi, + describe, + it, + expect, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import { act, renderHook } from '@testing-library/react'; +import { + type Config, + type FallbackModelHandler, + UserTierId, + AuthType, + isGenericQuotaExceededError, + isProQuotaExceededError, + makeFakeConfig, +} from '@google/gemini-cli-core'; +import { useQuotaAndFallback } from './useQuotaAndFallback.js'; +import type { UseHistoryManagerReturn } from './useHistoryManager.js'; +import { AuthState, MessageType } from '../types.js'; + +// Mock the error checking functions from the core package to control test scenarios +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + isGenericQuotaExceededError: vi.fn(), + isProQuotaExceededError: vi.fn(), + }; +}); + +// Use a type alias for SpyInstance as it's not directly exported +type SpyInstance = ReturnType; + +describe('useQuotaAndFallback', () => { + let mockConfig: Config; + let mockHistoryManager: UseHistoryManagerReturn; + let mockSetAuthState: Mock; + let mockSetModelSwitchedFromQuotaError: Mock; + let setFallbackHandlerSpy: SpyInstance; + + const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock; + const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock; + + beforeEach(() => { + mockConfig = makeFakeConfig(); + + // Spy on the method that requires the private field and mock its return. + // This is cleaner than modifying the config class for tests. + vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ + model: 'gemini-pro', + authType: AuthType.LOGIN_WITH_GOOGLE, + }); + + mockHistoryManager = { + addItem: vi.fn(), + history: [], + updateItem: vi.fn(), + clearItems: vi.fn(), + loadHistory: vi.fn(), + }; + mockSetAuthState = vi.fn(); + mockSetModelSwitchedFromQuotaError = vi.fn(); + + setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler'); + vi.spyOn(mockConfig, 'setQuotaErrorOccurred'); + + mockedIsGenericQuotaExceededError.mockReturnValue(false); + mockedIsProQuotaExceededError.mockReturnValue(false); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should register a fallback handler on initialization', () => { + renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1); + expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function); + }); + + describe('Fallback Handler Logic', () => { + // Helper function to render the hook and extract the registered handler + const getRegisteredHandler = ( + userTier: UserTierId = UserTierId.FREE, + ): FallbackModelHandler => { + renderHook( + (props) => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: props.userTier, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + { initialProps: { userTier } }, + ); + return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler; + }; + + it('should return null and take no action if already in fallback mode', async () => { + vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true); + const handler = getRegisteredHandler(); + const result = await handler('gemini-pro', 'gemini-flash', new Error()); + + expect(result).toBeNull(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => { + // Override the default mock from beforeEach for this specific test + vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({ + model: 'gemini-pro', + authType: AuthType.USE_GEMINI, + }); + + const handler = getRegisteredHandler(); + const result = await handler('gemini-pro', 'gemini-flash', new Error()); + + expect(result).toBeNull(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + describe('Automatic Fallback Scenarios', () => { + const testCases = [ + { + errorType: 'generic', + tier: UserTierId.FREE, + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B', + 'upgrade to a Gemini Code Assist Standard or Enterprise plan', + ], + }, + { + errorType: 'generic', + tier: UserTierId.STANDARD, // Paid tier + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B', + 'switch to using a paid API key from AI Studio', + ], + }, + { + errorType: 'other', + tier: UserTierId.FREE, + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B for faster responses', + 'upgrade to a Gemini Code Assist Standard or Enterprise plan', + ], + }, + { + errorType: 'other', + tier: UserTierId.LEGACY, // Paid tier + expectedMessageSnippets: [ + 'Automatically switching from model-A to model-B for faster responses', + 'switch to using a paid API key from AI Studio', + ], + }, + ]; + + for (const { errorType, tier, expectedMessageSnippets } of testCases) { + it(`should handle ${errorType} error for ${tier} tier correctly`, async () => { + mockedIsGenericQuotaExceededError.mockReturnValue( + errorType === 'generic', + ); + + const handler = getRegisteredHandler(tier); + const result = await handler( + 'model-A', + 'model-B', + new Error('quota exceeded'), + ); + + // Automatic fallbacks should return 'stop' + expect(result).toBe('stop'); + + expect(mockHistoryManager.addItem).toHaveBeenCalledWith( + expect.objectContaining({ type: MessageType.INFO }), + expect.any(Number), + ); + + const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0] + .text; + for (const snippet of expectedMessageSnippets) { + expect(message).toContain(snippet); + } + + expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true); + expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true); + }); + } + }); + + describe('Interactive Fallback (Pro Quota Error)', () => { + beforeEach(() => { + mockedIsProQuotaExceededError.mockReturnValue(true); + }); + + it('should set an interactive request and wait for user choice', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + // Call the handler but do not await it, to check the intermediate state + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + + await act(async () => {}); + + // The hook should now have a pending request for the UI to handle + expect(result.current.proQuotaRequest).not.toBeNull(); + expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro'); + + // Simulate the user choosing to continue with the fallback model + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + // The original promise from the handler should now resolve + const intent = await promise; + expect(intent).toBe('retry'); + + // The pending request should be cleared from the state + expect(result.current.proQuotaRequest).toBeNull(); + }); + + it('should handle race conditions by stopping subsequent requests', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + const promise1 = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota 1'), + ); + await act(async () => {}); + + const firstRequest = result.current.proQuotaRequest; + expect(firstRequest).not.toBeNull(); + + const result2 = await handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota 2'), + ); + + // The lock should have stopped the second request + expect(result2).toBe('stop'); + expect(result.current.proQuotaRequest).toBe(firstRequest); + + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + const intent1 = await promise1; + expect(intent1).toBe('retry'); + expect(result.current.proQuotaRequest).toBeNull(); + }); + }); + }); + + describe('handleProQuotaChoice', () => { + beforeEach(() => { + mockedIsProQuotaExceededError.mockReturnValue(true); + }); + + it('should do nothing if there is no pending pro quota request', () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + act(() => { + result.current.handleProQuotaChoice('auth'); + }); + + expect(mockSetAuthState).not.toHaveBeenCalled(); + expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); + }); + + it('should resolve intent to "auth" and trigger auth state update', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + await act(async () => {}); // Allow state to update + + act(() => { + result.current.handleProQuotaChoice('auth'); + }); + + const intent = await promise; + expect(intent).toBe('auth'); + expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating); + expect(result.current.proQuotaRequest).toBeNull(); + }); + + it('should resolve intent to "retry" and add info message on continue', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setAuthState: mockSetAuthState, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + // The first `addItem` call is for the initial quota error message + const promise = handler( + 'gemini-pro', + 'gemini-flash', + new Error('pro quota'), + ); + await act(async () => {}); // Allow state to update + + act(() => { + result.current.handleProQuotaChoice('continue'); + }); + + const intent = await promise; + expect(intent).toBe('retry'); + expect(result.current.proQuotaRequest).toBeNull(); + + // Check for the second "Switched to fallback model" message + expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2); + const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0]; + expect(lastCall.type).toBe(MessageType.INFO); + expect(lastCall.text).toContain('Switched to fallback model.'); + }); + }); +}); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts new file mode 100644 index 0000000000..a7eb77659a --- /dev/null +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -0,0 +1,175 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + AuthType, + type Config, + type FallbackModelHandler, + type FallbackIntent, + isGenericQuotaExceededError, + isProQuotaExceededError, + UserTierId, +} from '@google/gemini-cli-core'; +import { useCallback, useEffect, useRef, useState } from 'react'; +import { type UseHistoryManagerReturn } from './useHistoryManager.js'; +import { AuthState, MessageType } from '../types.js'; +import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js'; + +interface UseQuotaAndFallbackArgs { + config: Config; + historyManager: UseHistoryManagerReturn; + userTier: UserTierId | undefined; + setAuthState: (state: AuthState) => void; + setModelSwitchedFromQuotaError: (value: boolean) => void; +} + +export function useQuotaAndFallback({ + config, + historyManager, + userTier, + setAuthState, + setModelSwitchedFromQuotaError, +}: UseQuotaAndFallbackArgs) { + const [proQuotaRequest, setProQuotaRequest] = + useState(null); + const isDialogPending = useRef(false); + + // Set up Flash fallback handler + useEffect(() => { + const fallbackHandler: FallbackModelHandler = async ( + failedModel, + fallbackModel, + error, + ): Promise => { + if (config.isInFallbackMode()) { + return null; + } + + // Fallbacks are currently only handled for OAuth users. + const contentGeneratorConfig = config.getContentGeneratorConfig(); + if ( + !contentGeneratorConfig || + contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE + ) { + return null; + } + + // Use actual user tier if available; otherwise, default to FREE tier behavior (safe default) + const isPaidTier = + userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD; + + let message: string; + + if (error && isProQuotaExceededError(error)) { + // Pro Quota specific messages (Interactive) + if (isPaidTier) { + message = `⚡ You have reached your daily ${failedModel} quota limit. +⚡ You can choose to authenticate with a paid API key or continue with the fallback model. +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `⚡ You have reached your daily ${failedModel} quota limit. +⚡ You can choose to authenticate with a paid API key or continue with the fallback model. +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } else if (error && isGenericQuotaExceededError(error)) { + // Generic Quota (Automatic fallback) + const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`; + + if (isPaidTier) { + message = `${actionMessage} +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `${actionMessage} +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } else { + // Consecutive 429s or other errors (Automatic fallback) + const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`; + + if (isPaidTier) { + message = `${actionMessage} +⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit +⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`; + } else { + message = `${actionMessage} +⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit +⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist +⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key +⚡ You can switch authentication methods by typing /auth`; + } + } + + // Add message to UI history + historyManager.addItem( + { + type: MessageType.INFO, + text: message, + }, + Date.now(), + ); + + setModelSwitchedFromQuotaError(true); + config.setQuotaErrorOccurred(true); + + // Interactive Fallback for Pro quota + if (error && isProQuotaExceededError(error)) { + if (isDialogPending.current) { + return 'stop'; // A dialog is already active, so just stop this request. + } + isDialogPending.current = true; + + const intent: FallbackIntent = await new Promise( + (resolve) => { + setProQuotaRequest({ + failedModel, + fallbackModel, + resolve, + }); + }, + ); + + return intent; + } + + return 'stop'; + }; + + config.setFallbackModelHandler(fallbackHandler); + }, [config, historyManager, userTier, setModelSwitchedFromQuotaError]); + + const handleProQuotaChoice = useCallback( + (choice: 'auth' | 'continue') => { + if (!proQuotaRequest) return; + + const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry'; + proQuotaRequest.resolve(intent); + setProQuotaRequest(null); + isDialogPending.current = false; // Reset the flag here + + if (choice === 'auth') { + setAuthState(AuthState.Updating); + } else { + historyManager.addItem( + { + type: MessageType.INFO, + text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.', + }, + Date.now(), + ); + } + }, + [proQuotaRequest, setAuthState, historyManager], + ); + + return { + proQuotaRequest, + handleProQuotaChoice, + }; +} diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 05dff79c06..fbc22f73e5 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -52,6 +52,7 @@ import type { FileSystemService } from '../services/fileSystemService.js'; import { StandardFileSystemService } from '../services/fileSystemService.js'; import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js'; import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js'; +import type { FallbackModelHandler } from '../fallback/types.js'; // Re-export OAuth config type export type { MCPOAuthConfig, AnyToolInvocation }; @@ -157,12 +158,6 @@ export interface SandboxConfig { image: string; } -export type FlashFallbackHandler = ( - currentModel: string, - fallbackModel: string, - error?: unknown, -) => Promise; - export interface ConfigParameters { sessionId: string; embeddingModel?: string; @@ -281,7 +276,7 @@ export class Config { name: string; extensionName: string; }>; - flashFallbackHandler?: FlashFallbackHandler; + fallbackModelHandler?: FallbackModelHandler; private quotaErrorOccurred: boolean = false; private readonly summarizeToolOutput: | Record @@ -490,8 +485,8 @@ export class Config { this.inFallbackMode = active; } - setFlashFallbackHandler(handler: FlashFallbackHandler): void { - this.flashFallbackHandler = handler; + setFallbackModelHandler(handler: FallbackModelHandler): void { + this.fallbackModelHandler = handler; } getMaxSessionTurns(): number { diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index fe110cca4a..0fb5a5b7d7 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -4,7 +4,15 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; import type { Content, GenerateContentResponse, Part } from '@google/genai'; import { @@ -212,16 +220,19 @@ describe('Gemini Client (client.ts)', () => { let mockContentGenerator: ContentGenerator; let mockConfig: Config; let client: GeminiClient; + let mockGenerateContentFn: Mock; beforeEach(async () => { vi.resetAllMocks(); + mockGenerateContentFn = vi.fn().mockResolvedValue({ + candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }], + }); + // Disable 429 simulation for tests setSimulate429(false); mockContentGenerator = { - generateContent: vi.fn().mockResolvedValue({ - candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }], - }), + generateContent: mockGenerateContentFn, generateContentStream: vi.fn(), countTokens: vi.fn(), embedContent: vi.fn(), @@ -270,6 +281,7 @@ describe('Gemini Client (client.ts)', () => { getDirectories: vi.fn().mockReturnValue(['/test/dir']), }), getGeminiClient: vi.fn(), + isInFallbackMode: vi.fn().mockReturnValue(false), setFallbackMode: vi.fn(), getChatCompression: vi.fn().mockReturnValue(undefined), getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), @@ -453,6 +465,27 @@ describe('Gemini Client (client.ts)', () => { 'test-session-id', ); }); + + it('should use the Flash model when fallback mode is active', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const schema = { type: 'string' }; + const abortSignal = new AbortController().signal; + const requestedModel = 'gemini-2.5-pro'; // A non-flash model + + // Mock config to be in fallback mode + // We access the mock via the client instance which holds the mocked config + vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true); + + await client.generateJson(contents, schema, abortSignal, requestedModel); + + // Assert that the Flash model was used, not the requested model + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'test-session-id', + ); + }); }); describe('addHistory', () => { @@ -2210,32 +2243,28 @@ ${JSON.stringify( 'test-session-id', ); }); - }); - describe('handleFlashFallback', () => { - it('should use current model from config when checking for fallback', async () => { - const initialModel = client['config'].getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + it('should use the Flash model when fallback mode is active', async () => { + const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const generationConfig = { temperature: 0.5 }; + const abortSignal = new AbortController().signal; + const requestedModel = 'gemini-2.5-pro'; // A non-flash model - // mock config been changed - const currentModel = initialModel + '-changed'; - const getModelSpy = vi.spyOn(client['config'], 'getModel'); - getModelSpy.mockReturnValue(currentModel); + // Mock config to be in fallback mode + vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true); - const mockFallbackHandler = vi.fn().mockResolvedValue(true); - client['config'].flashFallbackHandler = mockFallbackHandler; - client['config'].setModel = vi.fn(); - - const result = await client['handleFlashFallback']( - AuthType.LOGIN_WITH_GOOGLE, + await client.generateContent( + contents, + generationConfig, + abortSignal, + requestedModel, ); - expect(result).toBe(fallbackModel); - - expect(mockFallbackHandler).toHaveBeenCalledWith( - currentModel, - fallbackModel, - undefined, + expect(mockGenerateContentFn).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'test-session-id', ); }); }); diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 01c52c4c92..658701517b 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -31,7 +31,6 @@ import { isFunctionResponse } from '../utils/messageInspectors.js'; import { tokenLimit } from './tokenLimits.js'; import type { ChatRecordingService } from '../services/chatRecordingService.js'; import type { ContentGenerator } from './contentGenerator.js'; -import { AuthType } from './contentGenerator.js'; import { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_THINKING_MODE, @@ -49,6 +48,7 @@ import { NextSpeakerCheckEvent, } from '../telemetry/types.js'; import type { IdeContext, File } from '../ide/ideContext.js'; +import { handleFallback } from '../fallback/handler.js'; export function isThinkingSupported(model: string) { if (model.startsWith('gemini-2.5')) return true; @@ -550,6 +550,8 @@ export class GeminiClient { model: string, config: GenerateContentConfig = {}, ): Promise> { + let currentAttemptModel: string = model; + try { const userMemory = this.config.getUserMemory(); const systemInstruction = getCoreSystemPrompt(userMemory); @@ -559,10 +561,15 @@ export class GeminiClient { ...config, }; - const apiCall = () => - this.getContentGeneratorOrFail().generateContent( + const apiCall = () => { + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : model; + currentAttemptModel = modelToUse; + + return this.getContentGeneratorOrFail().generateContent( { - model, + model: modelToUse, config: { ...requestConfig, systemInstruction, @@ -573,10 +580,17 @@ export class GeminiClient { }, this.lastPromptId, ); + }; + + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => + // Pass the captured model to the centralized handler. + await handleFallback(this.config, currentAttemptModel, authType, error); const result = await retryWithBackoff(apiCall, { - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); @@ -599,7 +613,7 @@ export class GeminiClient { if (text.startsWith(prefix) && text.endsWith(suffix)) { logMalformedJsonResponse( this.config, - new MalformedJsonResponseEvent(model), + new MalformedJsonResponseEvent(currentAttemptModel), ); text = text .substring(prefix.length, text.length - suffix.length) @@ -655,6 +669,8 @@ export class GeminiClient { abortSignal: AbortSignal, model: string, ): Promise { + let currentAttemptModel: string = model; + const configToUse: GenerateContentConfig = { ...this.generateContentConfig, ...generationConfig, @@ -670,19 +686,30 @@ export class GeminiClient { systemInstruction, }; - const apiCall = () => - this.getContentGeneratorOrFail().generateContent( + const apiCall = () => { + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : model; + currentAttemptModel = modelToUse; + + return this.getContentGeneratorOrFail().generateContent( { - model, + model: modelToUse, config: requestConfig, contents, }, this.lastPromptId, ); + }; + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => + // Pass the captured model to the centralized handler. + await handleFallback(this.config, currentAttemptModel, authType, error); const result = await retryWithBackoff(apiCall, { - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); return result; @@ -693,7 +720,7 @@ export class GeminiClient { await reportError( error, - `Error generating content via API with model ${model}.`, + `Error generating content via API with model ${currentAttemptModel}.`, { requestContents: contents, requestConfig: configToUse, @@ -701,7 +728,7 @@ export class GeminiClient { 'generateContent-api', ); throw new Error( - `Failed to generate content with model ${model}: ${getErrorMessage(error)}`, + `Failed to generate content with model ${currentAttemptModel}: ${getErrorMessage(error)}`, ); } } @@ -880,53 +907,6 @@ export class GeminiClient { compressionStatus: CompressionStatus.COMPRESSED, }; } - - /** - * Handles falling back to Flash model when persistent 429 errors occur for OAuth users. - * Uses a fallback handler if provided by the config; otherwise, returns null. - */ - private async handleFlashFallback( - authType?: string, - error?: unknown, - ): Promise { - // Only handle fallback for OAuth users - if (authType !== AuthType.LOGIN_WITH_GOOGLE) { - return null; - } - - const currentModel = this.config.getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - - // Don't fallback if already using Flash model - if (currentModel === fallbackModel) { - return null; - } - - // Check if config has a fallback handler (set by CLI package) - const fallbackHandler = this.config.flashFallbackHandler; - if (typeof fallbackHandler === 'function') { - try { - const accepted = await fallbackHandler( - currentModel, - fallbackModel, - error, - ); - if (accepted !== false && accepted !== null) { - this.config.setModel(fallbackModel); - this.config.setFallbackMode(true); - return fallbackModel; - } - // Check if the model was switched manually in the handler - if (this.config.getModel() === fallbackModel) { - return null; // Model was switched but don't continue with current prompt - } - } catch (error) { - console.warn('Flash fallback handler failed:', error); - } - } - - return null; - } } export const TEST_ONLY = { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index 8305b1f359..c90ad2e722 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -20,6 +20,9 @@ import { } from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { AuthType } from './contentGenerator.js'; +import { type RetryOptions } from '../utils/retry.js'; // Mock fs module to prevent actual file system operations during tests const mockFileSystem = new Map(); @@ -47,6 +50,23 @@ vi.mock('node:fs', () => { }; }); +const { mockHandleFallback } = vi.hoisted(() => ({ + mockHandleFallback: vi.fn(), +})); + +// Add mock for the retry utility +const { mockRetryWithBackoff } = vi.hoisted(() => ({ + mockRetryWithBackoff: vi.fn(), +})); + +vi.mock('../utils/retry.js', () => ({ + retryWithBackoff: mockRetryWithBackoff, +})); + +vi.mock('../fallback/handler.js', () => ({ + handleFallback: mockHandleFallback, +})); + const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } = vi.hoisted(() => ({ mockLogInvalidChunk: vi.fn(), @@ -76,17 +96,21 @@ describe('GeminiChat', () => { batchEmbedContents: vi.fn(), } as unknown as ContentGenerator; + mockHandleFallback.mockClear(); + // Default mock implementation for tests that don't care about retry logic + mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); mockConfig = { getSessionId: () => 'test-session-id', getTelemetryLogPromptsEnabled: () => true, getUsageStatisticsEnabled: () => true, getDebugMode: () => false, - getContentGeneratorConfig: () => ({ - authType: 'oauth-personal', + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: 'oauth-personal', // Ensure this is set for fallback tests model: 'test-model', }), getModel: vi.fn().mockReturnValue('gemini-pro'), setModel: vi.fn(), + isInFallbackMode: vi.fn().mockReturnValue(false), getQuotaErrorOccurred: vi.fn().mockReturnValue(false), setQuotaErrorOccurred: vi.fn(), flashFallbackHandler: undefined, @@ -1476,8 +1500,176 @@ describe('GeminiChat', () => { expect(turn4.parts[0].text).toBe('second response'); }); + describe('Model Resolution', () => { + const mockResponse = { + candidates: [ + { + content: { parts: [{ text: 'response' }], role: 'model' }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + + it('should use the configured model when not in fallback mode (sendMessage)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro'); + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false); + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( + mockResponse, + ); + + await chat.sendMessage({ message: 'test' }, 'prompt-id-res1'); + + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: 'gemini-2.5-pro', + }), + 'prompt-id-res1', + ); + }); + + it('should use the FLASH model when in fallback mode (sendMessage)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro'); + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + vi.mocked(mockContentGenerator.generateContent).mockResolvedValue( + mockResponse, + ); + + await chat.sendMessage({ message: 'test' }, 'prompt-id-res2'); + + expect(mockContentGenerator.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'prompt-id-res2', + ); + }); + + it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro'); + vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); + vi.mocked(mockContentGenerator.generateContentStream).mockImplementation( + async () => + (async function* () { + yield mockResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-res3', + ); + for await (const _ of stream) { + // consume stream + } + + expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith( + expect.objectContaining({ + model: DEFAULT_GEMINI_FLASH_MODEL, + }), + 'prompt-id-res3', + ); + }); + }); + + describe('Fallback Integration (Retries)', () => { + const error429 = Object.assign(new Error('API Error 429: Quota exceeded'), { + status: 429, + }); + + // Define the simulated behavior for retryWithBackoff for these tests. + // This simulation tries the apiCall, if it fails, it calls the callback, + // and then tries the apiCall again if the callback returns true. + const simulateRetryBehavior = async ( + apiCall: () => Promise, + options: Partial, + ) => { + try { + return await apiCall(); + } catch (error) { + if (options.onPersistent429) { + // We simulate the "persistent" trigger here for simplicity. + const shouldRetry = await options.onPersistent429( + options.authType, + error, + ); + if (shouldRetry) { + return await apiCall(); + } + } + throw error; // Stop if callback returns false/null or doesn't exist + } + }; + + beforeEach(() => { + mockRetryWithBackoff.mockImplementation(simulateRetryBehavior); + }); + + afterEach(() => { + mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall()); + }); + + it('should call handleFallback with the specific failed model and retry if handler returns true', async () => { + const FAILED_MODEL = 'gemini-2.5-pro'; + vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL); + const authType = AuthType.LOGIN_WITH_GOOGLE; + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType, + model: FAILED_MODEL, + }); + + const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode'); + isInFallbackModeSpy.mockReturnValue(false); + + vi.mocked(mockContentGenerator.generateContent) + .mockRejectedValueOnce(error429) // Attempt 1 fails + .mockResolvedValueOnce({ + candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }], + } as unknown as GenerateContentResponse); // Attempt 2 succeeds + + mockHandleFallback.mockImplementation(async () => { + isInFallbackModeSpy.mockReturnValue(true); + return true; // Signal retry + }); + + const result = await chat.sendMessage( + { message: 'trigger 429' }, + 'prompt-id-fb1', + ); + + expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1); + expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2); + expect(mockHandleFallback).toHaveBeenCalledTimes(1); + + expect(mockHandleFallback).toHaveBeenCalledWith( + mockConfig, + FAILED_MODEL, + authType, + error429, + ); + + expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe( + 'Success on retry', + ); + }); + + it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => { + vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro'); + vi.mocked(mockContentGenerator.generateContent).mockRejectedValue( + error429, + ); + mockHandleFallback.mockResolvedValue(false); + + await expect( + chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'), + ).rejects.toThrow(error429); + + expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1); + expect(mockHandleFallback).toHaveBeenCalledTimes(1); + }); + }); + it('should discard valid partial content from a failed attempt upon retry', async () => { - // ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content. + // Mock the stream to fail on the first attempt after yielding some valid content. vi.mocked(mockContentGenerator.generateContentStream) .mockImplementationOnce(async () => // First attempt: yields one valid chunk, then one invalid chunk @@ -1512,7 +1704,7 @@ describe('GeminiChat', () => { })(), ); - // ACT: Send a message and consume the stream + // Send a message and consume the stream const stream = await chat.sendMessageStream( { message: 'test' }, 'prompt-id-discard-test', @@ -1522,7 +1714,6 @@ describe('GeminiChat', () => { events.push(event); } - // ASSERT // Check that a retry happened expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2); expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 62942b25f8..8ce52a2640 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -18,7 +18,6 @@ import type { import { toParts } from '../code_assist/converter.js'; import { createUserContent } from '@google/genai'; import { retryWithBackoff } from '../utils/retry.js'; -import { AuthType } from './contentGenerator.js'; import type { Config } from '../config/config.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; @@ -35,6 +34,7 @@ import { ContentRetryFailureEvent, InvalidChunkEvent, } from '../telemetry/types.js'; +import { handleFallback } from '../fallback/handler.js'; import { isFunctionResponse } from '../utils/messageInspectors.js'; import { partListUnionToString } from './geminiRequest.js'; @@ -179,53 +179,6 @@ export class GeminiChat { this.chatRecordingService.initialize(); } - /** - * Handles falling back to Flash model when persistent 429 errors occur for OAuth users. - * Uses a fallback handler if provided by the config; otherwise, returns null. - */ - private async handleFlashFallback( - authType?: string, - error?: unknown, - ): Promise { - // Only handle fallback for OAuth users - if (authType !== AuthType.LOGIN_WITH_GOOGLE) { - return null; - } - - const currentModel = this.config.getModel(); - const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - - // Don't fallback if already using Flash model - if (currentModel === fallbackModel) { - return null; - } - - // Check if config has a fallback handler (set by CLI package) - const fallbackHandler = this.config.flashFallbackHandler; - if (typeof fallbackHandler === 'function') { - try { - const accepted = await fallbackHandler( - currentModel, - fallbackModel, - error, - ); - if (accepted !== false && accepted !== null) { - this.config.setModel(fallbackModel); - this.config.setFallbackMode(true); - return fallbackModel; - } - // Check if the model was switched manually in the handler - if (this.config.getModel() === fallbackModel) { - return null; // Model was switched but don't continue with current prompt - } - } catch (error) { - console.warn('Flash fallback handler failed:', error); - } - } - - return null; - } - setSystemInstruction(sysInstr: string) { this.generationConfig.systemInstruction = sysInstr; } @@ -272,8 +225,13 @@ export class GeminiChat { let response: GenerateContentResponse; try { + let currentAttemptModel: string | undefined; + const apiCall = () => { - const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL; + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : this.config.getModel(); + currentAttemptModel = modelToUse; // Prevent Flash model calls immediately after quota error if ( @@ -295,6 +253,19 @@ export class GeminiChat { ); }; + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => { + if (!currentAttemptModel) return null; + return await handleFallback( + this.config, + currentAttemptModel, + authType, + error, + ); + }; + response = await retryWithBackoff(apiCall, { shouldRetry: (error: unknown) => { // Check for known error messages and codes. @@ -305,8 +276,7 @@ export class GeminiChat { } return false; // Don't retry other errors by default }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); @@ -484,8 +454,13 @@ export class GeminiChat { prompt_id: string, userContent: Content, ): Promise> { + let currentAttemptModel: string | undefined; + const apiCall = () => { - const modelToUse = this.config.getModel(); + const modelToUse = this.config.isInFallbackMode() + ? DEFAULT_GEMINI_FLASH_MODEL + : this.config.getModel(); + currentAttemptModel = modelToUse; if ( this.config.getQuotaErrorOccurred() && @@ -506,6 +481,19 @@ export class GeminiChat { ); }; + const onPersistent429Callback = async ( + authType?: string, + error?: unknown, + ) => { + if (!currentAttemptModel) return null; + return await handleFallback( + this.config, + currentAttemptModel, + authType, + error, + ); + }; + const streamResponse = await retryWithBackoff(apiCall, { shouldRetry: (error: unknown) => { if (error instanceof Error && error.message) { @@ -515,8 +503,7 @@ export class GeminiChat { } return false; }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), + onPersistent429: onPersistent429Callback, authType: this.config.getContentGeneratorConfig()?.authType, }); diff --git a/packages/core/src/fallback/handler.test.ts b/packages/core/src/fallback/handler.test.ts new file mode 100644 index 0000000000..77c9375644 --- /dev/null +++ b/packages/core/src/fallback/handler.test.ts @@ -0,0 +1,218 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + type Mock, + type MockInstance, + afterEach, +} from 'vitest'; +import { handleFallback } from './handler.js'; +import type { Config } from '../config/config.js'; +import { AuthType } from '../core/contentGenerator.js'; +import { + DEFAULT_GEMINI_FLASH_MODEL, + DEFAULT_GEMINI_MODEL, +} from '../config/models.js'; +import { logFlashFallback } from '../telemetry/index.js'; +import type { FallbackModelHandler } from './types.js'; + +// Mock the telemetry logger and event class +vi.mock('../telemetry/index.js', () => ({ + logFlashFallback: vi.fn(), + FlashFallbackEvent: class {}, +})); + +const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL; +const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL; +const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE; +const AUTH_API_KEY = AuthType.USE_GEMINI; + +const createMockConfig = (overrides: Partial = {}): Config => + ({ + isInFallbackMode: vi.fn(() => false), + setFallbackMode: vi.fn(), + fallbackHandler: undefined, + ...overrides, + }) as unknown as Config; + +describe('handleFallback', () => { + let mockConfig: Config; + let mockHandler: Mock; + let consoleErrorSpy: MockInstance; + + beforeEach(() => { + vi.clearAllMocks(); + mockHandler = vi.fn(); + // Default setup: OAuth user, Pro model failed, handler injected + mockConfig = createMockConfig({ + fallbackModelHandler: mockHandler, + }); + consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + consoleErrorSpy.mockRestore(); + }); + + it('should return null immediately if authType is not OAuth', async () => { + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_API_KEY, + ); + expect(result).toBeNull(); + expect(mockHandler).not.toHaveBeenCalled(); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); + + it('should return null if the failed model is already the fallback model', async () => { + const result = await handleFallback( + mockConfig, + FALLBACK_MODEL, // Failed model is Flash + AUTH_OAUTH, + ); + expect(result).toBeNull(); + expect(mockHandler).not.toHaveBeenCalled(); + }); + + it('should return null if no fallbackHandler is injected in config', async () => { + const configWithoutHandler = createMockConfig({ + fallbackModelHandler: undefined, + }); + const result = await handleFallback( + configWithoutHandler, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + expect(result).toBeNull(); + }); + + describe('when handler returns "retry"', () => { + it('should activate fallback mode, log telemetry, and return true', async () => { + mockHandler.mockResolvedValue('retry'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(true); + expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true); + expect(logFlashFallback).toHaveBeenCalled(); + }); + }); + + describe('when handler returns "stop"', () => { + it('should activate fallback mode, log telemetry, and return false', async () => { + mockHandler.mockResolvedValue('stop'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(false); + expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true); + expect(logFlashFallback).toHaveBeenCalled(); + }); + }); + + describe('when handler returns "auth"', () => { + it('should NOT activate fallback mode and return false', async () => { + mockHandler.mockResolvedValue('auth'); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBe(false); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + expect(logFlashFallback).not.toHaveBeenCalled(); + }); + }); + + describe('when handler returns an unexpected value', () => { + it('should log an error and return null', async () => { + mockHandler.mockResolvedValue(null); + + const result = await handleFallback( + mockConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + expect(result).toBeNull(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Fallback UI handler failed:', + new Error( + 'Unexpected fallback intent received from fallbackModelHandler: "null"', + ), + ); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); + }); + + it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => { + const mockError = new Error('Quota Exceeded'); + mockHandler.mockResolvedValue('retry'); + + await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError); + + expect(mockHandler).toHaveBeenCalledWith( + MOCK_PRO_MODEL, + FALLBACK_MODEL, + mockError, + ); + }); + + it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => { + // Setup config where fallback mode is already active + const activeFallbackConfig = createMockConfig({ + fallbackModelHandler: mockHandler, + isInFallbackMode: vi.fn(() => true), // Already active + setFallbackMode: vi.fn(), + }); + + mockHandler.mockResolvedValue('retry'); + + const result = await handleFallback( + activeFallbackConfig, + MOCK_PRO_MODEL, + AUTH_OAUTH, + ); + + // Should still return true to allow the retry (which will use the active fallback mode) + expect(result).toBe(true); + // Should still consult the handler + expect(mockHandler).toHaveBeenCalled(); + // But should not mutate state or log telemetry again + expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled(); + expect(logFlashFallback).not.toHaveBeenCalled(); + }); + + it('should catch errors from the handler, log an error, and return null', async () => { + const handlerError = new Error('UI interaction failed'); + mockHandler.mockRejectedValue(handlerError); + + const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH); + + expect(result).toBeNull(); + expect(consoleErrorSpy).toHaveBeenCalledWith( + 'Fallback UI handler failed:', + handlerError, + ); + expect(mockConfig.setFallbackMode).not.toHaveBeenCalled(); + }); +}); diff --git a/packages/core/src/fallback/handler.ts b/packages/core/src/fallback/handler.ts new file mode 100644 index 0000000000..762552cd2d --- /dev/null +++ b/packages/core/src/fallback/handler.ts @@ -0,0 +1,69 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Config } from '../config/config.js'; +import { AuthType } from '../core/contentGenerator.js'; +import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; +import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js'; + +export async function handleFallback( + config: Config, + failedModel: string, + authType?: string, + error?: unknown, +): Promise { + // Applicability Checks + if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null; + + const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; + + if (failedModel === fallbackModel) return null; + + // Consult UI Handler for Intent + const fallbackModelHandler = config.fallbackModelHandler; + if (typeof fallbackModelHandler !== 'function') return null; + + try { + // Pass the specific failed model to the UI handler. + const intent = await fallbackModelHandler( + failedModel, + fallbackModel, + error, + ); + + // Process Intent and Update State + switch (intent) { + case 'retry': + // Activate fallback mode. The NEXT retry attempt will pick this up. + activateFallbackMode(config, authType); + return true; // Signal retryWithBackoff to continue. + + case 'stop': + activateFallbackMode(config, authType); + return false; + + case 'auth': + return false; + + default: + throw new Error( + `Unexpected fallback intent received from fallbackModelHandler: "${intent}"`, + ); + } + } catch (handlerError) { + console.error('Fallback UI handler failed:', handlerError); + return null; + } +} + +function activateFallbackMode(config: Config, authType: string | undefined) { + if (!config.isInFallbackMode()) { + config.setFallbackMode(true); + if (authType) { + logFlashFallback(config, new FlashFallbackEvent(authType)); + } + } +} diff --git a/packages/core/src/fallback/types.ts b/packages/core/src/fallback/types.ts new file mode 100644 index 0000000000..6543123371 --- /dev/null +++ b/packages/core/src/fallback/types.ts @@ -0,0 +1,23 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * Defines the intent returned by the UI layer during a fallback scenario. + */ +export type FallbackIntent = + | 'retry' // Immediately retry the current request with the fallback model. + | 'stop' // Switch to fallback for future requests, but stop the current request. + | 'auth'; // Stop the current request; user intends to change authentication. + +/** + * The interface for the handler provided by the UI layer (e.g., the CLI) + * to interact with the user during a fallback scenario. + */ +export type FallbackModelHandler = ( + failedModel: string, + fallbackModel: string, + error?: unknown, +) => Promise; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 4b8d3aa3e8..047e43a529 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,8 @@ export * from './core/geminiRequest.js'; export * from './core/coreToolScheduler.js'; export * from './core/nonInteractiveToolExecutor.js'; +export * from './fallback/types.js'; + export * from './code_assist/codeAssist.js'; export * from './code_assist/oauth2.js'; export * from './code_assist/server.js'; diff --git a/packages/core/src/utils/flashFallback.integration.test.ts b/packages/core/src/utils/flashFallback.test.ts similarity index 62% rename from packages/core/src/utils/flashFallback.integration.test.ts rename to packages/core/src/utils/flashFallback.test.ts index 9211ad2f20..6d4330f1ad 100644 --- a/packages/core/src/utils/flashFallback.integration.test.ts +++ b/packages/core/src/utils/flashFallback.test.ts @@ -17,10 +17,13 @@ import { import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { retryWithBackoff } from './retry.js'; import { AuthType } from '../core/contentGenerator.js'; +// Import the new types (Assuming this test file is in packages/core/src/utils/) +import type { FallbackModelHandler } from '../fallback/types.js'; vi.mock('node:fs'); -describe('Flash Fallback Integration', () => { +// Update the description to reflect that this tests the retry utility's integration +describe('Retry Utility Fallback Integration', () => { let config: Config; beforeEach(() => { @@ -41,25 +44,28 @@ describe('Flash Fallback Integration', () => { resetRequestCounter(); }); - it('should automatically accept fallback', async () => { - // Set up a minimal flash fallback handler for testing - const flashFallbackHandler = async (): Promise => true; + // This test validates the Config's ability to store and execute the handler contract. + it('should execute the injected FallbackHandler contract correctly', async () => { + // Set up a minimal handler for testing, ensuring it matches the new type. + const fallbackHandler: FallbackModelHandler = async () => 'retry'; - config.setFlashFallbackHandler(flashFallbackHandler); + // Use the generalized setter + config.setFallbackModelHandler(fallbackHandler); - // Call the handler directly to test - const result = await config.flashFallbackHandler!( + // Call the handler directly via the config property + const result = await config.fallbackModelHandler!( 'gemini-2.5-pro', DEFAULT_GEMINI_FLASH_MODEL, ); - // Verify it automatically accepts - expect(result).toBe(true); + // Verify it returns the correct intent + expect(result).toBe('retry'); }); - it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => { + // This test validates the retry utility's logic for triggering the callback. + it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => { let fallbackCalled = false; - let fallbackModel = ''; + // Removed fallbackModel variable as it's no longer relevant here. // Mock function that simulates exactly 2 429 errors, then succeeds after fallback const mockApiCall = vi @@ -68,11 +74,11 @@ describe('Flash Fallback Integration', () => { .mockRejectedValueOnce(createSimulated429Error()) .mockResolvedValueOnce('success after fallback'); - // Mock fallback handler - const mockFallbackHandler = vi.fn(async (_authType?: string) => { + // Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides) + const mockPersistent429Callback = vi.fn(async (_authType?: string) => { fallbackCalled = true; - fallbackModel = DEFAULT_GEMINI_FLASH_MODEL; - return fallbackModel; + // Return true to signal retryWithBackoff to reset attempts and continue. + return true; }); // Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers @@ -84,14 +90,13 @@ describe('Flash Fallback Integration', () => { const status = (error as Error & { status?: number }).status; return status === 429; }, - onPersistent429: mockFallbackHandler, + onPersistent429: mockPersistent429Callback, authType: AuthType.LOGIN_WITH_GOOGLE, }); - // Verify fallback was triggered + // Verify fallback mechanism was triggered expect(fallbackCalled).toBe(true); - expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL); - expect(mockFallbackHandler).toHaveBeenCalledWith( + expect(mockPersistent429Callback).toHaveBeenCalledWith( AuthType.LOGIN_WITH_GOOGLE, expect.any(Error), ); @@ -100,16 +105,16 @@ describe('Flash Fallback Integration', () => { expect(mockApiCall).toHaveBeenCalledTimes(3); }); - it('should not trigger fallback for API key users', async () => { + it('should not trigger onPersistent429 for API key users', async () => { let fallbackCalled = false; // Mock function that simulates 429 errors const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error()); - // Mock fallback handler - const mockFallbackHandler = vi.fn(async () => { + // Mock the callback + const mockPersistent429Callback = vi.fn(async () => { fallbackCalled = true; - return DEFAULT_GEMINI_FLASH_MODEL; + return true; }); // Test with API key auth type - should not trigger fallback @@ -122,7 +127,7 @@ describe('Flash Fallback Integration', () => { const status = (error as Error & { status?: number }).status; return status === 429; }, - onPersistent429: mockFallbackHandler, + onPersistent429: mockPersistent429Callback, authType: AuthType.USE_GEMINI, // API key auth type }); } catch (error) { @@ -132,10 +137,11 @@ describe('Flash Fallback Integration', () => { // Verify fallback was NOT triggered for API key users expect(fallbackCalled).toBe(false); - expect(mockFallbackHandler).not.toHaveBeenCalled(); + expect(mockPersistent429Callback).not.toHaveBeenCalled(); }); - it('should properly disable simulation state after fallback', () => { + // This test validates the test utilities themselves. + it('should properly disable simulation state after fallback (Test Utility)', () => { // Enable simulation setSimulate429(true);