From f97b04cc9a8323e33eb7649cb0c464a60ad2ebc9 Mon Sep 17 00:00:00 2001 From: Sehoon Shon Date: Fri, 20 Feb 2026 14:19:21 -0500 Subject: [PATCH] feat(models): support Gemini 3.1 Pro Preview and fixes (#19676) --- packages/cli/src/ui/components/AboutBox.tsx | 5 +- .../src/ui/components/ModelDialog.test.tsx | 118 +++++++++++++++++- .../cli/src/ui/components/ModelDialog.tsx | 34 +++-- .../cli/src/ui/components/StatsDisplay.tsx | 32 +++-- .../ui/components/messages/ModelMessage.tsx | 3 +- .../src/ui/hooks/useQuotaAndFallback.test.ts | 74 ++++++++++- .../cli/src/ui/hooks/useQuotaAndFallback.ts | 14 +-- .../cli/src/zed-integration/zedIntegration.ts | 5 +- .../core/src/availability/policyHelpers.ts | 5 +- .../src/code_assist/experiments/flagNames.ts | 1 + packages/core/src/config/config.ts | 51 +++++++- packages/core/src/config/models.test.ts | 116 +++++++++++++++++ packages/core/src/config/models.ts | 75 +++++++++-- packages/core/src/core/client.ts | 5 +- packages/core/src/core/contentGenerator.ts | 7 +- packages/core/src/core/geminiChat.ts | 5 +- packages/core/src/prompts/promptProvider.ts | 10 +- .../strategies/classifierStrategy.test.ts | 52 ++++++++ .../routing/strategies/classifierStrategy.ts | 7 ++ .../src/routing/strategies/defaultStrategy.ts | 5 +- .../routing/strategies/fallbackStrategy.ts | 5 +- .../numericalClassifierStrategy.test.ts | 71 +++++++++++ .../strategies/numericalClassifierStrategy.ts | 13 +- .../routing/strategies/overrideStrategy.ts | 5 +- .../src/services/chatCompressionService.ts | 2 + 25 files changed, 670 insertions(+), 50 deletions(-) diff --git a/packages/cli/src/ui/components/AboutBox.tsx b/packages/cli/src/ui/components/AboutBox.tsx index ea5512b48d..7ea744b0fe 100644 --- a/packages/cli/src/ui/components/AboutBox.tsx +++ b/packages/cli/src/ui/components/AboutBox.tsx @@ -9,6 +9,7 @@ import { Box, Text } from 'ink'; import { theme } from '../semantic-colors.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { useSettings } from '../contexts/SettingsContext.js'; +import { getDisplayString } from '@google/gemini-cli-core'; interface AboutBoxProps { cliVersion: string; @@ -79,7 +80,9 @@ export const AboutBox: React.FC = ({ - {modelVersion} + + {getDisplayString(modelVersion)} + diff --git a/packages/cli/src/ui/components/ModelDialog.test.tsx b/packages/cli/src/ui/components/ModelDialog.test.tsx index e96694eeaf..6f347faa1d 100644 --- a/packages/cli/src/ui/components/ModelDialog.test.tsx +++ b/packages/cli/src/ui/components/ModelDialog.test.tsx @@ -9,11 +9,17 @@ import { act } from 'react'; import { ModelDialog } from './ModelDialog.js'; import { renderWithProviders } from '../../test-utils/render.js'; import { waitFor } from '../../test-utils/async.js'; +import { createMockSettings } from '../../test-utils/settings.js'; import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL, + PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + PREVIEW_GEMINI_FLASH_MODEL, + AuthType, } from '@google/gemini-cli-core'; import type { Config, ModelSlashCommandEvent } from '@google/gemini-cli-core'; @@ -42,12 +48,14 @@ describe('', () => { const mockGetModel = vi.fn(); const mockOnClose = vi.fn(); const mockGetHasAccessToPreviewModel = vi.fn(); + const mockGetGemini31LaunchedSync = vi.fn(); interface MockConfig extends Partial { setModel: (model: string, isTemporary?: boolean) => void; getModel: () => string; getHasAccessToPreviewModel: () => boolean; getIdeMode: () => boolean; + getGemini31LaunchedSync: () => boolean; } const mockConfig: MockConfig = { @@ -55,12 +63,14 @@ describe('', () => { getModel: mockGetModel, getHasAccessToPreviewModel: mockGetHasAccessToPreviewModel, getIdeMode: () => false, + getGemini31LaunchedSync: mockGetGemini31LaunchedSync, }; beforeEach(() => { vi.resetAllMocks(); mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); mockGetHasAccessToPreviewModel.mockReturnValue(false); + mockGetGemini31LaunchedSync.mockReturnValue(false); // Default implementation for getDisplayString mockGetDisplayString.mockImplementation((val: string) => { @@ -70,9 +80,21 @@ describe('', () => { }); }); - const renderComponent = async (configValue = mockConfig as Config) => { + const renderComponent = async ( + configValue = mockConfig as Config, + authType = AuthType.LOGIN_WITH_GOOGLE, + ) => { + const settings = createMockSettings({ + security: { + auth: { + selectedType: authType, + }, + }, + }); + const result = renderWithProviders(, { config: configValue, + settings, }); await result.waitUntilReady(); return result; @@ -241,4 +263,98 @@ describe('', () => { }); unmount(); }); + + it('shows the preferred manual model in the main view option', async () => { + mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL); + const { lastFrame, unmount } = await renderComponent(); + + expect(lastFrame()).toContain(`Manual (${DEFAULT_GEMINI_MODEL})`); + unmount(); + }); + + describe('Preview Models', () => { + beforeEach(() => { + mockGetHasAccessToPreviewModel.mockReturnValue(true); + }); + + it('shows Auto (Preview) in main view when access is granted', async () => { + const { lastFrame, unmount } = await renderComponent(); + expect(lastFrame()).toContain('Auto (Preview)'); + unmount(); + }); + + it('shows Gemini 3 models in manual view when Gemini 3.1 is NOT launched', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(false); + const { lastFrame, stdin, waitUntilReady, unmount } = + await renderComponent(); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + const output = lastFrame(); + expect(output).toContain(PREVIEW_GEMINI_MODEL); + expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL); + unmount(); + }); + + it('shows Gemini 3.1 models in manual view when Gemini 3.1 IS launched', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(true); + const { lastFrame, stdin, waitUntilReady, unmount } = + await renderComponent(mockConfig as Config, AuthType.USE_VERTEX_AI); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + const output = lastFrame(); + expect(output).toContain(PREVIEW_GEMINI_3_1_MODEL); + expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL); + unmount(); + }); + + it('uses custom tools model when Gemini 3.1 IS launched and auth is Gemini API Key', async () => { + mockGetGemini31LaunchedSync.mockReturnValue(true); + const { stdin, waitUntilReady, unmount } = await renderComponent( + mockConfig as Config, + AuthType.USE_GEMINI, + ); + + // Go to manual view + await act(async () => { + stdin.write('\u001B[B'); // Manual + }); + await waitUntilReady(); + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + // Select Gemini 3.1 (first item in preview section) + await act(async () => { + stdin.write('\r'); + }); + await waitUntilReady(); + + await waitFor(() => { + expect(mockSetModel).toHaveBeenCalledWith( + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + true, + ); + }); + unmount(); + }); + }); }); diff --git a/packages/cli/src/ui/components/ModelDialog.tsx b/packages/cli/src/ui/components/ModelDialog.tsx index 88be57b841..d50e9b7153 100644 --- a/packages/cli/src/ui/components/ModelDialog.tsx +++ b/packages/cli/src/ui/components/ModelDialog.tsx @@ -9,6 +9,7 @@ import { useCallback, useContext, useMemo, useState } from 'react'; import { Box, Text } from 'ink'; import { PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL, @@ -18,11 +19,14 @@ import { ModelSlashCommandEvent, logModelSlashCommand, getDisplayString, + AuthType, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, } from '@google/gemini-cli-core'; import { useKeypress } from '../hooks/useKeypress.js'; import { theme } from '../semantic-colors.js'; import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js'; import { ConfigContext } from '../contexts/ConfigContext.js'; +import { useSettings } from '../contexts/SettingsContext.js'; interface ModelDialogProps { onClose: () => void; @@ -30,6 +34,7 @@ interface ModelDialogProps { export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { const config = useContext(ConfigContext); + const settings = useSettings(); const [view, setView] = useState<'main' | 'manual'>('main'); const [persistMode, setPersistMode] = useState(false); @@ -37,6 +42,10 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { const preferredModel = config?.getModel() || DEFAULT_GEMINI_MODEL_AUTO; const shouldShowPreviewModels = config?.getHasAccessToPreviewModel(); + const useGemini31 = config?.getGemini31LaunchedSync?.() ?? false; + const selectedAuthType = settings.merged.security.auth.selectedType; + const useCustomToolModel = + useGemini31 && selectedAuthType === AuthType.USE_GEMINI; const manualModelSelected = useMemo(() => { const manualModels = [ @@ -44,6 +53,8 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL, PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_FLASH_MODEL, ]; if (manualModels.includes(preferredModel)) { @@ -94,13 +105,14 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { list.unshift({ value: PREVIEW_GEMINI_MODEL_AUTO, title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO), - description: - 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', + description: useGemini31 + ? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash' + : 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', key: PREVIEW_GEMINI_MODEL_AUTO, }); } return list; - }, [shouldShowPreviewModels, manualModelSelected]); + }, [shouldShowPreviewModels, manualModelSelected, useGemini31]); const manualOptions = useMemo(() => { const list = [ @@ -122,11 +134,19 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { ]; if (shouldShowPreviewModels) { + const previewProModel = useGemini31 + ? PREVIEW_GEMINI_3_1_MODEL + : PREVIEW_GEMINI_MODEL; + + const previewProValue = useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : previewProModel; + list.unshift( { - value: PREVIEW_GEMINI_MODEL, - title: PREVIEW_GEMINI_MODEL, - key: PREVIEW_GEMINI_MODEL, + value: previewProValue, + title: previewProModel, + key: previewProModel, }, { value: PREVIEW_GEMINI_FLASH_MODEL, @@ -136,7 +156,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { ); } return list; - }, [shouldShowPreviewModels]); + }, [shouldShowPreviewModels, useGemini31, useCustomToolModel]); const options = view === 'main' ? mainOptions : manualOptions; diff --git a/packages/cli/src/ui/components/StatsDisplay.tsx b/packages/cli/src/ui/components/StatsDisplay.tsx index 3b42512424..d12dd4eb07 100644 --- a/packages/cli/src/ui/components/StatsDisplay.tsx +++ b/packages/cli/src/ui/components/StatsDisplay.tsx @@ -23,11 +23,13 @@ import { import { computeSessionStats } from '../utils/computeStats.js'; import { type RetrieveUserQuotaResponse, - VALID_GEMINI_MODELS, + isActiveModel, getDisplayString, isAutoModel, + AuthType, } from '@google/gemini-cli-core'; import { useSettings } from '../contexts/SettingsContext.js'; +import { useConfig } from '../contexts/ConfigContext.js'; import type { QuotaStats } from '../types.js'; import { QuotaStatsInfo } from './QuotaStatsInfo.js'; @@ -82,9 +84,13 @@ const Section: React.FC = ({ title, children }) => ( const buildModelRows = ( models: Record, quotas?: RetrieveUserQuotaResponse, + useGemini3_1 = false, + useCustomToolModel = false, ) => { const getBaseModelName = (name: string) => name.replace('-001', ''); - const usedModelNames = new Set(Object.keys(models).map(getBaseModelName)); + const usedModelNames = new Set( + Object.keys(models).map(getBaseModelName).map(getDisplayString), + ); // 1. Models with active usage const activeRows = Object.entries(models).map(([name, metrics]) => { @@ -93,7 +99,7 @@ const buildModelRows = ( const inputTokens = metrics.tokens.input; return { key: name, - modelName, + modelName: getDisplayString(modelName), requests: metrics.api.totalRequests, cachedTokens: cachedTokens.toLocaleString(), inputTokens: inputTokens.toLocaleString(), @@ -109,12 +115,12 @@ const buildModelRows = ( ?.filter( (b) => b.modelId && - VALID_GEMINI_MODELS.has(b.modelId) && - !usedModelNames.has(b.modelId), + isActiveModel(b.modelId, useGemini3_1, useCustomToolModel) && + !usedModelNames.has(getDisplayString(b.modelId)), ) .map((bucket) => ({ key: bucket.modelId!, - modelName: bucket.modelId!, + modelName: getDisplayString(bucket.modelId!), requests: '-', cachedTokens: '-', inputTokens: '-', @@ -135,6 +141,8 @@ const ModelUsageTable: React.FC<{ pooledRemaining?: number; pooledLimit?: number; pooledResetTime?: string; + useGemini3_1?: boolean; + useCustomToolModel?: boolean; }> = ({ models, quotas, @@ -144,8 +152,10 @@ const ModelUsageTable: React.FC<{ pooledRemaining, pooledLimit, pooledResetTime, + useGemini3_1, + useCustomToolModel, }) => { - const rows = buildModelRows(models, quotas); + const rows = buildModelRows(models, quotas, useGemini3_1, useCustomToolModel); if (rows.length === 0) { return null; @@ -403,7 +413,11 @@ export const StatsDisplay: React.FC = ({ const { models, tools, files } = metrics; const computed = computeSessionStats(metrics); const settings = useSettings(); - + const config = useConfig(); + const useGemini3_1 = config.getGemini31LaunchedSync?.() ?? false; + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI; const pooledRemaining = quotaStats?.remaining; const pooledLimit = quotaStats?.limit; const pooledResetTime = quotaStats?.resetTime; @@ -544,6 +558,8 @@ export const StatsDisplay: React.FC = ({ pooledRemaining={pooledRemaining} pooledLimit={pooledLimit} pooledResetTime={pooledResetTime} + useGemini3_1={useGemini3_1} + useCustomToolModel={useCustomToolModel} /> {renderFooter()} diff --git a/packages/cli/src/ui/components/messages/ModelMessage.tsx b/packages/cli/src/ui/components/messages/ModelMessage.tsx index bddbae8e8b..b313dab6f1 100644 --- a/packages/cli/src/ui/components/messages/ModelMessage.tsx +++ b/packages/cli/src/ui/components/messages/ModelMessage.tsx @@ -7,6 +7,7 @@ import type React from 'react'; import { Text, Box } from 'ink'; import { theme } from '../../semantic-colors.js'; +import { getDisplayString } from '@google/gemini-cli-core'; interface ModelMessageProps { model: string; @@ -15,7 +16,7 @@ interface ModelMessageProps { export const ModelMessage: React.FC = ({ model }) => ( - Responding with {model} + Responding with {getDisplayString(model)} ); diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts index d0f81bea60..5d6db5abfa 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.test.ts @@ -155,9 +155,10 @@ describe('useQuotaAndFallback', () => { expect(request?.isTerminalQuotaError).toBe(true); const message = request!.message; - expect(message).toContain('Usage limit reached for gemini-pro.'); + expect(message).toContain('Usage limit reached for all Pro models.'); expect(message).toContain('Access resets at'); // From getResetTimeMessage expect(message).toContain('/stats model for usage details'); + expect(message).toContain('/model to switch models.'); expect(message).toContain('/auth to switch to API key.'); expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); @@ -176,6 +177,77 @@ describe('useQuotaAndFallback', () => { expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); }); + it('should show the model name for a terminal quota error on a non-pro model', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + let promise: Promise; + const error = new TerminalQuotaError( + 'flash quota', + mockGoogleApiError, + 1000 * 60 * 5, + ); + act(() => { + promise = handler('gemini-flash', 'gemini-pro', error); + }); + + const request = result.current.proQuotaRequest; + expect(request).not.toBeNull(); + expect(request?.failedModel).toBe('gemini-flash'); + + const message = request!.message; + expect(message).toContain('Usage limit reached for gemini-flash.'); + expect(message).not.toContain('all Pro models'); + + act(() => { + result.current.handleProQuotaChoice('retry_later'); + }); + + await promise!; + }); + + it('should handle terminal quota error without retry delay', async () => { + const { result } = renderHook(() => + useQuotaAndFallback({ + config: mockConfig, + historyManager: mockHistoryManager, + userTier: UserTierId.FREE, + setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError, + onShowAuthSelection: mockOnShowAuthSelection, + }), + ); + + const handler = setFallbackHandlerSpy.mock + .calls[0][0] as FallbackModelHandler; + + let promise: Promise; + const error = new TerminalQuotaError('no delay', mockGoogleApiError); + act(() => { + promise = handler('gemini-pro', 'gemini-flash', error); + }); + + const request = result.current.proQuotaRequest; + const message = request!.message; + expect(message).not.toContain('Access resets at'); + expect(message).toContain('Usage limit reached for all Pro models.'); + + act(() => { + result.current.handleProQuotaChoice('retry_later'); + }); + + await promise!; + }); + it('should handle race conditions by stopping subsequent requests', async () => { const { result } = renderHook(() => useQuotaAndFallback({ diff --git a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts index 60c91f3143..1ba03f2a47 100644 --- a/packages/cli/src/ui/hooks/useQuotaAndFallback.ts +++ b/packages/cli/src/ui/hooks/useQuotaAndFallback.ts @@ -14,9 +14,9 @@ import { TerminalQuotaError, ModelNotFoundError, type UserTierId, - PREVIEW_GEMINI_MODEL, - DEFAULT_GEMINI_MODEL, VALID_GEMINI_MODELS, + isProModel, + getDisplayString, } from '@google/gemini-cli-core'; import { useCallback, useEffect, useRef, useState } from 'react'; import { type UseHistoryManagerReturn } from './useHistoryManager.js'; @@ -67,11 +67,9 @@ export function useQuotaAndFallback({ let message: string; let isTerminalQuotaError = false; let isModelNotFoundError = false; - const usageLimitReachedModel = - failedModel === DEFAULT_GEMINI_MODEL || - failedModel === PREVIEW_GEMINI_MODEL - ? 'all Pro models' - : failedModel; + const usageLimitReachedModel = isProModel(failedModel) + ? 'all Pro models' + : failedModel; if (error instanceof TerminalQuotaError) { isTerminalQuotaError = true; // Common part of the message for both tiers @@ -87,7 +85,7 @@ export function useQuotaAndFallback({ isModelNotFoundError = true; if (VALID_GEMINI_MODELS.has(failedModel)) { const messageLines = [ - `It seems like you don't have access to ${failedModel}.`, + `It seems like you don't have access to ${getDisplayString(failedModel)}.`, `Your admin might have disabled the access. Contact them to enable the Preview Release Channel.`, ]; message = messageLines.join('\n'); diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 44b1890ce2..f04caf01f7 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -520,7 +520,10 @@ export class Session { const functionCalls: FunctionCall[] = []; try { - const model = resolveModel(this.config.getModel()); + const model = resolveModel( + this.config.getModel(), + (await this.config.getGemini31Launched?.()) ?? false, + ); const responseStream = await chat.sendMessageStream( { model }, nextMessage?.parts ?? [], diff --git a/packages/core/src/availability/policyHelpers.ts b/packages/core/src/availability/policyHelpers.ts index 569157561f..6cf22d6388 100644 --- a/packages/core/src/availability/policyHelpers.ts +++ b/packages/core/src/availability/policyHelpers.ts @@ -44,7 +44,10 @@ export function resolvePolicyChain( const configuredModel = config.getModel(); let chain; - const resolvedModel = resolveModel(modelFromConfig); + const resolvedModel = resolveModel( + modelFromConfig, + config.getGemini31LaunchedSync?.() ?? false, + ); const isAutoPreferred = preferredModel ? isAutoModel(preferredModel) : false; const isAutoConfigured = isAutoModel(configuredModel); const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true; diff --git a/packages/core/src/code_assist/experiments/flagNames.ts b/packages/core/src/code_assist/experiments/flagNames.ts index 03b6aaac0a..e1ae2a1af2 100644 --- a/packages/core/src/code_assist/experiments/flagNames.ts +++ b/packages/core/src/code_assist/experiments/flagNames.ts @@ -16,6 +16,7 @@ export const ExperimentFlags = { MASKING_PROTECTION_THRESHOLD: 45758817, MASKING_PRUNABLE_THRESHOLD: 45758818, MASKING_PROTECT_LATEST_TURN: 45758819, + GEMINI_3_1_PRO_LAUNCHED: 45760185, } as const; export type ExperimentFlagName = diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 406835310a..fc4f7c2ff7 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -1071,6 +1071,12 @@ export class Config { // Reset availability status when switching auth (e.g. from limited key to OAuth) this.modelAvailabilityService.reset(); + // Clear stale authType to ensure getGemini31LaunchedSync doesn't return stale results + // during the transition. + if (this.contentGeneratorConfig) { + this.contentGeneratorConfig.authType = undefined; + } + const newContentGeneratorConfig = await createContentGeneratorConfig( this, authMethod, @@ -1350,7 +1356,10 @@ export class Config { if (pooled.remaining !== undefined) { return pooled.remaining; } - const primaryModel = resolveModel(this.getModel()); + const primaryModel = resolveModel( + this.getModel(), + this.getGemini31LaunchedSync(), + ); return this.modelQuotas.get(primaryModel)?.remaining; } @@ -1359,7 +1368,10 @@ export class Config { if (pooled.limit !== undefined) { return pooled.limit; } - const primaryModel = resolveModel(this.getModel()); + const primaryModel = resolveModel( + this.getModel(), + this.getGemini31LaunchedSync(), + ); return this.modelQuotas.get(primaryModel)?.limit; } @@ -1368,7 +1380,10 @@ export class Config { if (pooled.resetTime !== undefined) { return pooled.resetTime; } - const primaryModel = resolveModel(this.getModel()); + const primaryModel = resolveModel( + this.getModel(), + this.getGemini31LaunchedSync(), + ); return this.modelQuotas.get(primaryModel)?.resetTime; } @@ -2253,6 +2268,36 @@ export class Config { ); } + /** + * Returns whether Gemini 3.1 has been launched. + * This method is async and ensures that experiments are loaded before returning the result. + */ + async getGemini31Launched(): Promise { + await this.ensureExperimentsLoaded(); + return this.getGemini31LaunchedSync(); + } + + /** + * Returns whether Gemini 3.1 has been launched. + * + * Note: This method should only be called after startup, once experiments have been loaded. + * If you need to call this during startup or from an async context, use + * getGemini31Launched instead. + */ + getGemini31LaunchedSync(): boolean { + const authType = this.contentGeneratorConfig?.authType; + if ( + authType === AuthType.USE_GEMINI || + authType === AuthType.USE_VERTEX_AI + ) { + return true; + } + return ( + this.experiments?.flags[ExperimentFlags.GEMINI_3_1_PRO_LAUNCHED] + ?.boolValue ?? false + ); + } + private async ensureExperimentsLoaded(): Promise { if (!this.experimentsPromise) { return; diff --git a/packages/core/src/config/models.test.ts b/packages/core/src/config/models.test.ts index 2b2ddb1041..c16cf49781 100644 --- a/packages/core/src/config/models.test.ts +++ b/packages/core/src/config/models.test.ts @@ -25,8 +25,41 @@ import { PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, + isActiveModel, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, + isPreviewModel, + isProModel, } from './models.js'; +describe('isPreviewModel', () => { + it('should return true for preview models', () => { + expect(isPreviewModel(PREVIEW_GEMINI_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_3_1_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_FLASH_MODEL)).toBe(true); + expect(isPreviewModel(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true); + }); + + it('should return false for non-preview models', () => { + expect(isPreviewModel(DEFAULT_GEMINI_MODEL)).toBe(false); + expect(isPreviewModel('gemini-1.5-pro')).toBe(false); + }); +}); + +describe('isProModel', () => { + it('should return true for models containing "pro"', () => { + expect(isProModel('gemini-3-pro-preview')).toBe(true); + expect(isProModel('gemini-2.5-pro')).toBe(true); + expect(isProModel('pro')).toBe(true); + }); + + it('should return false for models without "pro"', () => { + expect(isProModel('gemini-3-flash-preview')).toBe(false); + expect(isProModel('gemini-2.5-flash')).toBe(false); + expect(isProModel('auto')).toBe(false); + }); +}); + describe('isCustomModel', () => { it('should return true for models not starting with gemini-', () => { expect(isCustomModel('testing')).toBe(true); @@ -115,6 +148,12 @@ describe('getDisplayString', () => { ); }); + it('should return PREVIEW_GEMINI_3_1_MODEL for PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL', () => { + expect(getDisplayString(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL)).toBe( + PREVIEW_GEMINI_3_1_MODEL, + ); + }); + it('should return the model name as is for other models', () => { expect(getDisplayString('custom-model')).toBe('custom-model'); expect(getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL)).toBe( @@ -146,6 +185,16 @@ describe('resolveModel', () => { expect(model).toBe(PREVIEW_GEMINI_MODEL); }); + it('should return Gemini 3.1 Pro when auto-gemini-3 is requested and useGemini3_1 is true', () => { + const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true); + expect(model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + + it('should return Gemini 3.1 Pro Custom Tools when auto-gemini-3 is requested, useGemini3_1 is true, and useCustomToolModel is true', () => { + const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true, true); + expect(model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + it('should return the Default Pro model when auto-gemini-2.5 is requested', () => { const model = resolveModel(DEFAULT_GEMINI_MODEL_AUTO); expect(model).toBe(DEFAULT_GEMINI_MODEL); @@ -239,4 +288,71 @@ describe('resolveClassifierModel', () => { resolveClassifierModel(PREVIEW_GEMINI_MODEL_AUTO, GEMINI_MODEL_ALIAS_PRO), ).toBe(PREVIEW_GEMINI_MODEL); }); + + it('should return Gemini 3.1 Pro when alias is pro and useGemini3_1 is true', () => { + expect( + resolveClassifierModel( + PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_PRO, + true, + ), + ).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + + it('should return Gemini 3.1 Pro Custom Tools when alias is pro, useGemini3_1 is true, and useCustomToolModel is true', () => { + expect( + resolveClassifierModel( + PREVIEW_GEMINI_MODEL_AUTO, + GEMINI_MODEL_ALIAS_PRO, + true, + true, + ), + ).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); +}); + +describe('isActiveModel', () => { + it('should return true for valid models when useGemini3_1 is false', () => { + expect(isActiveModel(DEFAULT_GEMINI_MODEL)).toBe(true); + expect(isActiveModel(PREVIEW_GEMINI_MODEL)).toBe(true); + expect(isActiveModel(DEFAULT_GEMINI_FLASH_MODEL)).toBe(true); + }); + + it('should return true for unknown models and aliases', () => { + expect(isActiveModel('invalid-model')).toBe(false); + expect(isActiveModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(false); + }); + + it('should return false for PREVIEW_GEMINI_MODEL when useGemini3_1 is true', () => { + expect(isActiveModel(PREVIEW_GEMINI_MODEL, true)).toBe(false); + }); + + it('should return true for other valid models when useGemini3_1 is true', () => { + expect(isActiveModel(DEFAULT_GEMINI_MODEL, true)).toBe(true); + }); + + it('should correctly filter Gemini 3.1 models based on useCustomToolModel when useGemini3_1 is true', () => { + // When custom tools are preferred, standard 3.1 should be inactive + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, true)).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, true), + ).toBe(true); + + // When custom tools are NOT preferred, custom tools 3.1 should be inactive + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, false)).toBe(true); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, false), + ).toBe(false); + }); + + it('should return false for both Gemini 3.1 models when useGemini3_1 is false', () => { + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, true)).toBe(false); + expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, false)).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, true), + ).toBe(false); + expect( + isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, false), + ).toBe(false); + }); }); diff --git a/packages/core/src/config/models.ts b/packages/core/src/config/models.ts index 9f12944333..d0ec49f005 100644 --- a/packages/core/src/config/models.ts +++ b/packages/core/src/config/models.ts @@ -5,6 +5,9 @@ */ export const PREVIEW_GEMINI_MODEL = 'gemini-3-pro-preview'; +export const PREVIEW_GEMINI_3_1_MODEL = 'gemini-3.1-pro-preview'; +export const PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL = + 'gemini-3.1-pro-preview-customtools'; export const PREVIEW_GEMINI_FLASH_MODEL = 'gemini-3-flash-preview'; export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro'; export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash'; @@ -12,6 +15,8 @@ export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite'; export const VALID_GEMINI_MODELS = new Set([ PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_FLASH_MODEL, @@ -37,20 +42,29 @@ export const DEFAULT_THINKING_MODE = 8192; * to a concrete model name. * * @param requestedModel The model alias or concrete model name requested by the user. + * @param useGemini3_1 Whether to use Gemini 3.1 Pro Preview for auto/pro aliases. * @returns The resolved concrete model name. */ -export function resolveModel(requestedModel: string): string { +export function resolveModel( + requestedModel: string, + useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, +): string { switch (requestedModel) { - case PREVIEW_GEMINI_MODEL_AUTO: { + case PREVIEW_GEMINI_MODEL: + case PREVIEW_GEMINI_MODEL_AUTO: + case GEMINI_MODEL_ALIAS_AUTO: + case GEMINI_MODEL_ALIAS_PRO: { + if (useGemini3_1) { + return useCustomToolModel + ? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + : PREVIEW_GEMINI_3_1_MODEL; + } return PREVIEW_GEMINI_MODEL; } case DEFAULT_GEMINI_MODEL_AUTO: { return DEFAULT_GEMINI_MODEL; } - case GEMINI_MODEL_ALIAS_AUTO: - case GEMINI_MODEL_ALIAS_PRO: { - return PREVIEW_GEMINI_MODEL; - } case GEMINI_MODEL_ALIAS_FLASH: { return PREVIEW_GEMINI_FLASH_MODEL; } @@ -73,6 +87,8 @@ export function resolveModel(requestedModel: string): string { export function resolveClassifierModel( requestedModel: string, modelAlias: string, + useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, ): string { if (modelAlias === GEMINI_MODEL_ALIAS_FLASH) { if ( @@ -89,7 +105,7 @@ export function resolveClassifierModel( } return resolveModel(GEMINI_MODEL_ALIAS_FLASH); } - return resolveModel(requestedModel); + return resolveModel(requestedModel, useGemini3_1, useCustomToolModel); } export function getDisplayString(model: string) { switch (model) { @@ -101,6 +117,8 @@ export function getDisplayString(model: string) { return PREVIEW_GEMINI_MODEL; case GEMINI_MODEL_ALIAS_FLASH: return PREVIEW_GEMINI_FLASH_MODEL; + case PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL: + return PREVIEW_GEMINI_3_1_MODEL; default: return model; } @@ -115,11 +133,22 @@ export function getDisplayString(model: string) { export function isPreviewModel(model: string): boolean { return ( model === PREVIEW_GEMINI_MODEL || + model === PREVIEW_GEMINI_3_1_MODEL || model === PREVIEW_GEMINI_FLASH_MODEL || model === PREVIEW_GEMINI_MODEL_AUTO ); } +/** + * Checks if the model is a Pro model. + * + * @param model The model name to check. + * @returns True if the model is a Pro model. + */ +export function isProModel(model: string): boolean { + return model.toLowerCase().includes('pro'); +} + /** * Checks if the model is a Gemini 3 model. * @@ -188,3 +217,35 @@ export function isAutoModel(model: string): boolean { export function supportsMultimodalFunctionResponse(model: string): boolean { return model.startsWith('gemini-3-'); } + +/** + * Checks if the given model is considered active based on the current configuration. + * + * @param model The model name to check. + * @param useGemini3_1 Whether Gemini 3.1 Pro Preview is enabled. + * @returns True if the model is active. + */ +export function isActiveModel( + model: string, + useGemini3_1: boolean = false, + useCustomToolModel: boolean = false, +): boolean { + if (!VALID_GEMINI_MODELS.has(model)) { + return false; + } + if (useGemini3_1) { + if (model === PREVIEW_GEMINI_MODEL) { + return false; + } + if (useCustomToolModel) { + return model !== PREVIEW_GEMINI_3_1_MODEL; + } else { + return model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL; + } + } else { + return ( + model !== PREVIEW_GEMINI_3_1_MODEL && + model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL + ); + } +} diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index 0951eb397b..efa35a868b 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -542,7 +542,10 @@ export class GeminiClient { // Availability logic: The configured model is the source of truth, // including any permanent fallbacks (config.setModel) or manual overrides. - return resolveModel(this.config.getActiveModel()); + return resolveModel( + this.config.getActiveModel(), + this.config.getGemini31LaunchedSync?.() ?? false, + ); } private async *processTurn( diff --git a/packages/core/src/core/contentGenerator.ts b/packages/core/src/core/contentGenerator.ts index bfd8221f75..7adae874aa 100644 --- a/packages/core/src/core/contentGenerator.ts +++ b/packages/core/src/core/contentGenerator.ts @@ -146,7 +146,12 @@ export async function createContentGenerator( return new LoggingContentGenerator(fakeGenerator, gcConfig); } const version = await getVersion(); - const model = resolveModel(gcConfig.getModel()); + const model = resolveModel( + gcConfig.getModel(), + config.authType === AuthType.USE_GEMINI || + config.authType === AuthType.USE_VERTEX_AI || + ((await gcConfig.getGemini31Launched?.()) ?? false), + ); const customHeadersEnv = process.env['GEMINI_CLI_CUSTOM_HEADERS'] || undefined; const userAgent = `GeminiCLI/${version}/${model} (${process.platform}; ${process.arch})`; diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 6b1ede738c..14f90cea9d 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -496,13 +496,14 @@ export class GeminiChat { const initialActiveModel = this.config.getActiveModel(); const apiCall = async () => { + const useGemini3_1 = (await this.config.getGemini31Launched?.()) ?? false; // Default to the last used model (which respects arguments/availability selection) - let modelToUse = resolveModel(lastModelToUse); + let modelToUse = resolveModel(lastModelToUse, useGemini3_1); // If the active model has changed (e.g. due to a fallback updating the config), // we switch to the new active model. if (this.config.getActiveModel() !== initialActiveModel) { - modelToUse = resolveModel(this.config.getActiveModel()); + modelToUse = resolveModel(this.config.getActiveModel(), useGemini3_1); } if (modelToUse !== lastModelToUse) { diff --git a/packages/core/src/prompts/promptProvider.ts b/packages/core/src/prompts/promptProvider.ts index 36ffddf71c..4f1a3afbff 100644 --- a/packages/core/src/prompts/promptProvider.ts +++ b/packages/core/src/prompts/promptProvider.ts @@ -58,7 +58,10 @@ export class PromptProvider { const enabledToolNames = new Set(toolNames); const approvedPlanPath = config.getApprovedPlanPath(); - const desiredModel = resolveModel(config.getActiveModel()); + const desiredModel = resolveModel( + config.getActiveModel(), + config.getGemini31LaunchedSync?.() ?? false, + ); const isModernModel = supportsModernFeatures(desiredModel); const activeSnippets = isModernModel ? snippets : legacySnippets; const contextFilenames = getAllGeminiMdFilenames(); @@ -231,7 +234,10 @@ export class PromptProvider { } getCompressionPrompt(config: Config): string { - const desiredModel = resolveModel(config.getActiveModel()); + const desiredModel = resolveModel( + config.getActiveModel(), + config.getGemini31LaunchedSync?.() ?? false, + ); const isModernModel = supportsModernFeatures(desiredModel); const activeSnippets = isModernModel ? snippets : legacySnippets; return activeSnippets.getCompressionPrompt(); diff --git a/packages/core/src/routing/strategies/classifierStrategy.test.ts b/packages/core/src/routing/strategies/classifierStrategy.test.ts index b2c7a8797e..7e024b790a 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.test.ts @@ -18,11 +18,14 @@ import { DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, } from '../../config/models.js'; import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { AuthType } from '../../core/contentGenerator.js'; vi.mock('../../core/baseLlmClient.js'); @@ -53,6 +56,10 @@ describe('ClassifierStrategy', () => { }, getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), + getGemini31Launched: vi.fn().mockResolvedValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), @@ -339,4 +346,49 @@ describe('ClassifierStrategy', () => { // Since requestedModel is Pro, and choice is flash, it should resolve to Flash expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); }); + + describe('Gemini 3.1 and Custom Tools Routing', () => { + it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); + const mockApiResponse = { + reasoning: 'Complex task', + model_choice: 'pro', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + + it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + const mockApiResponse = { + reasoning: 'Complex task', + model_choice: 'pro', + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + }); }); diff --git a/packages/core/src/routing/strategies/classifierStrategy.ts b/packages/core/src/routing/strategies/classifierStrategy.ts index 980e89829d..7e54d161de 100644 --- a/packages/core/src/routing/strategies/classifierStrategy.ts +++ b/packages/core/src/routing/strategies/classifierStrategy.ts @@ -21,6 +21,7 @@ import { } from '../../utils/messageInspectors.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { LlmRole } from '../../telemetry/types.js'; +import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 4; @@ -169,9 +170,15 @@ export class ClassifierStrategy implements RoutingStrategy { const reasoning = routerResponse.reasoning; const latencyMs = Date.now() - startTime; + const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI; const selectedModel = resolveClassifierModel( model, routerResponse.model_choice, + useGemini3_1, + useCustomToolModel, ); return { diff --git a/packages/core/src/routing/strategies/defaultStrategy.ts b/packages/core/src/routing/strategies/defaultStrategy.ts index e5b89eb1b3..1f5b7e54c2 100644 --- a/packages/core/src/routing/strategies/defaultStrategy.ts +++ b/packages/core/src/routing/strategies/defaultStrategy.ts @@ -21,7 +21,10 @@ export class DefaultStrategy implements TerminalStrategy { config: Config, _baseLlmClient: BaseLlmClient, ): Promise { - const defaultModel = resolveModel(config.getModel()); + const defaultModel = resolveModel( + config.getModel(), + config.getGemini31LaunchedSync?.() ?? false, + ); return { model: defaultModel, metadata: { diff --git a/packages/core/src/routing/strategies/fallbackStrategy.ts b/packages/core/src/routing/strategies/fallbackStrategy.ts index d568039cbc..a18e4fc4dd 100644 --- a/packages/core/src/routing/strategies/fallbackStrategy.ts +++ b/packages/core/src/routing/strategies/fallbackStrategy.ts @@ -23,7 +23,10 @@ export class FallbackStrategy implements RoutingStrategy { _baseLlmClient: BaseLlmClient, ): Promise { const requestedModel = context.requestedModel ?? config.getModel(); - const resolvedModel = resolveModel(requestedModel); + const resolvedModel = resolveModel( + requestedModel, + config.getGemini31LaunchedSync?.() ?? false, + ); const service = config.getModelAvailabilityService(); const snapshot = service.snapshot(resolvedModel); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts index 8767709f68..b8f6c50282 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.test.ts @@ -12,6 +12,8 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js'; import { PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_MODEL, + PREVIEW_GEMINI_3_1_MODEL, + PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL, @@ -20,6 +22,7 @@ import { promptIdContext } from '../../utils/promptIdContext.js'; import type { Content } from '@google/genai'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import { debugLogger } from '../../utils/debugLogger.js'; +import { AuthType } from '../../core/contentGenerator.js'; vi.mock('../../core/baseLlmClient.js'); @@ -52,6 +55,10 @@ describe('NumericalClassifierStrategy', () => { getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getClassifierThreshold: vi.fn().mockResolvedValue(undefined), + getGemini31Launched: vi.fn().mockResolvedValue(false), + getContentGeneratorConfig: vi.fn().mockReturnValue({ + authType: AuthType.LOGIN_WITH_GOOGLE, + }), } as unknown as Config; mockBaseLlmClient = { generateJson: vi.fn(), @@ -535,4 +542,68 @@ describe('NumericalClassifierStrategy', () => { ), ); }); + + describe('Gemini 3.1 and Custom Tools Routing', () => { + it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_GEMINI, + }); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); + }); + + it('should NOT route to custom tools model when auth is USE_VERTEX_AI', async () => { + vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true); + vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({ + authType: AuthType.USE_VERTEX_AI, + }); + const mockApiResponse = { + complexity_reasoning: 'Complex task', + complexity_score: 80, + }; + vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue( + mockApiResponse, + ); + + const decision = await strategy.route( + mockContext, + mockConfig, + mockBaseLlmClient, + ); + + expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); + }); + }); }); diff --git a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts index d4ddf99b8d..32cc6ccbb7 100644 --- a/packages/core/src/routing/strategies/numericalClassifierStrategy.ts +++ b/packages/core/src/routing/strategies/numericalClassifierStrategy.ts @@ -17,6 +17,7 @@ import { createUserContent, Type } from '@google/genai'; import type { Config } from '../../config/config.js'; import { debugLogger } from '../../utils/debugLogger.js'; import { LlmRole } from '../../telemetry/types.js'; +import { AuthType } from '../../core/contentGenerator.js'; // The number of recent history turns to provide to the router for context. const HISTORY_TURNS_FOR_CONTEXT = 8; @@ -182,8 +183,16 @@ export class NumericalClassifierStrategy implements RoutingStrategy { config, config.getSessionId() || 'unknown-session', ); - - const selectedModel = resolveClassifierModel(model, modelAlias); + const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false; + const useCustomToolModel = + useGemini3_1 && + config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI; + const selectedModel = resolveClassifierModel( + model, + modelAlias, + useGemini3_1, + useCustomToolModel, + ); const latencyMs = Date.now() - startTime; diff --git a/packages/core/src/routing/strategies/overrideStrategy.ts b/packages/core/src/routing/strategies/overrideStrategy.ts index b8382407bd..5101ba9fe7 100644 --- a/packages/core/src/routing/strategies/overrideStrategy.ts +++ b/packages/core/src/routing/strategies/overrideStrategy.ts @@ -33,7 +33,10 @@ export class OverrideStrategy implements RoutingStrategy { // Return the overridden model name. return { - model: resolveModel(overrideModel), + model: resolveModel( + overrideModel, + config.getGemini31LaunchedSync?.() ?? false, + ), metadata: { source: this.name, latencyMs: 0, diff --git a/packages/core/src/services/chatCompressionService.ts b/packages/core/src/services/chatCompressionService.ts index 44ffe90cf2..432c08dd1e 100644 --- a/packages/core/src/services/chatCompressionService.ts +++ b/packages/core/src/services/chatCompressionService.ts @@ -29,6 +29,7 @@ import { DEFAULT_GEMINI_MODEL, PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_FLASH_MODEL, + PREVIEW_GEMINI_3_1_MODEL, } from '../config/models.js'; import { PreCompressTrigger } from '../hooks/types.js'; import { LlmRole } from '../telemetry/types.js'; @@ -101,6 +102,7 @@ export function findCompressSplitPoint( export function modelStringToModelConfigAlias(model: string): string { switch (model) { case PREVIEW_GEMINI_MODEL: + case PREVIEW_GEMINI_3_1_MODEL: return 'chat-compression-3-pro'; case PREVIEW_GEMINI_FLASH_MODEL: return 'chat-compression-3-flash';