diff --git a/packages/cli/src/ui/components/ModelDialog.test.tsx b/packages/cli/src/ui/components/ModelDialog.test.tsx index c9ee077bc8..1c8546344a 100644 --- a/packages/cli/src/ui/components/ModelDialog.test.tsx +++ b/packages/cli/src/ui/components/ModelDialog.test.tsx @@ -14,6 +14,11 @@ import { DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + AuthType, } from '@google/gemini-cli-core'; import type { Config, ModelSlashCommandEvent } from '@google/gemini-cli-core'; @@ -42,12 +47,14 @@ describe('', () => { const mockGetModel = vi.fn(); const mockOnClose = vi.fn(); const mockGetHasAccessToPreviewModel = vi.fn(); + const mockGetGemini31LaunchedSync = vi.fn(); interface MockConfig extends Partial { setModel: (model: string, isTemporary?: boolean) => void; getModel: () => string; getHasAccessToPreviewModel: () => boolean; getIdeMode: () => boolean; + getGemini31LaunchedSync: () => boolean; } const mockConfig: MockConfig = { @@ -55,12 +62,14 @@ describe('', () => { getModel: mockGetModel, getHasAccessToPreviewModel: mockGetHasAccessToPreviewModel, getIdeMode: () => false, + getGemini31LaunchedSync: mockGetGemini31LaunchedSync, }; beforeEach(() => { vi.resetAllMocks(); mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); mockGetHasAccessToPreviewModel.mockReturnValue(false); + mockGetGemini31LaunchedSync.mockReturnValue(false); // Default implementation for getDisplayString mockGetDisplayString.mockImplementation((val: string) => { @@ -72,7 +81,21 @@ describe('', () => { const renderComponent = (configValue = mockConfig as Config) => renderWithProviders(, { + const renderComponent = async ( + configValue = mockConfig as Config, + authType = AuthType.LOGIN_WITH_GOOGLE, + ) => { + const settings = createMockSettings({ + security: { + auth: { + selectedType: authType, + }, + }, + }); + + const result = renderWithProviders(, { config: configValue, + settings, }); it('renders the initial "main" view correctly', () => { @@ -210,4 +233,98 @@ describe('', () => { expect(lastFrame()).toContain('Manual'); }); }); + + it('shows the preferred manual model in the main view option', async () => { + mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL); + const { lastFrame, unmount } = await renderComponent(); + + expect(lastFrame()).toContain(`Manual (${DEFAULT_GEMINI_MODEL})`); + unmount(); + }); + + describe('Preview Models', () => { + beforeEach(() => { + mockGetHasAccessToPreviewModel.mockReturnValue(true); + }); + + it('shows Auto (Preview) in main view when access is granted', async () => { + const { lastFrame, unmount } = await renderComponent(); + expect(lastFrame()).toContain('Auto (Preview)'); + unmount(); + }); + + it('shows Gemini 3 models in manual view when Gemini 3.1 is NOT launched', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(false); + const { lastFrame, stdin, waitUntilReady, unmount } = + await renderComponent(); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + const output = lastFrame(); + expect(output).toContain(PREVIEW_GEMINI_MODEL); + expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL); + unmount(); + }); + + it('shows Gemini 3.1 models in manual view when Gemini 3.1 IS launched', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(true); + const { lastFrame, stdin, waitUntilReady, unmount } = + await renderComponent(mockConfig as Config, AuthType.USE_VERTEX_AI); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + const output = lastFrame(); + expect(output).toContain(PREVIEW_GEMINI_3_1_MODEL); + expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL); + unmount(); + }); + + it('uses custom tools model when Gemini 3.1 IS launched and auth is Gemini API Key', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(true); + const { stdin, waitUntilReady, unmount } = await renderComponent( + mockConfig as Config, + AuthType.USE_GEMINI, + ); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + // Select Gemini 3.1 (first item in preview section) + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + await waitFor(() => { + expect(mockSetModel).toHaveBeenCalledWith( + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + true, + ); + }); + unmount(); + }); + }); }); diff --git a/packages/cli/src/ui/components/ModelDialog.tsx b/packages/cli/src/ui/components/ModelDialog.tsx index 8b5d27c138..5205ceee71 100644 --- a/packages/cli/src/ui/components/ModelDialog.tsx +++ b/packages/cli/src/ui/components/ModelDialog.tsx @@ -19,11 +19,14 @@ import { ModelSlashCommandEvent, logModelSlashCommand, getDisplayString, + AuthType, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, } from '@google/gemini-cli-core'; import { useKeypress } from '../hooks/useKeypress.js'; import { theme } from '../semantic-colors.js'; import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js'; import { ConfigContext } from '../contexts/ConfigContext.js'; +import { useSettings } from '../contexts/SettingsContext.js'; interface ModelDialogProps { onClose: () => void; @@ -31,6 +34,7 @@ interface ModelDialogProps { export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { const config = useContext(ConfigContext); + const settings = useSettings(); const [view, setView] = useState<'main' | 'manual'>('main'); const [persistMode, setPersistMode] = useState(false); @@ -39,6 +43,9 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { const shouldShowPreviewModels = config?.getHasAccessToPreviewModel(); const useGemini31 = config?.getGemini31LaunchedSync?.() ?? false; + const selectedAuthType = settings.merged.security.auth.selectedType; + const useCustomToolModel = + useGemini31 && selectedAuthType !== AuthType.USE_VERTEX_AI; const manualModelSelected = useMemo(() => { const manualModels = [ @@ -47,6 +54,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { DEFAULT_GEMINI_FLASH_LITE_MODEL, PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_FLASH_MODEL, ]; if (manualModels.includes(preferredModel)) { @@ -126,11 +134,19 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { ]; if (shouldShowPreviewModels) { + const previewProModel = useGemini31 + ? PREVIEW_GEMINI_3_1_MODEL + : PREVIEW_GEMINI_MODEL; + + const previewProValue = useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : previewProModel; + list.unshift( { - value: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL, - title: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL, - key: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL, + value: previewProValue, + title: previewProModel, + key: previewProModel, }, { value: PREVIEW_GEMINI_FLASH_MODEL, @@ -140,7 +156,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { ); } return list; - }, [shouldShowPreviewModels, useGemini31]); + }, [shouldShowPreviewModels, useGemini31, useCustomToolModel]); const options = view === 'main' ? mainOptions : manualOptions; diff --git a/packages/cli/src/ui/components/StatsDisplay.tsx b/packages/cli/src/ui/components/StatsDisplay.tsx index da4e6e901c..fc3f80778f 100644 --- a/packages/cli/src/ui/components/StatsDisplay.tsx +++ b/packages/cli/src/ui/components/StatsDisplay.tsx @@ -26,6 +26,7 @@ import { isActiveModel, getDisplayString, isAutoModel, + AuthType, } from '@google/gemini-cli-core'; import { useSettings } from '../contexts/SettingsContext.js'; import { useConfig } from '../contexts/ConfigContext.js'; @@ -84,9 +85,12 @@ const buildModelRows = ( models: Record, quotas?: RetrieveUserQuotaResponse, useGemini3_1 = false, + useCustomToolModel = false, ) => { const getBaseModelName = (name: string) => name.replace('-001', ''); - const usedModelNames = new Set(Object.keys(models).map(getBaseModelName)); + const usedModelNames = new Set( + Object.keys(models).map(getBaseModelName).map(getDisplayString), + ); // 1. Models with active usage const activeRows = Object.entries(models).map(([name, metrics]) => { @@ -95,7 +99,7 @@ const buildModelRows = ( const inputTokens = metrics.tokens.input; return { key: name, - modelName, + modelName: getDisplayString(modelName), requests: metrics.api.totalRequests, cachedTokens: cachedTokens.toLocaleString(), inputTokens: inputTokens.toLocaleString(), @@ -111,12 +115,12 @@ const buildModelRows = ( ?.filter( (b) => b.modelId && - isActiveModel(b.modelId, useGemini3_1) && - !usedModelNames.has(b.modelId), + isActiveModel(b.modelId, useGemini3_1, useCustomToolModel) && + !usedModelNames.has(getDisplayString(b.modelId)), ) .map((bucket) => ({ key: bucket.modelId!, - modelName: bucket.modelId!, + modelName: getDisplayString(bucket.modelId!), requests: '-', cachedTokens: '-', inputTokens: '-', @@ -138,6 +142,7 @@ const ModelUsageTable: React.FC<{ pooledLimit?: number; pooledResetTime?: string; useGemini3_1?: boolean; + useCustomToolModel?: boolean; }> = ({ models, quotas, @@ -147,9 +152,10 @@ const ModelUsageTable: React.FC<{ pooledRemaining, pooledLimit, pooledResetTime, - useGemini3_1 = false, + useGemini3_1, + useCustomToolModel, }) => { - const rows = buildModelRows(models, quotas, useGemini3_1); + const rows = buildModelRows(models, quotas, useGemini3_1, useCustomToolModel); if (rows.length === 0) { return null; @@ -407,7 +413,9 @@ export const StatsDisplay: React.FC = ({ const settings = useSettings(); const config = useConfig(); const useGemini3_1 = config.getGemini31LaunchedSync?.() ?? false; - + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType !== AuthType.USE_VERTEX_AI; const pooledRemaining = quotaStats?.remaining; const pooledLimit = quotaStats?.limit; const pooledResetTime = quotaStats?.resetTime; @@ -542,6 +550,7 @@ export const StatsDisplay: React.FC = ({ pooledLimit={pooledLimit} pooledResetTime={pooledResetTime} useGemini3_1={useGemini3_1} + useCustomToolModel={useCustomToolModel} /> ); diff --git a/packages/cli/src/ui/components/messages/ModelMessage.tsx b/packages/cli/src/ui/components/messages/ModelMessage.tsx index bddbae8e8b..b313dab6f1 100644 --- a/packages/cli/src/ui/components/messages/ModelMessage.tsx +++ b/packages/cli/src/ui/components/messages/ModelMessage.tsx @@ -7,6 +7,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; import { theme } from '../../semantic-colors.js'; +import { getDisplayString } from '@google/gemini-cli-core'; interface ModelMessageProps { model: string; @@ -15,7 +16,7 @@ interface ModelMessageProps { export const ModelMessage: React.FC = ({ model }) => ( - Responding with {model} + Responding with {getDisplayString(model)} ); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts index d0f81bea60..5d6db5abfa 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -155,9 +155,10 @@ describe('useQuotaAndFallback', () => { expect(request?.isTerminalQuotaError).toBe(true); const message = request!.message; - expect(message).toContain('Usage limit reached for gemini-pro.'); + expect(message).toContain('Usage limit reached for all Pro models.'); expect(message).toContain('Access resets at'); // From getResetTimeMessage expect(message).toContain('/stats model for usage details'); + expect(message).toContain('/model to switch models.'); expect(message).toContain('/auth to switch to API key.'); expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); @@ -176,6 +177,77 @@ describe('useQuotaAndFallback', () => { expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); }); + it('should show the model name for a terminal quota error on a non-pro model', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + let promise: Promise; + const error = new TerminalQuotaError( + 'flash quota', + mockGoogleApiError, + 1000 * 60 * 5, + ); + act(() => { + promise = handler('gemini-flash', 'gemini-pro', error); + }); + + const request = result.current.proQuotaRequest; + expect(request).not.toBeNull(); + expect(request?.failedModel).toBe('gemini-flash'); + + const message = request!.message; + expect(message).toContain('Usage limit reached for gemini-flash.'); + expect(message).not.toContain('all Pro models'); + + act(() => { + result.current.handleProQuotaChoice('retry_later'); + }); + + await promise!; + }); + + it('should handle terminal quota error without retry delay', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + let promise: Promise; + const error = new TerminalQuotaError('no delay', mockGoogleApiError); + act(() => { + promise = handler('gemini-pro', 'gemini-flash', error); + }); + + const request = result.current.proQuotaRequest; + const message = request!.message; + expect(message).not.toContain('Access resets at'); + expect(message).toContain('Usage limit reached for all Pro models.'); + + act(() => { + result.current.handleProQuotaChoice('retry_later'); + }); + + await promise!; + }); + it('should handle race conditions by stopping subsequent requests', async () => { const { result } = renderHook(() => useQuotaAndFallback({ diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts index 8a62ace716..1ba03f2a47 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -16,6 +16,7 @@ import { type UserTierId, VALID_GEMINI_MODELS, isProModel, + getDisplayString, } from '@google/gemini-cli-core'; import { useCallback, useEffect, useRef, useState } from 'react'; import { type UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -84,7 +85,7 @@ export function useQuotaAndFallback({ isModelNotFoundError = true; if (VALID_GEMINI_MODELS.has(failedModel)) { const messageLines = [ - `It seems like you don't have access to ${failedModel}.`, + `It seems like you don't have access to ${getDisplayString(failedModel)}.`, `Your admin might have disabled the access. Contact them to enable the Preview Release Channel.`, ]; message = messageLines.join('\n'); diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index bfc6b23c9c..7e6a619cdf 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -26,8 +26,40 @@ import { PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, isActiveModel, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + isPreviewModel, + isProModel, } from './models.js'; +describe('isPreviewModel', () => { + it('should return true for preview models', () => { + expect(isPreviewModel(PREVIEW_GEMINI_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_3_1_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_FLASH_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return false for non-preview models', () => { + expect(isPreviewModel(DEFAULT_GEMINI_MODEL)).toBe(false); + expect(isPreviewModel('gemini-1.5-pro')).toBe(false); + }); +}); + +describe('isProModel', () => { + it('should return true for models containing "pro"', () => { + expect(isProModel('gemini-3-pro-preview')).toBe(true); + expect(isProModel('gemini-2.5-pro')).toBe(true); + expect(isProModel('pro')).toBe(true); + }); + + it('should return false for models without "pro"', () => { + expect(isProModel('gemini-3-flash-preview')).toBe(false); + expect(isProModel('gemini-2.5-flash')).toBe(false); + expect(isProModel('auto')).toBe(false); + }); +}); + describe('isCustomModel', () => { it('should return true for models not starting with gemini-', () => { expect(isCustomModel('testing')).toBe(true); @@ -116,6 +148,12 @@ describe('getDisplayString', () => { ); }); + it('should return PREVIEW_GEMINI_3_1_MODEL for PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL', () => { + expect(getDisplayString(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL)).toBe( + PREVIEW_GEMINI_3_1_MODEL, + ); + }); + it('should return the model name as is for other models', () => { expect(getDisplayString('custom-model')).toBe('custom-model'); expect(getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL)).toBe( @@ -147,6 +185,16 @@ describe('resolveModel', () => { expect(model).toBe(PREVIEW_GEMINI_MODEL); }); + it('should return Gemini 3.1 Pro when auto-gemini-3 is requested and useGemini3_1 is true', () => { + const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true); + expect(model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + + it('should return Gemini 3.1 Pro Custom Tools when auto-gemini-3 is requested, useGemini3_1 is true, and useCustomToolModel is true', () => { + const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true, true); + expect(model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + it('should return the Default Pro model when auto-gemini-2.5 is requested', () => { const model = resolveModel(DEFAULT_GEMINI_MODEL_AUTO); expect(model).toBe(DEFAULT_GEMINI_MODEL); @@ -240,6 +288,27 @@ describe('resolveClassifierModel', () => { resolveClassifierModel(PREVIEW_GEMINI_MODEL_AUTO, GEMINI_MODEL_ALIAS_PRO), ).toBe(PREVIEW_GEMINI_MODEL); }); + + it('should return Gemini 3.1 Pro when alias is pro and useGemini3_1 is true', () => { + expect( + resolveClassifierModel( + PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_PRO, + true, + ), + ).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + + it('should return Gemini 3.1 Pro Custom Tools when alias is pro, useGemini3_1 is true, and useCustomToolModel is true', () => { + expect( + resolveClassifierModel( + PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_PRO, + true, + true, + ), + ).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); }); describe('isActiveModel', () => { @@ -249,9 +318,9 @@ describe('isActiveModel', () => { expect(isActiveModel(DEFAULT_GEMINI_FLASH_MODEL)).toBe(true); }); - it('should return false for invalid models', () => { - expect(isActiveModel('invalid-model')).toBe(false); - expect(isActiveModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(false); + it('should return true for unknown models and aliases (to support test models)', () => { + expect(isActiveModel('invalid-model')).toBe(true); + expect(isActiveModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(true); }); it('should return false for PREVIEW_GEMINI_MODEL when useGemini3_1 is true', () => { @@ -261,4 +330,29 @@ describe('isActiveModel', () => { it('should return true for other valid models when useGemini3_1 is true', () => { expect(isActiveModel(DEFAULT_GEMINI_MODEL, true)).toBe(true); }); + + it('should correctly filter Gemini 3.1 models based on useCustomToolModel when useGemini3_1 is true', () => { + // When custom tools are preferred, standard 3.1 should be inactive + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, true)).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, true), + ).toBe(true); + + // When custom tools are NOT preferred, custom tools 3.1 should be inactive + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, false)).toBe(true); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, false), + ).toBe(false); + }); + + it('should return false for both Gemini 3.1 models when useGemini3_1 is false', () => { + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, true)).toBe(false); + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, false)).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, true), + ).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, false), + ).toBe(false); + }); }); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 9ee8485fd1..5e3b0a2984 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -6,6 +6,8 @@ export const PREVIEW_GEMINI_MODEL = 'gemini-3-pro-preview'; export const PREVIEW_GEMINI_3_1_MODEL = 'gemini-3.1-pro-preview'; +export const PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL = + 'gemini-3.1-pro-preview-customtools'; export const PREVIEW_GEMINI_FLASH_MODEL = 'gemini-3-flash-preview'; export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro'; export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash'; @@ -14,6 +16,7 @@ export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite'; export const VALID_GEMINI_MODELS = new Set([ PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL, @@ -45,19 +48,23 @@ export const DEFAULT_THINKING_MODE = 8192; export function resolveModel( requestedModel: string, useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, ): string { switch (requestedModel) { case PREVIEW_GEMINI_MODEL: - case PREVIEW_GEMINI_MODEL_AUTO: { - return useGemini3_1 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL; + case PREVIEW_GEMINI_MODEL_AUTO: + case GEMINI_MODEL_ALIAS_AUTO: + case GEMINI_MODEL_ALIAS_PRO: { + if (useGemini3_1) { + return useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : PREVIEW_GEMINI_3_1_MODEL; + } + return PREVIEW_GEMINI_MODEL; } case DEFAULT_GEMINI_MODEL_AUTO: { return DEFAULT_GEMINI_MODEL; } - case GEMINI_MODEL_ALIAS_AUTO: - case GEMINI_MODEL_ALIAS_PRO: { - return useGemini3_1 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL; - } case GEMINI_MODEL_ALIAS_FLASH: { return PREVIEW_GEMINI_FLASH_MODEL; } @@ -80,6 +87,8 @@ export function resolveModel( export function resolveClassifierModel( requestedModel: string, modelAlias: string, + useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, ): string { if (modelAlias === GEMINI_MODEL_ALIAS_FLASH) { if ( @@ -96,7 +105,7 @@ export function resolveClassifierModel( } return resolveModel(GEMINI_MODEL_ALIAS_FLASH); } - return resolveModel(requestedModel); + return resolveModel(requestedModel, useGemini3_1, useCustomToolModel); } export function getDisplayString(model: string) { switch (model) { @@ -108,6 +117,8 @@ export function getDisplayString(model: string) { return PREVIEW_GEMINI_MODEL; case GEMINI_MODEL_ALIAS_FLASH: return PREVIEW_GEMINI_FLASH_MODEL; + case PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL: + return PREVIEW_GEMINI_3_1_MODEL; default: return model; } @@ -135,11 +146,7 @@ export function isPreviewModel(model: string): boolean { * @returns True if the model is a Pro model. */ export function isProModel(model: string): boolean { - return ( - model === PREVIEW_GEMINI_MODEL || - model === PREVIEW_GEMINI_3_1_MODEL || - model === DEFAULT_GEMINI_MODEL - ); + return model.toLowerCase().includes('pro'); } /** @@ -221,13 +228,24 @@ export function supportsMultimodalFunctionResponse(model: string): boolean { export function isActiveModel( model: string, useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, ): boolean { if (!VALID_GEMINI_MODELS.has(model)) { - return false; + return true; } if (useGemini3_1) { - return model !== PREVIEW_GEMINI_MODEL; + if (model === PREVIEW_GEMINI_MODEL) { + return false; + } + if (useCustomToolModel) { + return model !== PREVIEW_GEMINI_3_1_MODEL; + } else { + return model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL; + } } else { - return model !== PREVIEW_GEMINI_3_1_MODEL; + return ( + model !== PREVIEW_GEMINI_3_1_MODEL && + model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + ); } } diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index b2c7a8797e..c9f33b11ec 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -18,11 +18,13 @@ import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { AuthType } from '../../core/contentGenerator.js'; vi.mock('../../core/baseLlmClient.js'); @@ -53,6 +55,10 @@ describe('ClassifierStrategy', () => { }, getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), + getGemini31Launched: vi.fn().mockResolvedValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), @@ -339,4 +345,49 @@ describe('ClassifierStrategy', () => { // Since requestedModel is Pro, and choice is flash, it should resolve to Flash expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); }); + + describe('Gemini 3.1 and Custom Tools Routing', () => { + it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); + const mockApiResponse = { + reasoning: 'Complex task', + model_choice: 'pro', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + + it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + const mockApiResponse = { + reasoning: 'Complex task', + model_choice: 'pro', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + }); }); diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 980e89829d..94627c5377 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -21,6 +21,7 @@ import { } from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { LlmRole } from '../../telemetry/types.js'; +import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 4; @@ -169,9 +170,15 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; + const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType !== AuthType.USE_VERTEX_AI; const selectedModel = resolveClassifierModel( model, routerResponse.model_choice, + useGemini3_1, + useCustomToolModel, ); return { diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 8767709f68..37e9d18af7 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -12,6 +12,8 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL, @@ -20,6 +22,7 @@ import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { AuthType } from '../../core/contentGenerator.js'; vi.mock('../../core/baseLlmClient.js'); @@ -52,6 +55,10 @@ describe('NumericalClassifierStrategy', () => { getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), + getGemini31Launched: vi.fn().mockResolvedValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), @@ -535,4 +542,68 @@ describe('NumericalClassifierStrategy', () => { ), ); }); + + describe('Gemini 3.1 and Custom Tools Routing', () => { + it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + + it('should NOT route to custom tools model when auth is USE_VERTEX_AI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_VERTEX_AI, + }); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + }); }); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index d4ddf99b8d..b8ebc0a885 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -17,6 +17,7 @@ import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { LlmRole } from '../../telemetry/types.js'; +import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 8; @@ -182,8 +183,16 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config, config.getSessionId() || 'unknown-session', ); - - const selectedModel = resolveClassifierModel(model, modelAlias); + const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType !== AuthType.USE_VERTEX_AI; + const selectedModel = resolveClassifierModel( + model, + modelAlias, + useGemini3_1, + useCustomToolModel, + ); const latencyMs = Date.now() - startTime;