mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-29 22:44:45 -07:00
refactor: Centralize and improve model fallback handling (#7634)
This commit is contained in:
@@ -88,12 +88,11 @@ export class Task {
|
||||
this.eventBus = eventBus;
|
||||
this.completedToolCalls = [];
|
||||
this._resetToolCompletionPromise();
|
||||
this.config.setFlashFallbackHandler(
|
||||
async (currentModel: string, fallbackModel: string): Promise<boolean> => {
|
||||
config.setModel(fallbackModel); // gemini-cli-core sets to DEFAULT_GEMINI_FLASH_MODEL
|
||||
// Switch model for future use but return false to stop current retry
|
||||
return false;
|
||||
},
|
||||
this.config.setFallbackModelHandler(
|
||||
// For a2a-server, we want to automatically switch to the fallback model
|
||||
// for future requests without retrying the current one. The 'stop'
|
||||
// intent achieves this.
|
||||
async () => 'stop',
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ export function createMockConfig(
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
setFlashFallbackHandler: vi.fn(),
|
||||
setFallbackModelHandler: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
getProxy: vi.fn().mockReturnValue(undefined),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
|
||||
@@ -4,31 +4,55 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { render, cleanup } from 'ink-testing-library';
|
||||
import { AppContainer } from './AppContainer.js';
|
||||
import { type Config, makeFakeConfig } from '@google/gemini-cli-core';
|
||||
import type { LoadedSettings } from '../config/settings.js';
|
||||
import type { InitializationResult } from '../core/initializer.js';
|
||||
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
|
||||
import { UIStateContext, type UIState } from './contexts/UIStateContext.js';
|
||||
import {
|
||||
UIActionsContext,
|
||||
type UIActions,
|
||||
} from './contexts/UIActionsContext.js';
|
||||
import { useContext } from 'react';
|
||||
|
||||
// Helper component will read the context values provided by AppContainer
|
||||
// so we can assert against them in our tests.
|
||||
let capturedUIState: UIState;
|
||||
let capturedUIActions: UIActions;
|
||||
function TestContextConsumer() {
|
||||
capturedUIState = useContext(UIStateContext)!;
|
||||
capturedUIActions = useContext(UIActionsContext)!;
|
||||
return null;
|
||||
}
|
||||
|
||||
// Mock App component to isolate AppContainer testing
|
||||
vi.mock('./App.js', () => ({
|
||||
App: () => 'App Component',
|
||||
App: TestContextConsumer,
|
||||
}));
|
||||
|
||||
// Mock all the hooks and utilities
|
||||
vi.mock('./hooks/useHistory.js');
|
||||
vi.mock('./hooks/useQuotaAndFallback.js');
|
||||
vi.mock('./hooks/useHistoryManager.js');
|
||||
vi.mock('./hooks/useThemeCommand.js');
|
||||
vi.mock('./hooks/useAuthCommand.js');
|
||||
vi.mock('./auth/useAuth.js');
|
||||
vi.mock('./hooks/useEditorSettings.js');
|
||||
vi.mock('./hooks/useSettingsCommand.js');
|
||||
vi.mock('./hooks/useSlashCommandProcessor.js');
|
||||
vi.mock('./hooks/slashCommandProcessor.js');
|
||||
vi.mock('./hooks/useConsoleMessages.js');
|
||||
vi.mock('./hooks/useTerminalSize.js', () => ({
|
||||
useTerminalSize: vi.fn(() => ({ columns: 80, rows: 24 })),
|
||||
}));
|
||||
vi.mock('./hooks/useGeminiStream.js');
|
||||
vi.mock('./hooks/useVim.js');
|
||||
vi.mock('./hooks/vim.js');
|
||||
vi.mock('./hooks/useFocus.js');
|
||||
vi.mock('./hooks/useBracketedPaste.js');
|
||||
vi.mock('./hooks/useKeypress.js');
|
||||
@@ -40,7 +64,7 @@ vi.mock('./hooks/useWorkspaceMigration.js');
|
||||
vi.mock('./hooks/useGitBranchName.js');
|
||||
vi.mock('./contexts/VimModeContext.js');
|
||||
vi.mock('./contexts/SessionContext.js');
|
||||
vi.mock('./hooks/useTextBuffer.js');
|
||||
vi.mock('./components/shared/text-buffer.js');
|
||||
vi.mock('./hooks/useLogger.js');
|
||||
|
||||
// Mock external utilities
|
||||
@@ -49,14 +73,153 @@ vi.mock('../utils/handleAutoUpdate.js');
|
||||
vi.mock('./utils/ConsolePatcher.js');
|
||||
vi.mock('../utils/cleanup.js');
|
||||
|
||||
import { useHistory } from './hooks/useHistoryManager.js';
|
||||
import { useThemeCommand } from './hooks/useThemeCommand.js';
|
||||
import { useAuthCommand } from './auth/useAuth.js';
|
||||
import { useEditorSettings } from './hooks/useEditorSettings.js';
|
||||
import { useSettingsCommand } from './hooks/useSettingsCommand.js';
|
||||
import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js';
|
||||
import { useConsoleMessages } from './hooks/useConsoleMessages.js';
|
||||
import { useGeminiStream } from './hooks/useGeminiStream.js';
|
||||
import { useVim } from './hooks/vim.js';
|
||||
import { useFolderTrust } from './hooks/useFolderTrust.js';
|
||||
import { useMessageQueue } from './hooks/useMessageQueue.js';
|
||||
import { useAutoAcceptIndicator } from './hooks/useAutoAcceptIndicator.js';
|
||||
import { useWorkspaceMigration } from './hooks/useWorkspaceMigration.js';
|
||||
import { useGitBranchName } from './hooks/useGitBranchName.js';
|
||||
import { useVimMode } from './contexts/VimModeContext.js';
|
||||
import { useSessionStats } from './contexts/SessionContext.js';
|
||||
import { useTextBuffer } from './components/shared/text-buffer.js';
|
||||
import { useLogger } from './hooks/useLogger.js';
|
||||
import { useLoadingIndicator } from './hooks/useLoadingIndicator.js';
|
||||
|
||||
describe('AppContainer State Management', () => {
|
||||
let mockConfig: Config;
|
||||
let mockSettings: LoadedSettings;
|
||||
let mockInitResult: InitializationResult;
|
||||
|
||||
// Create typed mocks for all hooks
|
||||
const mockedUseQuotaAndFallback = useQuotaAndFallback as Mock;
|
||||
const mockedUseHistory = useHistory as Mock;
|
||||
const mockedUseThemeCommand = useThemeCommand as Mock;
|
||||
const mockedUseAuthCommand = useAuthCommand as Mock;
|
||||
const mockedUseEditorSettings = useEditorSettings as Mock;
|
||||
const mockedUseSettingsCommand = useSettingsCommand as Mock;
|
||||
const mockedUseSlashCommandProcessor = useSlashCommandProcessor as Mock;
|
||||
const mockedUseConsoleMessages = useConsoleMessages as Mock;
|
||||
const mockedUseGeminiStream = useGeminiStream as Mock;
|
||||
const mockedUseVim = useVim as Mock;
|
||||
const mockedUseFolderTrust = useFolderTrust as Mock;
|
||||
const mockedUseMessageQueue = useMessageQueue as Mock;
|
||||
const mockedUseAutoAcceptIndicator = useAutoAcceptIndicator as Mock;
|
||||
const mockedUseWorkspaceMigration = useWorkspaceMigration as Mock;
|
||||
const mockedUseGitBranchName = useGitBranchName as Mock;
|
||||
const mockedUseVimMode = useVimMode as Mock;
|
||||
const mockedUseSessionStats = useSessionStats as Mock;
|
||||
const mockedUseTextBuffer = useTextBuffer as Mock;
|
||||
const mockedUseLogger = useLogger as Mock;
|
||||
const mockedUseLoadingIndicator = useLoadingIndicator as Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
capturedUIState = null!;
|
||||
capturedUIActions = null!;
|
||||
|
||||
// **Provide a default return value for EVERY mocked hook.**
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: null,
|
||||
handleProQuotaChoice: vi.fn(),
|
||||
});
|
||||
mockedUseHistory.mockReturnValue({
|
||||
history: [],
|
||||
addItem: vi.fn(),
|
||||
updateItem: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
loadHistory: vi.fn(),
|
||||
});
|
||||
mockedUseThemeCommand.mockReturnValue({
|
||||
isThemeDialogOpen: false,
|
||||
openThemeDialog: vi.fn(),
|
||||
handleThemeSelect: vi.fn(),
|
||||
handleThemeHighlight: vi.fn(),
|
||||
});
|
||||
mockedUseAuthCommand.mockReturnValue({
|
||||
authState: 'authenticated',
|
||||
setAuthState: vi.fn(),
|
||||
authError: null,
|
||||
onAuthError: vi.fn(),
|
||||
});
|
||||
mockedUseEditorSettings.mockReturnValue({
|
||||
isEditorDialogOpen: false,
|
||||
openEditorDialog: vi.fn(),
|
||||
handleEditorSelect: vi.fn(),
|
||||
exitEditorDialog: vi.fn(),
|
||||
});
|
||||
mockedUseSettingsCommand.mockReturnValue({
|
||||
isSettingsDialogOpen: false,
|
||||
openSettingsDialog: vi.fn(),
|
||||
closeSettingsDialog: vi.fn(),
|
||||
});
|
||||
mockedUseSlashCommandProcessor.mockReturnValue({
|
||||
handleSlashCommand: vi.fn(),
|
||||
slashCommands: [],
|
||||
pendingHistoryItems: [],
|
||||
commandContext: {},
|
||||
shellConfirmationRequest: null,
|
||||
confirmationRequest: null,
|
||||
});
|
||||
mockedUseConsoleMessages.mockReturnValue({
|
||||
consoleMessages: [],
|
||||
handleNewMessage: vi.fn(),
|
||||
clearConsoleMessages: vi.fn(),
|
||||
});
|
||||
mockedUseGeminiStream.mockReturnValue({
|
||||
streamingState: 'idle',
|
||||
submitQuery: vi.fn(),
|
||||
initError: null,
|
||||
pendingHistoryItems: [],
|
||||
thought: null,
|
||||
cancelOngoingRequest: vi.fn(),
|
||||
});
|
||||
mockedUseVim.mockReturnValue({ handleInput: vi.fn() });
|
||||
mockedUseFolderTrust.mockReturnValue({
|
||||
isFolderTrustDialogOpen: false,
|
||||
handleFolderTrustSelect: vi.fn(),
|
||||
isRestarting: false,
|
||||
});
|
||||
mockedUseMessageQueue.mockReturnValue({
|
||||
messageQueue: [],
|
||||
addMessage: vi.fn(),
|
||||
clearQueue: vi.fn(),
|
||||
getQueuedMessagesText: vi.fn().mockReturnValue(''),
|
||||
});
|
||||
mockedUseAutoAcceptIndicator.mockReturnValue(false);
|
||||
mockedUseWorkspaceMigration.mockReturnValue({
|
||||
showWorkspaceMigrationDialog: false,
|
||||
workspaceExtensions: [],
|
||||
onWorkspaceMigrationDialogOpen: vi.fn(),
|
||||
onWorkspaceMigrationDialogClose: vi.fn(),
|
||||
});
|
||||
mockedUseGitBranchName.mockReturnValue('main');
|
||||
mockedUseVimMode.mockReturnValue({
|
||||
isVimEnabled: false,
|
||||
toggleVimEnabled: vi.fn(),
|
||||
});
|
||||
mockedUseSessionStats.mockReturnValue({ stats: {} });
|
||||
mockedUseTextBuffer.mockReturnValue({
|
||||
text: '',
|
||||
setText: vi.fn(),
|
||||
// Add other properties if AppContainer uses them
|
||||
});
|
||||
mockedUseLogger.mockReturnValue({
|
||||
getPreviousUserMessages: vi.fn().mockResolvedValue([]),
|
||||
});
|
||||
mockedUseLoadingIndicator.mockReturnValue({
|
||||
elapsedTime: '0.0s',
|
||||
currentLoadingPhrase: '',
|
||||
});
|
||||
|
||||
// Mock Config
|
||||
mockConfig = makeFakeConfig();
|
||||
|
||||
@@ -325,7 +488,73 @@ describe('AppContainer State Management', () => {
|
||||
expect(() => unmount()).not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// TODO: Add comprehensive integration test once all hook mocks are complete
|
||||
// For now, the 14 passing unit tests provide good coverage of AppContainer functionality
|
||||
describe('Quota and Fallback Integration', () => {
|
||||
it('passes a null proQuotaRequest to UIStateContext by default', () => {
|
||||
// The default mock from beforeEach already sets proQuotaRequest to null
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert that the context value is as expected
|
||||
expect(capturedUIState.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('passes a valid proQuotaRequest to UIStateContext when provided by the hook', () => {
|
||||
// Arrange: Create a mock request object that a UI dialog would receive
|
||||
const mockRequest = {
|
||||
failedModel: 'gemini-pro',
|
||||
fallbackModel: 'gemini-flash',
|
||||
resolve: vi.fn(),
|
||||
};
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: mockRequest,
|
||||
handleProQuotaChoice: vi.fn(),
|
||||
});
|
||||
|
||||
// Act: Render the container
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert: The mock request is correctly passed through the context
|
||||
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
|
||||
});
|
||||
|
||||
it('passes the handleProQuotaChoice function to UIActionsContext', () => {
|
||||
// Arrange: Create a mock handler function
|
||||
const mockHandler = vi.fn();
|
||||
mockedUseQuotaAndFallback.mockReturnValue({
|
||||
proQuotaRequest: null,
|
||||
handleProQuotaChoice: mockHandler,
|
||||
});
|
||||
|
||||
// Act: Render the container
|
||||
render(
|
||||
<AppContainer
|
||||
config={mockConfig}
|
||||
settings={mockSettings}
|
||||
version="1.0.0"
|
||||
initializationResult={mockInitResult}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert: The action in the context is the mock handler we provided
|
||||
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
|
||||
|
||||
// You can even verify that the plumbed function is callable
|
||||
capturedUIActions.handleProQuotaChoice('auth');
|
||||
expect(mockHandler).toHaveBeenCalledWith('auth');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -24,18 +24,15 @@ import { MessageType, StreamingState } from './types.js';
|
||||
import {
|
||||
type EditorType,
|
||||
type Config,
|
||||
IdeClient,
|
||||
type DetectedIde,
|
||||
ideContext,
|
||||
type IdeContext,
|
||||
type UserTierId,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
IdeClient,
|
||||
ideContext,
|
||||
getErrorMessage,
|
||||
getAllGeminiMdFilenames,
|
||||
UserTierId,
|
||||
AuthType,
|
||||
isProQuotaExceededError,
|
||||
isGenericQuotaExceededError,
|
||||
logFlashFallback,
|
||||
FlashFallbackEvent,
|
||||
clearCachedCredentialFile,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { validateAuthMethod } from '../config/auth.js';
|
||||
@@ -44,6 +41,7 @@ import process from 'node:process';
|
||||
import { useHistory } from './hooks/useHistoryManager.js';
|
||||
import { useThemeCommand } from './hooks/useThemeCommand.js';
|
||||
import { useAuthCommand } from './auth/useAuth.js';
|
||||
import { useQuotaAndFallback } from './hooks/useQuotaAndFallback.js';
|
||||
import { useEditorSettings } from './hooks/useEditorSettings.js';
|
||||
import { useSettingsCommand } from './hooks/useSettingsCommand.js';
|
||||
import { useSlashCommandProcessor } from './hooks/slashCommandProcessor.js';
|
||||
@@ -123,12 +121,18 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
const [isTrustedFolder, setIsTrustedFolder] = useState<boolean | undefined>(
|
||||
config.isTrustedFolder(),
|
||||
);
|
||||
const [currentModel, setCurrentModel] = useState(config.getModel());
|
||||
|
||||
// Helper to determine the effective model, considering the fallback state.
|
||||
const getEffectiveModel = useCallback(() => {
|
||||
if (config.isInFallbackMode()) {
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
}
|
||||
return config.getModel();
|
||||
}, [config]);
|
||||
|
||||
const [currentModel, setCurrentModel] = useState(getEffectiveModel());
|
||||
|
||||
const [userTier, setUserTier] = useState<UserTierId | undefined>(undefined);
|
||||
const [isProQuotaDialogOpen, setIsProQuotaDialogOpen] = useState(false);
|
||||
const [proQuotaDialogResolver, setProQuotaDialogResolver] = useState<
|
||||
((value: boolean) => void) | null
|
||||
>(null);
|
||||
|
||||
// Auto-accept indicator
|
||||
const showAutoAcceptIndicator = useAutoAcceptIndicator({
|
||||
@@ -167,18 +171,17 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
// Watch for model changes (e.g., from Flash fallback)
|
||||
useEffect(() => {
|
||||
const checkModelChange = () => {
|
||||
const configModel = config.getModel();
|
||||
if (configModel !== currentModel) {
|
||||
setCurrentModel(configModel);
|
||||
const effectiveModel = getEffectiveModel();
|
||||
if (effectiveModel !== currentModel) {
|
||||
setCurrentModel(effectiveModel);
|
||||
}
|
||||
};
|
||||
|
||||
// Check immediately and then periodically
|
||||
checkModelChange();
|
||||
const interval = setInterval(checkModelChange, 1000); // Check every second
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, [config, currentModel]);
|
||||
}, [config, currentModel, getEffectiveModel]);
|
||||
|
||||
const {
|
||||
consoleMessages,
|
||||
@@ -273,6 +276,14 @@ export const AppContainer = (props: AppContainerProps) => {
|
||||
config,
|
||||
);
|
||||
|
||||
const { proQuotaRequest, handleProQuotaChoice } = useQuotaAndFallback({
|
||||
config,
|
||||
historyManager,
|
||||
userTier,
|
||||
setAuthState,
|
||||
setModelSwitchedFromQuotaError,
|
||||
});
|
||||
|
||||
// Derive auth state variables for backward compatibility with UIStateContext
|
||||
const isAuthDialogOpen = authState === AuthState.Updating;
|
||||
const isAuthenticating = authState === AuthState.Unauthenticated;
|
||||
@@ -477,132 +488,6 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
}
|
||||
}, [config, historyManager, settings.merged]);
|
||||
|
||||
// Set up Flash fallback handler
|
||||
useEffect(() => {
|
||||
const flashFallbackHandler = async (
|
||||
currentModel: string,
|
||||
fallbackModel: string,
|
||||
error?: unknown,
|
||||
): Promise<boolean> => {
|
||||
// Check if we've already switched to the fallback model
|
||||
if (config.isInFallbackMode()) {
|
||||
// If we're already in fallback mode, don't show the dialog again
|
||||
return false;
|
||||
}
|
||||
|
||||
let message: string;
|
||||
|
||||
if (
|
||||
config.getContentGeneratorConfig().authType ===
|
||||
AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
// Use actual user tier if available; otherwise, default to FREE tier behavior (safe default)
|
||||
const isPaidTier =
|
||||
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
|
||||
|
||||
// Check if this is a Pro quota exceeded error
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
if (isPaidTier) {
|
||||
message = `⚡ You have reached your daily ${currentModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `⚡ You have reached your daily ${currentModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else if (error && isGenericQuotaExceededError(error)) {
|
||||
if (isPaidTier) {
|
||||
message = `⚡ You have reached your daily quota limit.
|
||||
⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session.
|
||||
⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `⚡ You have reached your daily quota limit.
|
||||
⚡ Automatically switching from ${currentModel} to ${fallbackModel} for the remainder of this session.
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else {
|
||||
if (isPaidTier) {
|
||||
// Default fallback message for other cases (like consecutive 429s)
|
||||
message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session.
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit
|
||||
⚡ To continue accessing the ${currentModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
// Default fallback message for other cases (like consecutive 429s)
|
||||
message = `⚡ Automatically switching from ${currentModel} to ${fallbackModel} for faster responses for the remainder of this session.
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${currentModel} quota limit
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
}
|
||||
|
||||
// Add message to UI history
|
||||
historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: message,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// For Pro quota errors, show the dialog and wait for user's choice
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
// Set the flag to prevent tool continuation
|
||||
setModelSwitchedFromQuotaError(true);
|
||||
// Set global quota error flag to prevent Flash model calls
|
||||
config.setQuotaErrorOccurred(true);
|
||||
|
||||
// Show the ProQuotaDialog and wait for user's choice
|
||||
const shouldContinueWithFallback = await new Promise<boolean>(
|
||||
(resolve) => {
|
||||
setIsProQuotaDialogOpen(true);
|
||||
setProQuotaDialogResolver(() => resolve);
|
||||
},
|
||||
);
|
||||
|
||||
// If user chose to continue with fallback, we don't need to stop the current prompt
|
||||
if (shouldContinueWithFallback) {
|
||||
// Switch to fallback model for future use
|
||||
config.setModel(fallbackModel);
|
||||
config.setFallbackMode(true);
|
||||
logFlashFallback(
|
||||
config,
|
||||
new FlashFallbackEvent(
|
||||
config.getContentGeneratorConfig().authType!,
|
||||
),
|
||||
);
|
||||
return true; // Continue with current prompt using fallback model
|
||||
}
|
||||
|
||||
// If user chose to authenticate, stop current prompt
|
||||
return false;
|
||||
}
|
||||
|
||||
// For other quota errors, automatically switch to fallback model
|
||||
// Set the flag to prevent tool continuation
|
||||
setModelSwitchedFromQuotaError(true);
|
||||
// Set global quota error flag to prevent Flash model calls
|
||||
config.setQuotaErrorOccurred(true);
|
||||
}
|
||||
|
||||
// Switch model for future use but return false to stop current retry
|
||||
config.setModel(fallbackModel);
|
||||
config.setFallbackMode(true);
|
||||
logFlashFallback(
|
||||
config,
|
||||
new FlashFallbackEvent(config.getContentGeneratorConfig().authType!),
|
||||
);
|
||||
return false; // Don't continue with current prompt
|
||||
};
|
||||
|
||||
config.setFlashFallbackHandler(flashFallbackHandler);
|
||||
}, [config, historyManager, userTier]);
|
||||
|
||||
const cancelHandlerRef = useRef<() => void>(() => {});
|
||||
|
||||
const {
|
||||
@@ -681,22 +566,6 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
refreshStatic();
|
||||
}, [historyManager, clearConsoleMessagesState, refreshStatic]);
|
||||
|
||||
const handleProQuotaChoice = useCallback(
|
||||
(choice: 'auth' | 'continue') => {
|
||||
setIsProQuotaDialogOpen(false);
|
||||
if (proQuotaDialogResolver) {
|
||||
if (choice === 'auth') {
|
||||
proQuotaDialogResolver(false); // Don't continue with fallback, show auth dialog
|
||||
setAuthState(AuthState.Updating);
|
||||
} else {
|
||||
proQuotaDialogResolver(true); // Continue with fallback model
|
||||
}
|
||||
setProQuotaDialogResolver(null);
|
||||
}
|
||||
},
|
||||
[proQuotaDialogResolver, setAuthState],
|
||||
);
|
||||
|
||||
const { handleInput: vimHandleInput } = useVim(buffer, handleFinalSubmit);
|
||||
|
||||
/**
|
||||
@@ -712,7 +581,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
!isProcessing &&
|
||||
(streamingState === StreamingState.Idle ||
|
||||
streamingState === StreamingState.Responding) &&
|
||||
!isProQuotaDialogOpen;
|
||||
!proQuotaRequest;
|
||||
|
||||
// Compute available terminal height based on controls measurement
|
||||
const availableTerminalHeight = useMemo(() => {
|
||||
@@ -1029,7 +898,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
isAuthDialogOpen ||
|
||||
isEditorDialogOpen ||
|
||||
showPrivacyNotice ||
|
||||
isProQuotaDialogOpen,
|
||||
!!proQuotaRequest,
|
||||
[
|
||||
showWorkspaceMigrationDialog,
|
||||
shouldShowIdePrompt,
|
||||
@@ -1042,7 +911,7 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
isAuthDialogOpen,
|
||||
isEditorDialogOpen,
|
||||
showPrivacyNotice,
|
||||
isProQuotaDialogOpen,
|
||||
proQuotaRequest,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -1101,11 +970,9 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
showAutoAcceptIndicator,
|
||||
showWorkspaceMigrationDialog,
|
||||
workspaceExtensions,
|
||||
// Use current state values instead of config.getModel()
|
||||
currentModel,
|
||||
userTier,
|
||||
isProQuotaDialogOpen,
|
||||
// New fields
|
||||
proQuotaRequest,
|
||||
contextFileNames,
|
||||
errorCount,
|
||||
availableTerminalHeight,
|
||||
@@ -1174,10 +1041,8 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
showAutoAcceptIndicator,
|
||||
showWorkspaceMigrationDialog,
|
||||
workspaceExtensions,
|
||||
// Quota-related state dependencies
|
||||
userTier,
|
||||
isProQuotaDialogOpen,
|
||||
// New fields dependencies
|
||||
proQuotaRequest,
|
||||
contextFileNames,
|
||||
errorCount,
|
||||
availableTerminalHeight,
|
||||
@@ -1196,7 +1061,6 @@ Logging in with Google... Please restart Gemini CLI to continue.
|
||||
updateInfo,
|
||||
showIdeRestartPrompt,
|
||||
isRestarting,
|
||||
// Quota-related dependencies
|
||||
currentModel,
|
||||
],
|
||||
);
|
||||
|
||||
@@ -22,7 +22,6 @@ import { useUIState } from '../contexts/UIStateContext.js';
|
||||
import { useUIActions } from '../contexts/UIActionsContext.js';
|
||||
import { useConfig } from '../contexts/ConfigContext.js';
|
||||
import { useSettings } from '../contexts/SettingsContext.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '@google/gemini-cli-core';
|
||||
import process from 'node:process';
|
||||
|
||||
// Props for DialogManager
|
||||
@@ -54,11 +53,11 @@ export const DialogManager = () => {
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (uiState.isProQuotaDialogOpen) {
|
||||
if (uiState.proQuotaRequest) {
|
||||
return (
|
||||
<ProQuotaDialog
|
||||
currentModel={uiState.currentModel}
|
||||
fallbackModel={DEFAULT_GEMINI_FLASH_MODEL}
|
||||
failedModel={uiState.proQuotaRequest.failedModel}
|
||||
fallbackModel={uiState.proQuotaRequest.fallbackModel}
|
||||
onChoice={uiActions.handleProQuotaChoice}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -22,7 +22,7 @@ describe('ProQuotaDialog', () => {
|
||||
it('should render with correct title and options', () => {
|
||||
const { lastFrame } = render(
|
||||
<ProQuotaDialog
|
||||
currentModel="gemini-2.5-pro"
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={() => {}}
|
||||
/>,
|
||||
@@ -53,7 +53,7 @@ describe('ProQuotaDialog', () => {
|
||||
const mockOnChoice = vi.fn();
|
||||
render(
|
||||
<ProQuotaDialog
|
||||
currentModel="gemini-2.5-pro"
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={mockOnChoice}
|
||||
/>,
|
||||
@@ -72,7 +72,7 @@ describe('ProQuotaDialog', () => {
|
||||
const mockOnChoice = vi.fn();
|
||||
render(
|
||||
<ProQuotaDialog
|
||||
currentModel="gemini-2.5-pro"
|
||||
failedModel="gemini-2.5-pro"
|
||||
fallbackModel="gemini-2.5-flash"
|
||||
onChoice={mockOnChoice}
|
||||
/>,
|
||||
|
||||
@@ -10,13 +10,13 @@ import { RadioButtonSelect } from './shared/RadioButtonSelect.js';
|
||||
import { Colors } from '../colors.js';
|
||||
|
||||
interface ProQuotaDialogProps {
|
||||
currentModel: string;
|
||||
failedModel: string;
|
||||
fallbackModel: string;
|
||||
onChoice: (choice: 'auth' | 'continue') => void;
|
||||
}
|
||||
|
||||
export function ProQuotaDialog({
|
||||
currentModel,
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
onChoice,
|
||||
}: ProQuotaDialogProps): React.JSX.Element {
|
||||
@@ -38,7 +38,7 @@ export function ProQuotaDialog({
|
||||
return (
|
||||
<Box borderStyle="round" flexDirection="column" paddingX={1}>
|
||||
<Text bold color={Colors.AccentYellow}>
|
||||
Pro quota limit reached for {currentModel}.
|
||||
Pro quota limit reached for {failedModel}.
|
||||
</Text>
|
||||
<Box marginTop={1}>
|
||||
<RadioButtonSelect
|
||||
|
||||
@@ -21,11 +21,18 @@ import type {
|
||||
ApprovalMode,
|
||||
UserTierId,
|
||||
DetectedIde,
|
||||
FallbackIntent,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { DOMElement } from 'ink';
|
||||
import type { SessionStatsState } from '../contexts/SessionContext.js';
|
||||
import type { UpdateObject } from '../utils/updateCheck.js';
|
||||
|
||||
export interface ProQuotaDialogRequest {
|
||||
failedModel: string;
|
||||
fallbackModel: string;
|
||||
resolve: (intent: FallbackIntent) => void;
|
||||
}
|
||||
|
||||
export interface UIState {
|
||||
history: HistoryItem[];
|
||||
isThemeDialogOpen: boolean;
|
||||
@@ -78,9 +85,8 @@ export interface UIState {
|
||||
workspaceExtensions: any[]; // Extension[]
|
||||
// Quota-related state
|
||||
userTier: UserTierId | undefined;
|
||||
isProQuotaDialogOpen: boolean;
|
||||
proQuotaRequest: ProQuotaDialogRequest | null;
|
||||
currentModel: string;
|
||||
// New fields for complete state management
|
||||
contextFileNames: string[];
|
||||
errorCount: number;
|
||||
availableTerminalHeight: number | undefined;
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
vi,
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
import { act, renderHook } from '@testing-library/react';
|
||||
import {
|
||||
type Config,
|
||||
type FallbackModelHandler,
|
||||
UserTierId,
|
||||
AuthType,
|
||||
isGenericQuotaExceededError,
|
||||
isProQuotaExceededError,
|
||||
makeFakeConfig,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { useQuotaAndFallback } from './useQuotaAndFallback.js';
|
||||
import type { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { AuthState, MessageType } from '../types.js';
|
||||
|
||||
// Mock the error checking functions from the core package to control test scenarios
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
isGenericQuotaExceededError: vi.fn(),
|
||||
isProQuotaExceededError: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
// Use a type alias for SpyInstance as it's not directly exported
|
||||
type SpyInstance = ReturnType<typeof vi.spyOn>;
|
||||
|
||||
describe('useQuotaAndFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHistoryManager: UseHistoryManagerReturn;
|
||||
let mockSetAuthState: Mock;
|
||||
let mockSetModelSwitchedFromQuotaError: Mock;
|
||||
let setFallbackHandlerSpy: SpyInstance;
|
||||
|
||||
const mockedIsGenericQuotaExceededError = isGenericQuotaExceededError as Mock;
|
||||
const mockedIsProQuotaExceededError = isProQuotaExceededError as Mock;
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig = makeFakeConfig();
|
||||
|
||||
// Spy on the method that requires the private field and mock its return.
|
||||
// This is cleaner than modifying the config class for tests.
|
||||
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: 'gemini-pro',
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
|
||||
mockHistoryManager = {
|
||||
addItem: vi.fn(),
|
||||
history: [],
|
||||
updateItem: vi.fn(),
|
||||
clearItems: vi.fn(),
|
||||
loadHistory: vi.fn(),
|
||||
};
|
||||
mockSetAuthState = vi.fn();
|
||||
mockSetModelSwitchedFromQuotaError = vi.fn();
|
||||
|
||||
setFallbackHandlerSpy = vi.spyOn(mockConfig, 'setFallbackModelHandler');
|
||||
vi.spyOn(mockConfig, 'setQuotaErrorOccurred');
|
||||
|
||||
mockedIsGenericQuotaExceededError.mockReturnValue(false);
|
||||
mockedIsProQuotaExceededError.mockReturnValue(false);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should register a fallback handler on initialization', () => {
|
||||
renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(setFallbackHandlerSpy).toHaveBeenCalledTimes(1);
|
||||
expect(setFallbackHandlerSpy.mock.calls[0][0]).toBeInstanceOf(Function);
|
||||
});
|
||||
|
||||
describe('Fallback Handler Logic', () => {
|
||||
// Helper function to render the hook and extract the registered handler
|
||||
const getRegisteredHandler = (
|
||||
userTier: UserTierId = UserTierId.FREE,
|
||||
): FallbackModelHandler => {
|
||||
renderHook(
|
||||
(props) =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: props.userTier,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
{ initialProps: { userTier } },
|
||||
);
|
||||
return setFallbackHandlerSpy.mock.calls[0][0] as FallbackModelHandler;
|
||||
};
|
||||
|
||||
it('should return null and take no action if already in fallback mode', async () => {
|
||||
vi.spyOn(mockConfig, 'isInFallbackMode').mockReturnValue(true);
|
||||
const handler = getRegisteredHandler();
|
||||
const result = await handler('gemini-pro', 'gemini-flash', new Error());
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null and take no action if authType is not LOGIN_WITH_GOOGLE', async () => {
|
||||
// Override the default mock from beforeEach for this specific test
|
||||
vi.spyOn(mockConfig, 'getContentGeneratorConfig').mockReturnValue({
|
||||
model: 'gemini-pro',
|
||||
authType: AuthType.USE_GEMINI,
|
||||
});
|
||||
|
||||
const handler = getRegisteredHandler();
|
||||
const result = await handler('gemini-pro', 'gemini-flash', new Error());
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('Automatic Fallback Scenarios', () => {
|
||||
const testCases = [
|
||||
{
|
||||
errorType: 'generic',
|
||||
tier: UserTierId.FREE,
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B',
|
||||
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'generic',
|
||||
tier: UserTierId.STANDARD, // Paid tier
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B',
|
||||
'switch to using a paid API key from AI Studio',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'other',
|
||||
tier: UserTierId.FREE,
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B for faster responses',
|
||||
'upgrade to a Gemini Code Assist Standard or Enterprise plan',
|
||||
],
|
||||
},
|
||||
{
|
||||
errorType: 'other',
|
||||
tier: UserTierId.LEGACY, // Paid tier
|
||||
expectedMessageSnippets: [
|
||||
'Automatically switching from model-A to model-B for faster responses',
|
||||
'switch to using a paid API key from AI Studio',
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
for (const { errorType, tier, expectedMessageSnippets } of testCases) {
|
||||
it(`should handle ${errorType} error for ${tier} tier correctly`, async () => {
|
||||
mockedIsGenericQuotaExceededError.mockReturnValue(
|
||||
errorType === 'generic',
|
||||
);
|
||||
|
||||
const handler = getRegisteredHandler(tier);
|
||||
const result = await handler(
|
||||
'model-A',
|
||||
'model-B',
|
||||
new Error('quota exceeded'),
|
||||
);
|
||||
|
||||
// Automatic fallbacks should return 'stop'
|
||||
expect(result).toBe('stop');
|
||||
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: MessageType.INFO }),
|
||||
expect.any(Number),
|
||||
);
|
||||
|
||||
const message = (mockHistoryManager.addItem as Mock).mock.calls[0][0]
|
||||
.text;
|
||||
for (const snippet of expectedMessageSnippets) {
|
||||
expect(message).toContain(snippet);
|
||||
}
|
||||
|
||||
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(true);
|
||||
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
describe('Interactive Fallback (Pro Quota Error)', () => {
|
||||
beforeEach(() => {
|
||||
mockedIsProQuotaExceededError.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it('should set an interactive request and wait for user choice', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
|
||||
// Call the handler but do not await it, to check the intermediate state
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
|
||||
await act(async () => {});
|
||||
|
||||
// The hook should now have a pending request for the UI to handle
|
||||
expect(result.current.proQuotaRequest).not.toBeNull();
|
||||
expect(result.current.proQuotaRequest?.failedModel).toBe('gemini-pro');
|
||||
|
||||
// Simulate the user choosing to continue with the fallback model
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
// The original promise from the handler should now resolve
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('retry');
|
||||
|
||||
// The pending request should be cleared from the state
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('should handle race conditions by stopping subsequent requests', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
|
||||
const promise1 = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota 1'),
|
||||
);
|
||||
await act(async () => {});
|
||||
|
||||
const firstRequest = result.current.proQuotaRequest;
|
||||
expect(firstRequest).not.toBeNull();
|
||||
|
||||
const result2 = await handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota 2'),
|
||||
);
|
||||
|
||||
// The lock should have stopped the second request
|
||||
expect(result2).toBe('stop');
|
||||
expect(result.current.proQuotaRequest).toBe(firstRequest);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
const intent1 = await promise1;
|
||||
expect(intent1).toBe('retry');
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleProQuotaChoice', () => {
|
||||
beforeEach(() => {
|
||||
mockedIsProQuotaExceededError.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it('should do nothing if there is no pending pro quota request', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('auth');
|
||||
});
|
||||
|
||||
expect(mockSetAuthState).not.toHaveBeenCalled();
|
||||
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should resolve intent to "auth" and trigger auth state update', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
await act(async () => {}); // Allow state to update
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('auth');
|
||||
});
|
||||
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('auth');
|
||||
expect(mockSetAuthState).toHaveBeenCalledWith(AuthState.Updating);
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
});
|
||||
|
||||
it('should resolve intent to "retry" and add info message on continue', async () => {
|
||||
const { result } = renderHook(() =>
|
||||
useQuotaAndFallback({
|
||||
config: mockConfig,
|
||||
historyManager: mockHistoryManager,
|
||||
userTier: UserTierId.FREE,
|
||||
setAuthState: mockSetAuthState,
|
||||
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
|
||||
}),
|
||||
);
|
||||
|
||||
const handler = setFallbackHandlerSpy.mock
|
||||
.calls[0][0] as FallbackModelHandler;
|
||||
// The first `addItem` call is for the initial quota error message
|
||||
const promise = handler(
|
||||
'gemini-pro',
|
||||
'gemini-flash',
|
||||
new Error('pro quota'),
|
||||
);
|
||||
await act(async () => {}); // Allow state to update
|
||||
|
||||
act(() => {
|
||||
result.current.handleProQuotaChoice('continue');
|
||||
});
|
||||
|
||||
const intent = await promise;
|
||||
expect(intent).toBe('retry');
|
||||
expect(result.current.proQuotaRequest).toBeNull();
|
||||
|
||||
// Check for the second "Switched to fallback model" message
|
||||
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(2);
|
||||
const lastCall = (mockHistoryManager.addItem as Mock).mock.calls[1][0];
|
||||
expect(lastCall.type).toBe(MessageType.INFO);
|
||||
expect(lastCall.text).toContain('Switched to fallback model.');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,175 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
AuthType,
|
||||
type Config,
|
||||
type FallbackModelHandler,
|
||||
type FallbackIntent,
|
||||
isGenericQuotaExceededError,
|
||||
isProQuotaExceededError,
|
||||
UserTierId,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { type UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { AuthState, MessageType } from '../types.js';
|
||||
import { type ProQuotaDialogRequest } from '../contexts/UIStateContext.js';
|
||||
|
||||
interface UseQuotaAndFallbackArgs {
|
||||
config: Config;
|
||||
historyManager: UseHistoryManagerReturn;
|
||||
userTier: UserTierId | undefined;
|
||||
setAuthState: (state: AuthState) => void;
|
||||
setModelSwitchedFromQuotaError: (value: boolean) => void;
|
||||
}
|
||||
|
||||
export function useQuotaAndFallback({
|
||||
config,
|
||||
historyManager,
|
||||
userTier,
|
||||
setAuthState,
|
||||
setModelSwitchedFromQuotaError,
|
||||
}: UseQuotaAndFallbackArgs) {
|
||||
const [proQuotaRequest, setProQuotaRequest] =
|
||||
useState<ProQuotaDialogRequest | null>(null);
|
||||
const isDialogPending = useRef(false);
|
||||
|
||||
// Set up Flash fallback handler
|
||||
useEffect(() => {
|
||||
const fallbackHandler: FallbackModelHandler = async (
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
): Promise<FallbackIntent | null> => {
|
||||
if (config.isInFallbackMode()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Fallbacks are currently only handled for OAuth users.
|
||||
const contentGeneratorConfig = config.getContentGeneratorConfig();
|
||||
if (
|
||||
!contentGeneratorConfig ||
|
||||
contentGeneratorConfig.authType !== AuthType.LOGIN_WITH_GOOGLE
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Use actual user tier if available; otherwise, default to FREE tier behavior (safe default)
|
||||
const isPaidTier =
|
||||
userTier === UserTierId.LEGACY || userTier === UserTierId.STANDARD;
|
||||
|
||||
let message: string;
|
||||
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
// Pro Quota specific messages (Interactive)
|
||||
if (isPaidTier) {
|
||||
message = `⚡ You have reached your daily ${failedModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `⚡ You have reached your daily ${failedModel} quota limit.
|
||||
⚡ You can choose to authenticate with a paid API key or continue with the fallback model.
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else if (error && isGenericQuotaExceededError(error)) {
|
||||
// Generic Quota (Automatic fallback)
|
||||
const actionMessage = `⚡ You have reached your daily quota limit.\n⚡ Automatically switching from ${failedModel} to ${fallbackModel} for the remainder of this session.`;
|
||||
|
||||
if (isPaidTier) {
|
||||
message = `${actionMessage}
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `${actionMessage}
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
} else {
|
||||
// Consecutive 429s or other errors (Automatic fallback)
|
||||
const actionMessage = `⚡ Automatically switching from ${failedModel} to ${fallbackModel} for faster responses for the remainder of this session.`;
|
||||
|
||||
if (isPaidTier) {
|
||||
message = `${actionMessage}
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
|
||||
⚡ To continue accessing the ${failedModel} model today, consider using /auth to switch to using a paid API key from AI Studio at https://aistudio.google.com/apikey`;
|
||||
} else {
|
||||
message = `${actionMessage}
|
||||
⚡ Possible reasons for this are that you have received multiple consecutive capacity errors or you have reached your daily ${failedModel} quota limit
|
||||
⚡ To increase your limits, upgrade to a Gemini Code Assist Standard or Enterprise plan with higher limits at https://goo.gle/set-up-gemini-code-assist
|
||||
⚡ Or you can utilize a Gemini API Key. See: https://goo.gle/gemini-cli-docs-auth#gemini-api-key
|
||||
⚡ You can switch authentication methods by typing /auth`;
|
||||
}
|
||||
}
|
||||
|
||||
// Add message to UI history
|
||||
historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: message,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
setModelSwitchedFromQuotaError(true);
|
||||
config.setQuotaErrorOccurred(true);
|
||||
|
||||
// Interactive Fallback for Pro quota
|
||||
if (error && isProQuotaExceededError(error)) {
|
||||
if (isDialogPending.current) {
|
||||
return 'stop'; // A dialog is already active, so just stop this request.
|
||||
}
|
||||
isDialogPending.current = true;
|
||||
|
||||
const intent: FallbackIntent = await new Promise<FallbackIntent>(
|
||||
(resolve) => {
|
||||
setProQuotaRequest({
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
resolve,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
return intent;
|
||||
}
|
||||
|
||||
return 'stop';
|
||||
};
|
||||
|
||||
config.setFallbackModelHandler(fallbackHandler);
|
||||
}, [config, historyManager, userTier, setModelSwitchedFromQuotaError]);
|
||||
|
||||
const handleProQuotaChoice = useCallback(
|
||||
(choice: 'auth' | 'continue') => {
|
||||
if (!proQuotaRequest) return;
|
||||
|
||||
const intent: FallbackIntent = choice === 'auth' ? 'auth' : 'retry';
|
||||
proQuotaRequest.resolve(intent);
|
||||
setProQuotaRequest(null);
|
||||
isDialogPending.current = false; // Reset the flag here
|
||||
|
||||
if (choice === 'auth') {
|
||||
setAuthState(AuthState.Updating);
|
||||
} else {
|
||||
historyManager.addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Switched to fallback model. Tip: Press Ctrl+P (or Up Arrow) to recall your previous prompt and submit it again if you wish.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}
|
||||
},
|
||||
[proQuotaRequest, setAuthState, historyManager],
|
||||
);
|
||||
|
||||
return {
|
||||
proQuotaRequest,
|
||||
handleProQuotaChoice,
|
||||
};
|
||||
}
|
||||
@@ -52,6 +52,7 @@ import type { FileSystemService } from '../services/fileSystemService.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
|
||||
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { MCPOAuthConfig, AnyToolInvocation };
|
||||
@@ -157,12 +158,6 @@ export interface SandboxConfig {
|
||||
image: string;
|
||||
}
|
||||
|
||||
export type FlashFallbackHandler = (
|
||||
currentModel: string,
|
||||
fallbackModel: string,
|
||||
error?: unknown,
|
||||
) => Promise<boolean | string | null>;
|
||||
|
||||
export interface ConfigParameters {
|
||||
sessionId: string;
|
||||
embeddingModel?: string;
|
||||
@@ -281,7 +276,7 @@ export class Config {
|
||||
name: string;
|
||||
extensionName: string;
|
||||
}>;
|
||||
flashFallbackHandler?: FlashFallbackHandler;
|
||||
fallbackModelHandler?: FallbackModelHandler;
|
||||
private quotaErrorOccurred: boolean = false;
|
||||
private readonly summarizeToolOutput:
|
||||
| Record<string, SummarizeToolOutputSettings>
|
||||
@@ -490,8 +485,8 @@ export class Config {
|
||||
this.inFallbackMode = active;
|
||||
}
|
||||
|
||||
setFlashFallbackHandler(handler: FlashFallbackHandler): void {
|
||||
this.flashFallbackHandler = handler;
|
||||
setFallbackModelHandler(handler: FallbackModelHandler): void {
|
||||
this.fallbackModelHandler = handler;
|
||||
}
|
||||
|
||||
getMaxSessionTurns(): number {
|
||||
|
||||
@@ -4,7 +4,15 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import {
|
||||
@@ -212,16 +220,19 @@ describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
let client: GeminiClient;
|
||||
let mockGenerateContentFn: Mock;
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
mockGenerateContentFn = vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
|
||||
});
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
|
||||
mockContentGenerator = {
|
||||
generateContent: vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
|
||||
}),
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
@@ -270,6 +281,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
|
||||
}),
|
||||
getGeminiClient: vi.fn(),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
setFallbackMode: vi.fn(),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||
@@ -453,6 +465,27 @@ describe('Gemini Client (client.ts)', () => {
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the Flash model when fallback mode is active', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const schema = { type: 'string' };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
|
||||
|
||||
// Mock config to be in fallback mode
|
||||
// We access the mock via the client instance which holds the mocked config
|
||||
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
|
||||
|
||||
await client.generateJson(contents, schema, abortSignal, requestedModel);
|
||||
|
||||
// Assert that the Flash model was used, not the requested model
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHistory', () => {
|
||||
@@ -2210,32 +2243,28 @@ ${JSON.stringify(
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleFlashFallback', () => {
|
||||
it('should use current model from config when checking for fallback', async () => {
|
||||
const initialModel = client['config'].getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
it('should use the Flash model when fallback mode is active', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const generationConfig = { temperature: 0.5 };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
|
||||
|
||||
// mock config been changed
|
||||
const currentModel = initialModel + '-changed';
|
||||
const getModelSpy = vi.spyOn(client['config'], 'getModel');
|
||||
getModelSpy.mockReturnValue(currentModel);
|
||||
// Mock config to be in fallback mode
|
||||
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
|
||||
|
||||
const mockFallbackHandler = vi.fn().mockResolvedValue(true);
|
||||
client['config'].flashFallbackHandler = mockFallbackHandler;
|
||||
client['config'].setModel = vi.fn();
|
||||
|
||||
const result = await client['handleFlashFallback'](
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
await client.generateContent(
|
||||
contents,
|
||||
generationConfig,
|
||||
abortSignal,
|
||||
requestedModel,
|
||||
);
|
||||
|
||||
expect(result).toBe(fallbackModel);
|
||||
|
||||
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
undefined,
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,7 +31,6 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_THINKING_MODE,
|
||||
@@ -49,6 +48,7 @@ import {
|
||||
NextSpeakerCheckEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import type { IdeContext, File } from '../ide/ideContext.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
if (model.startsWith('gemini-2.5')) return true;
|
||||
@@ -550,6 +550,8 @@ export class GeminiClient {
|
||||
model: string,
|
||||
config: GenerateContentConfig = {},
|
||||
): Promise<Record<string, unknown>> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
@@ -559,10 +561,15 @@ export class GeminiClient {
|
||||
...config,
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: model;
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
return this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
model: modelToUse,
|
||||
config: {
|
||||
...requestConfig,
|
||||
systemInstruction,
|
||||
@@ -573,10 +580,17 @@ export class GeminiClient {
|
||||
},
|
||||
this.lastPromptId,
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) =>
|
||||
// Pass the captured model to the centralized handler.
|
||||
await handleFallback(this.config, currentAttemptModel, authType, error);
|
||||
|
||||
const result = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
@@ -599,7 +613,7 @@ export class GeminiClient {
|
||||
if (text.startsWith(prefix) && text.endsWith(suffix)) {
|
||||
logMalformedJsonResponse(
|
||||
this.config,
|
||||
new MalformedJsonResponseEvent(model),
|
||||
new MalformedJsonResponseEvent(currentAttemptModel),
|
||||
);
|
||||
text = text
|
||||
.substring(prefix.length, text.length - suffix.length)
|
||||
@@ -655,6 +669,8 @@ export class GeminiClient {
|
||||
abortSignal: AbortSignal,
|
||||
model: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
const configToUse: GenerateContentConfig = {
|
||||
...this.generateContentConfig,
|
||||
...generationConfig,
|
||||
@@ -670,19 +686,30 @@ export class GeminiClient {
|
||||
systemInstruction,
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: model;
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
return this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
model: modelToUse,
|
||||
config: requestConfig,
|
||||
contents,
|
||||
},
|
||||
this.lastPromptId,
|
||||
);
|
||||
};
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) =>
|
||||
// Pass the captured model to the centralized handler.
|
||||
await handleFallback(this.config, currentAttemptModel, authType, error);
|
||||
|
||||
const result = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
return result;
|
||||
@@ -693,7 +720,7 @@ export class GeminiClient {
|
||||
|
||||
await reportError(
|
||||
error,
|
||||
`Error generating content via API with model ${model}.`,
|
||||
`Error generating content via API with model ${currentAttemptModel}.`,
|
||||
{
|
||||
requestContents: contents,
|
||||
requestConfig: configToUse,
|
||||
@@ -701,7 +728,7 @@ export class GeminiClient {
|
||||
'generateContent-api',
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to generate content with model ${model}: ${getErrorMessage(error)}`,
|
||||
`Failed to generate content with model ${currentAttemptModel}: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -880,53 +907,6 @@ export class GeminiClient {
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
|
||||
* Uses a fallback handler if provided by the config; otherwise, returns null.
|
||||
*/
|
||||
private async handleFlashFallback(
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentModel = this.config.getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
// Don't fallback if already using Flash model
|
||||
if (currentModel === fallbackModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if config has a fallback handler (set by CLI package)
|
||||
const fallbackHandler = this.config.flashFallbackHandler;
|
||||
if (typeof fallbackHandler === 'function') {
|
||||
try {
|
||||
const accepted = await fallbackHandler(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
// Check if the model was switched manually in the handler
|
||||
if (this.config.getModel() === fallbackModel) {
|
||||
return null; // Model was switched but don't continue with current prompt
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Flash fallback handler failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export const TEST_ONLY = {
|
||||
|
||||
@@ -20,6 +20,9 @@ import {
|
||||
} from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import { type RetryOptions } from '../utils/retry.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -47,6 +50,23 @@ vi.mock('node:fs', () => {
|
||||
};
|
||||
});
|
||||
|
||||
const { mockHandleFallback } = vi.hoisted(() => ({
|
||||
mockHandleFallback: vi.fn(),
|
||||
}));
|
||||
|
||||
// Add mock for the retry utility
|
||||
const { mockRetryWithBackoff } = vi.hoisted(() => ({
|
||||
mockRetryWithBackoff: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/retry.js', () => ({
|
||||
retryWithBackoff: mockRetryWithBackoff,
|
||||
}));
|
||||
|
||||
vi.mock('../fallback/handler.js', () => ({
|
||||
handleFallback: mockHandleFallback,
|
||||
}));
|
||||
|
||||
const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } =
|
||||
vi.hoisted(() => ({
|
||||
mockLogInvalidChunk: vi.fn(),
|
||||
@@ -76,17 +96,21 @@ describe('GeminiChat', () => {
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
mockHandleFallback.mockClear();
|
||||
// Default mock implementation for tests that don't care about retry logic
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getTelemetryLogPromptsEnabled: () => true,
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: () => ({
|
||||
authType: 'oauth-personal',
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'oauth-personal', // Ensure this is set for fallback tests
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
setModel: vi.fn(),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
flashFallbackHandler: undefined,
|
||||
@@ -1476,8 +1500,176 @@ describe('GeminiChat', () => {
|
||||
expect(turn4.parts[0].text).toBe('second response');
|
||||
});
|
||||
|
||||
describe('Model Resolution', () => {
|
||||
const mockResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'response' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
it('should use the configured model when not in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res1');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-2.5-pro',
|
||||
}),
|
||||
'prompt-id-res1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res2');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'prompt-id-res2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
|
||||
async () =>
|
||||
(async function* () {
|
||||
yield mockResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-res3',
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'prompt-id-res3',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Integration (Retries)', () => {
|
||||
const error429 = Object.assign(new Error('API Error 429: Quota exceeded'), {
|
||||
status: 429,
|
||||
});
|
||||
|
||||
// Define the simulated behavior for retryWithBackoff for these tests.
|
||||
// This simulation tries the apiCall, if it fails, it calls the callback,
|
||||
// and then tries the apiCall again if the callback returns true.
|
||||
const simulateRetryBehavior = async <T>(
|
||||
apiCall: () => Promise<T>,
|
||||
options: Partial<RetryOptions>,
|
||||
) => {
|
||||
try {
|
||||
return await apiCall();
|
||||
} catch (error) {
|
||||
if (options.onPersistent429) {
|
||||
// We simulate the "persistent" trigger here for simplicity.
|
||||
const shouldRetry = await options.onPersistent429(
|
||||
options.authType,
|
||||
error,
|
||||
);
|
||||
if (shouldRetry) {
|
||||
return await apiCall();
|
||||
}
|
||||
}
|
||||
throw error; // Stop if callback returns false/null or doesn't exist
|
||||
}
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockRetryWithBackoff.mockImplementation(simulateRetryBehavior);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
});
|
||||
|
||||
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
|
||||
const FAILED_MODEL = 'gemini-2.5-pro';
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL);
|
||||
const authType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
|
||||
authType,
|
||||
model: FAILED_MODEL,
|
||||
});
|
||||
|
||||
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
|
||||
isInFallbackModeSpy.mockReturnValue(false);
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContent)
|
||||
.mockRejectedValueOnce(error429) // Attempt 1 fails
|
||||
.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }],
|
||||
} as unknown as GenerateContentResponse); // Attempt 2 succeeds
|
||||
|
||||
mockHandleFallback.mockImplementation(async () => {
|
||||
isInFallbackModeSpy.mockReturnValue(true);
|
||||
return true; // Signal retry
|
||||
});
|
||||
|
||||
const result = await chat.sendMessage(
|
||||
{ message: 'trigger 429' },
|
||||
'prompt-id-fb1',
|
||||
);
|
||||
|
||||
expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(mockHandleFallback).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
FAILED_MODEL,
|
||||
authType,
|
||||
error429,
|
||||
);
|
||||
|
||||
expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'Success on retry',
|
||||
);
|
||||
});
|
||||
|
||||
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockContentGenerator.generateContent).mockRejectedValue(
|
||||
error429,
|
||||
);
|
||||
mockHandleFallback.mockResolvedValue(false);
|
||||
|
||||
await expect(
|
||||
chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'),
|
||||
).rejects.toThrow(error429);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should discard valid partial content from a failed attempt upon retry', async () => {
|
||||
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
// Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: yields one valid chunk, then one invalid chunk
|
||||
@@ -1512,7 +1704,7 @@ describe('GeminiChat', () => {
|
||||
})(),
|
||||
);
|
||||
|
||||
// ACT: Send a message and consume the stream
|
||||
// Send a message and consume the stream
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-discard-test',
|
||||
@@ -1522,7 +1714,6 @@ describe('GeminiChat', () => {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// ASSERT
|
||||
// Check that a retry happened
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
|
||||
|
||||
@@ -18,7 +18,6 @@ import type {
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { hasCycleInSchema } from '../tools/tools.js';
|
||||
@@ -35,6 +34,7 @@ import {
|
||||
ContentRetryFailureEvent,
|
||||
InvalidChunkEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
|
||||
@@ -179,53 +179,6 @@ export class GeminiChat {
|
||||
this.chatRecordingService.initialize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
|
||||
* Uses a fallback handler if provided by the config; otherwise, returns null.
|
||||
*/
|
||||
private async handleFlashFallback(
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentModel = this.config.getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
// Don't fallback if already using Flash model
|
||||
if (currentModel === fallbackModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if config has a fallback handler (set by CLI package)
|
||||
const fallbackHandler = this.config.flashFallbackHandler;
|
||||
if (typeof fallbackHandler === 'function') {
|
||||
try {
|
||||
const accepted = await fallbackHandler(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
// Check if the model was switched manually in the handler
|
||||
if (this.config.getModel() === fallbackModel) {
|
||||
return null; // Model was switched but don't continue with current prompt
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Flash fallback handler failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
setSystemInstruction(sysInstr: string) {
|
||||
this.generationConfig.systemInstruction = sysInstr;
|
||||
}
|
||||
@@ -272,8 +225,13 @@ export class GeminiChat {
|
||||
let response: GenerateContentResponse;
|
||||
|
||||
try {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
// Prevent Flash model calls immediately after quota error
|
||||
if (
|
||||
@@ -295,6 +253,19 @@ export class GeminiChat {
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
|
||||
response = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
// Check for known error messages and codes.
|
||||
@@ -305,8 +276,7 @@ export class GeminiChat {
|
||||
}
|
||||
return false; // Don't retry other errors by default
|
||||
},
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
@@ -484,8 +454,13 @@ export class GeminiChat {
|
||||
prompt_id: string,
|
||||
userContent: Content,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.getModel();
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
if (
|
||||
this.config.getQuotaErrorOccurred() &&
|
||||
@@ -506,6 +481,19 @@ export class GeminiChat {
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
|
||||
const streamResponse = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
if (error instanceof Error && error.message) {
|
||||
@@ -515,8 +503,7 @@ export class GeminiChat {
|
||||
}
|
||||
return false;
|
||||
},
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type Mock,
|
||||
type MockInstance,
|
||||
afterEach,
|
||||
} from 'vitest';
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { logFlashFallback } from '../telemetry/index.js';
|
||||
import type { FallbackModelHandler } from './types.js';
|
||||
|
||||
// Mock the telemetry logger and event class
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logFlashFallback: vi.fn(),
|
||||
FlashFallbackEvent: class {},
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
isInFallbackMode: vi.fn(() => false),
|
||||
setFallbackMode: vi.fn(),
|
||||
fallbackHandler: undefined,
|
||||
...overrides,
|
||||
}) as unknown as Config;
|
||||
|
||||
describe('handleFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHandler: Mock<FallbackModelHandler>;
|
||||
let consoleErrorSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockHandler = vi.fn();
|
||||
// Default setup: OAuth user, Pro model failed, handler injected
|
||||
mockConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null immediately if authType is not OAuth', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_API_KEY,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the failed model is already the fallback model', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
FALLBACK_MODEL, // Failed model is Flash
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if no fallbackHandler is injected in config', async () => {
|
||||
const configWithoutHandler = createMockConfig({
|
||||
fallbackModelHandler: undefined,
|
||||
});
|
||||
const result = await handleFallback(
|
||||
configWithoutHandler,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
describe('when handler returns "retry"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return true', async () => {
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "stop"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return false', async () => {
|
||||
mockHandler.mockResolvedValue('stop');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "auth"', () => {
|
||||
it('should NOT activate fallback mode and return false', async () => {
|
||||
mockHandler.mockResolvedValue('auth');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns an unexpected value', () => {
|
||||
it('should log an error and return null', async () => {
|
||||
mockHandler.mockResolvedValue(null);
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
new Error(
|
||||
'Unexpected fallback intent received from fallbackModelHandler: "null"',
|
||||
),
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
|
||||
const mockError = new Error('Quota Exceeded');
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
|
||||
|
||||
expect(mockHandler).toHaveBeenCalledWith(
|
||||
MOCK_PRO_MODEL,
|
||||
FALLBACK_MODEL,
|
||||
mockError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
|
||||
// Setup config where fallback mode is already active
|
||||
const activeFallbackConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
isInFallbackMode: vi.fn(() => true), // Already active
|
||||
setFallbackMode: vi.fn(),
|
||||
});
|
||||
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
activeFallbackConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
// Should still return true to allow the retry (which will use the active fallback mode)
|
||||
expect(result).toBe(true);
|
||||
// Should still consult the handler
|
||||
expect(mockHandler).toHaveBeenCalled();
|
||||
// But should not mutate state or log telemetry again
|
||||
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should catch errors from the handler, log an error, and return null', async () => {
|
||||
const handlerError = new Error('UI interaction failed');
|
||||
mockHandler.mockRejectedValue(handlerError);
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
handlerError,
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
|
||||
|
||||
export async function handleFallback(
|
||||
config: Config,
|
||||
failedModel: string,
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | boolean | null> {
|
||||
// Applicability Checks
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
|
||||
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
if (failedModel === fallbackModel) return null;
|
||||
|
||||
// Consult UI Handler for Intent
|
||||
const fallbackModelHandler = config.fallbackModelHandler;
|
||||
if (typeof fallbackModelHandler !== 'function') return null;
|
||||
|
||||
try {
|
||||
// Pass the specific failed model to the UI handler.
|
||||
const intent = await fallbackModelHandler(
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
|
||||
// Process Intent and Update State
|
||||
switch (intent) {
|
||||
case 'retry':
|
||||
// Activate fallback mode. The NEXT retry attempt will pick this up.
|
||||
activateFallbackMode(config, authType);
|
||||
return true; // Signal retryWithBackoff to continue.
|
||||
|
||||
case 'stop':
|
||||
activateFallbackMode(config, authType);
|
||||
return false;
|
||||
|
||||
case 'auth':
|
||||
return false;
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
|
||||
);
|
||||
}
|
||||
} catch (handlerError) {
|
||||
console.error('Fallback UI handler failed:', handlerError);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function activateFallbackMode(config: Config, authType: string | undefined) {
|
||||
if (!config.isInFallbackMode()) {
|
||||
config.setFallbackMode(true);
|
||||
if (authType) {
|
||||
logFlashFallback(config, new FlashFallbackEvent(authType));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Defines the intent returned by the UI layer during a fallback scenario.
|
||||
*/
|
||||
export type FallbackIntent =
|
||||
| 'retry' // Immediately retry the current request with the fallback model.
|
||||
| 'stop' // Switch to fallback for future requests, but stop the current request.
|
||||
| 'auth'; // Stop the current request; user intends to change authentication.
|
||||
|
||||
/**
|
||||
* The interface for the handler provided by the UI layer (e.g., the CLI)
|
||||
* to interact with the user during a fallback scenario.
|
||||
*/
|
||||
export type FallbackModelHandler = (
|
||||
failedModel: string,
|
||||
fallbackModel: string,
|
||||
error?: unknown,
|
||||
) => Promise<FallbackIntent | null>;
|
||||
@@ -20,6 +20,8 @@ export * from './core/geminiRequest.js';
|
||||
export * from './core/coreToolScheduler.js';
|
||||
export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
export * from './fallback/types.js';
|
||||
|
||||
export * from './code_assist/codeAssist.js';
|
||||
export * from './code_assist/oauth2.js';
|
||||
export * from './code_assist/server.js';
|
||||
|
||||
+32
-26
@@ -17,10 +17,13 @@ import {
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
// Import the new types (Assuming this test file is in packages/core/src/utils/)
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
describe('Flash Fallback Integration', () => {
|
||||
// Update the description to reflect that this tests the retry utility's integration
|
||||
describe('Retry Utility Fallback Integration', () => {
|
||||
let config: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -41,25 +44,28 @@ describe('Flash Fallback Integration', () => {
|
||||
resetRequestCounter();
|
||||
});
|
||||
|
||||
it('should automatically accept fallback', async () => {
|
||||
// Set up a minimal flash fallback handler for testing
|
||||
const flashFallbackHandler = async (): Promise<boolean> => true;
|
||||
// This test validates the Config's ability to store and execute the handler contract.
|
||||
it('should execute the injected FallbackHandler contract correctly', async () => {
|
||||
// Set up a minimal handler for testing, ensuring it matches the new type.
|
||||
const fallbackHandler: FallbackModelHandler = async () => 'retry';
|
||||
|
||||
config.setFlashFallbackHandler(flashFallbackHandler);
|
||||
// Use the generalized setter
|
||||
config.setFallbackModelHandler(fallbackHandler);
|
||||
|
||||
// Call the handler directly to test
|
||||
const result = await config.flashFallbackHandler!(
|
||||
// Call the handler directly via the config property
|
||||
const result = await config.fallbackModelHandler!(
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
// Verify it automatically accepts
|
||||
expect(result).toBe(true);
|
||||
// Verify it returns the correct intent
|
||||
expect(result).toBe('retry');
|
||||
});
|
||||
|
||||
it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => {
|
||||
// This test validates the retry utility's logic for triggering the callback.
|
||||
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
|
||||
let fallbackCalled = false;
|
||||
let fallbackModel = '';
|
||||
// Removed fallbackModel variable as it's no longer relevant here.
|
||||
|
||||
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
|
||||
const mockApiCall = vi
|
||||
@@ -68,11 +74,11 @@ describe('Flash Fallback Integration', () => {
|
||||
.mockRejectedValueOnce(createSimulated429Error())
|
||||
.mockResolvedValueOnce('success after fallback');
|
||||
|
||||
// Mock fallback handler
|
||||
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
|
||||
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
|
||||
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
|
||||
fallbackCalled = true;
|
||||
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
return fallbackModel;
|
||||
// Return true to signal retryWithBackoff to reset attempts and continue.
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
|
||||
@@ -84,14 +90,13 @@ describe('Flash Fallback Integration', () => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockFallbackHandler,
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
|
||||
// Verify fallback was triggered
|
||||
// Verify fallback mechanism was triggered
|
||||
expect(fallbackCalled).toBe(true);
|
||||
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||
expect(mockPersistent429Callback).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
expect.any(Error),
|
||||
);
|
||||
@@ -100,16 +105,16 @@ describe('Flash Fallback Integration', () => {
|
||||
expect(mockApiCall).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should not trigger fallback for API key users', async () => {
|
||||
it('should not trigger onPersistent429 for API key users', async () => {
|
||||
let fallbackCalled = false;
|
||||
|
||||
// Mock function that simulates 429 errors
|
||||
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
|
||||
|
||||
// Mock fallback handler
|
||||
const mockFallbackHandler = vi.fn(async () => {
|
||||
// Mock the callback
|
||||
const mockPersistent429Callback = vi.fn(async () => {
|
||||
fallbackCalled = true;
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with API key auth type - should not trigger fallback
|
||||
@@ -122,7 +127,7 @@ describe('Flash Fallback Integration', () => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockFallbackHandler,
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.USE_GEMINI, // API key auth type
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -132,10 +137,11 @@ describe('Flash Fallback Integration', () => {
|
||||
|
||||
// Verify fallback was NOT triggered for API key users
|
||||
expect(fallbackCalled).toBe(false);
|
||||
expect(mockFallbackHandler).not.toHaveBeenCalled();
|
||||
expect(mockPersistent429Callback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should properly disable simulation state after fallback', () => {
|
||||
// This test validates the test utilities themselves.
|
||||
it('should properly disable simulation state after fallback (Test Utility)', () => {
|
||||
// Enable simulation
|
||||
setSimulate429(true);
|
||||
|
||||
Reference in New Issue
Block a user