feat(models): support Gemini 3.1 Pro Preview and fixes (#19676)

This commit is contained in:
Sehoon Shon
2026-02-20 14:19:21 -05:00
committed by GitHub
parent 788a40c445
commit f97b04cc9a
25 changed files with 670 additions and 50 deletions
+4 -1
View File
@@ -9,6 +9,7 @@ import { Box, Text } from 'ink';
import { theme } from '../semantic-colors.js'; import { theme } from '../semantic-colors.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js'; import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { useSettings } from '../contexts/SettingsContext.js'; import { useSettings } from '../contexts/SettingsContext.js';
import { getDisplayString } from '@google/gemini-cli-core';
interface AboutBoxProps { interface AboutBoxProps {
cliVersion: string; cliVersion: string;
@@ -79,7 +80,9 @@ export const AboutBox: React.FC<AboutBoxProps> = ({
</Text> </Text>
</Box> </Box>
<Box> <Box>
<Text color={theme.text.primary}>{modelVersion}</Text> <Text color={theme.text.primary}>
{getDisplayString(modelVersion)}
</Text>
</Box> </Box>
</Box> </Box>
<Box flexDirection="row"> <Box flexDirection="row">
@@ -9,11 +9,17 @@ import { act } from 'react';
import { ModelDialog } from './ModelDialog.js'; import { ModelDialog } from './ModelDialog.js';
import { renderWithProviders } from '../../test-utils/render.js'; import { renderWithProviders } from '../../test-utils/render.js';
import { waitFor } from '../../test-utils/async.js'; import { waitFor } from '../../test-utils/async.js';
import { createMockSettings } from '../../test-utils/settings.js';
import { import {
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
AuthType,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import type { Config, ModelSlashCommandEvent } from '@google/gemini-cli-core'; import type { Config, ModelSlashCommandEvent } from '@google/gemini-cli-core';
@@ -42,12 +48,14 @@ describe('<ModelDialog />', () => {
const mockGetModel = vi.fn(); const mockGetModel = vi.fn();
const mockOnClose = vi.fn(); const mockOnClose = vi.fn();
const mockGetHasAccessToPreviewModel = vi.fn(); const mockGetHasAccessToPreviewModel = vi.fn();
const mockGetGemini31LaunchedSync = vi.fn();
interface MockConfig extends Partial<Config> { interface MockConfig extends Partial<Config> {
setModel: (model: string, isTemporary?: boolean) => void; setModel: (model: string, isTemporary?: boolean) => void;
getModel: () => string; getModel: () => string;
getHasAccessToPreviewModel: () => boolean; getHasAccessToPreviewModel: () => boolean;
getIdeMode: () => boolean; getIdeMode: () => boolean;
getGemini31LaunchedSync: () => boolean;
} }
const mockConfig: MockConfig = { const mockConfig: MockConfig = {
@@ -55,12 +63,14 @@ describe('<ModelDialog />', () => {
getModel: mockGetModel, getModel: mockGetModel,
getHasAccessToPreviewModel: mockGetHasAccessToPreviewModel, getHasAccessToPreviewModel: mockGetHasAccessToPreviewModel,
getIdeMode: () => false, getIdeMode: () => false,
getGemini31LaunchedSync: mockGetGemini31LaunchedSync,
}; };
beforeEach(() => { beforeEach(() => {
vi.resetAllMocks(); vi.resetAllMocks();
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO); mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
mockGetHasAccessToPreviewModel.mockReturnValue(false); mockGetHasAccessToPreviewModel.mockReturnValue(false);
mockGetGemini31LaunchedSync.mockReturnValue(false);
// Default implementation for getDisplayString // Default implementation for getDisplayString
mockGetDisplayString.mockImplementation((val: string) => { mockGetDisplayString.mockImplementation((val: string) => {
@@ -70,9 +80,21 @@ describe('<ModelDialog />', () => {
}); });
}); });
const renderComponent = async (configValue = mockConfig as Config) => { const renderComponent = async (
configValue = mockConfig as Config,
authType = AuthType.LOGIN_WITH_GOOGLE,
) => {
const settings = createMockSettings({
security: {
auth: {
selectedType: authType,
},
},
});
const result = renderWithProviders(<ModelDialog onClose={mockOnClose} />, { const result = renderWithProviders(<ModelDialog onClose={mockOnClose} />, {
config: configValue, config: configValue,
settings,
}); });
await result.waitUntilReady(); await result.waitUntilReady();
return result; return result;
@@ -241,4 +263,98 @@ describe('<ModelDialog />', () => {
}); });
unmount(); unmount();
}); });
it('shows the preferred manual model in the main view option', async () => {
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL);
const { lastFrame, unmount } = await renderComponent();
expect(lastFrame()).toContain(`Manual (${DEFAULT_GEMINI_MODEL})`);
unmount();
});
describe('Preview Models', () => {
beforeEach(() => {
mockGetHasAccessToPreviewModel.mockReturnValue(true);
});
it('shows Auto (Preview) in main view when access is granted', async () => {
const { lastFrame, unmount } = await renderComponent();
expect(lastFrame()).toContain('Auto (Preview)');
unmount();
});
it('shows Gemini 3 models in manual view when Gemini 3.1 is NOT launched', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(false);
const { lastFrame, stdin, waitUntilReady, unmount } =
await renderComponent();
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await waitUntilReady();
await act(async () => {
stdin.write('\r');
});
await waitUntilReady();
const output = lastFrame();
expect(output).toContain(PREVIEW_GEMINI_MODEL);
expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL);
unmount();
});
it('shows Gemini 3.1 models in manual view when Gemini 3.1 IS launched', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(true);
const { lastFrame, stdin, waitUntilReady, unmount } =
await renderComponent(mockConfig as Config, AuthType.USE_VERTEX_AI);
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await waitUntilReady();
await act(async () => {
stdin.write('\r');
});
await waitUntilReady();
const output = lastFrame();
expect(output).toContain(PREVIEW_GEMINI_3_1_MODEL);
expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL);
unmount();
});
it('uses custom tools model when Gemini 3.1 IS launched and auth is Gemini API Key', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(true);
const { stdin, waitUntilReady, unmount } = await renderComponent(
mockConfig as Config,
AuthType.USE_GEMINI,
);
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await waitUntilReady();
await act(async () => {
stdin.write('\r');
});
await waitUntilReady();
// Select Gemini 3.1 (first item in preview section)
await act(async () => {
stdin.write('\r');
});
await waitUntilReady();
await waitFor(() => {
expect(mockSetModel).toHaveBeenCalledWith(
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
true,
);
});
unmount();
});
});
}); });
+27 -7
View File
@@ -9,6 +9,7 @@ import { useCallback, useContext, useMemo, useState } from 'react';
import { Box, Text } from 'ink'; import { Box, Text } from 'ink';
import { import {
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
@@ -18,11 +19,14 @@ import {
ModelSlashCommandEvent, ModelSlashCommandEvent,
logModelSlashCommand, logModelSlashCommand,
getDisplayString, getDisplayString,
AuthType,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { useKeypress } from '../hooks/useKeypress.js'; import { useKeypress } from '../hooks/useKeypress.js';
import { theme } from '../semantic-colors.js'; import { theme } from '../semantic-colors.js';
import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js'; import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js';
import { ConfigContext } from '../contexts/ConfigContext.js'; import { ConfigContext } from '../contexts/ConfigContext.js';
import { useSettings } from '../contexts/SettingsContext.js';
interface ModelDialogProps { interface ModelDialogProps {
onClose: () => void; onClose: () => void;
@@ -30,6 +34,7 @@ interface ModelDialogProps {
export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element { export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
const config = useContext(ConfigContext); const config = useContext(ConfigContext);
const settings = useSettings();
const [view, setView] = useState<'main' | 'manual'>('main'); const [view, setView] = useState<'main' | 'manual'>('main');
const [persistMode, setPersistMode] = useState(false); const [persistMode, setPersistMode] = useState(false);
@@ -37,6 +42,10 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
const preferredModel = config?.getModel() || DEFAULT_GEMINI_MODEL_AUTO; const preferredModel = config?.getModel() || DEFAULT_GEMINI_MODEL_AUTO;
const shouldShowPreviewModels = config?.getHasAccessToPreviewModel(); const shouldShowPreviewModels = config?.getHasAccessToPreviewModel();
const useGemini31 = config?.getGemini31LaunchedSync?.() ?? false;
const selectedAuthType = settings.merged.security.auth.selectedType;
const useCustomToolModel =
useGemini31 && selectedAuthType === AuthType.USE_GEMINI;
const manualModelSelected = useMemo(() => { const manualModelSelected = useMemo(() => {
const manualModels = [ const manualModels = [
@@ -44,6 +53,8 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_FLASH_LITE_MODEL,
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
]; ];
if (manualModels.includes(preferredModel)) { if (manualModels.includes(preferredModel)) {
@@ -94,13 +105,14 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
list.unshift({ list.unshift({
value: PREVIEW_GEMINI_MODEL_AUTO, value: PREVIEW_GEMINI_MODEL_AUTO,
title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO), title: getDisplayString(PREVIEW_GEMINI_MODEL_AUTO),
description: description: useGemini31
'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash', ? 'Let Gemini CLI decide the best model for the task: gemini-3.1-pro, gemini-3-flash'
: 'Let Gemini CLI decide the best model for the task: gemini-3-pro, gemini-3-flash',
key: PREVIEW_GEMINI_MODEL_AUTO, key: PREVIEW_GEMINI_MODEL_AUTO,
}); });
} }
return list; return list;
}, [shouldShowPreviewModels, manualModelSelected]); }, [shouldShowPreviewModels, manualModelSelected, useGemini31]);
const manualOptions = useMemo(() => { const manualOptions = useMemo(() => {
const list = [ const list = [
@@ -122,11 +134,19 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
]; ];
if (shouldShowPreviewModels) { if (shouldShowPreviewModels) {
const previewProModel = useGemini31
? PREVIEW_GEMINI_3_1_MODEL
: PREVIEW_GEMINI_MODEL;
const previewProValue = useCustomToolModel
? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL
: previewProModel;
list.unshift( list.unshift(
{ {
value: PREVIEW_GEMINI_MODEL, value: previewProValue,
title: PREVIEW_GEMINI_MODEL, title: previewProModel,
key: PREVIEW_GEMINI_MODEL, key: previewProModel,
}, },
{ {
value: PREVIEW_GEMINI_FLASH_MODEL, value: PREVIEW_GEMINI_FLASH_MODEL,
@@ -136,7 +156,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
); );
} }
return list; return list;
}, [shouldShowPreviewModels]); }, [shouldShowPreviewModels, useGemini31, useCustomToolModel]);
const options = view === 'main' ? mainOptions : manualOptions; const options = view === 'main' ? mainOptions : manualOptions;
@@ -23,11 +23,13 @@ import {
import { computeSessionStats } from '../utils/computeStats.js'; import { computeSessionStats } from '../utils/computeStats.js';
import { import {
type RetrieveUserQuotaResponse, type RetrieveUserQuotaResponse,
VALID_GEMINI_MODELS, isActiveModel,
getDisplayString, getDisplayString,
isAutoModel, isAutoModel,
AuthType,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { useSettings } from '../contexts/SettingsContext.js'; import { useSettings } from '../contexts/SettingsContext.js';
import { useConfig } from '../contexts/ConfigContext.js';
import type { QuotaStats } from '../types.js'; import type { QuotaStats } from '../types.js';
import { QuotaStatsInfo } from './QuotaStatsInfo.js'; import { QuotaStatsInfo } from './QuotaStatsInfo.js';
@@ -82,9 +84,13 @@ const Section: React.FC<SectionProps> = ({ title, children }) => (
const buildModelRows = ( const buildModelRows = (
models: Record<string, ModelMetrics>, models: Record<string, ModelMetrics>,
quotas?: RetrieveUserQuotaResponse, quotas?: RetrieveUserQuotaResponse,
useGemini3_1 = false,
useCustomToolModel = false,
) => { ) => {
const getBaseModelName = (name: string) => name.replace('-001', ''); const getBaseModelName = (name: string) => name.replace('-001', '');
const usedModelNames = new Set(Object.keys(models).map(getBaseModelName)); const usedModelNames = new Set(
Object.keys(models).map(getBaseModelName).map(getDisplayString),
);
// 1. Models with active usage // 1. Models with active usage
const activeRows = Object.entries(models).map(([name, metrics]) => { const activeRows = Object.entries(models).map(([name, metrics]) => {
@@ -93,7 +99,7 @@ const buildModelRows = (
const inputTokens = metrics.tokens.input; const inputTokens = metrics.tokens.input;
return { return {
key: name, key: name,
modelName, modelName: getDisplayString(modelName),
requests: metrics.api.totalRequests, requests: metrics.api.totalRequests,
cachedTokens: cachedTokens.toLocaleString(), cachedTokens: cachedTokens.toLocaleString(),
inputTokens: inputTokens.toLocaleString(), inputTokens: inputTokens.toLocaleString(),
@@ -109,12 +115,12 @@ const buildModelRows = (
?.filter( ?.filter(
(b) => (b) =>
b.modelId && b.modelId &&
VALID_GEMINI_MODELS.has(b.modelId) && isActiveModel(b.modelId, useGemini3_1, useCustomToolModel) &&
!usedModelNames.has(b.modelId), !usedModelNames.has(getDisplayString(b.modelId)),
) )
.map((bucket) => ({ .map((bucket) => ({
key: bucket.modelId!, key: bucket.modelId!,
modelName: bucket.modelId!, modelName: getDisplayString(bucket.modelId!),
requests: '-', requests: '-',
cachedTokens: '-', cachedTokens: '-',
inputTokens: '-', inputTokens: '-',
@@ -135,6 +141,8 @@ const ModelUsageTable: React.FC<{
pooledRemaining?: number; pooledRemaining?: number;
pooledLimit?: number; pooledLimit?: number;
pooledResetTime?: string; pooledResetTime?: string;
useGemini3_1?: boolean;
useCustomToolModel?: boolean;
}> = ({ }> = ({
models, models,
quotas, quotas,
@@ -144,8 +152,10 @@ const ModelUsageTable: React.FC<{
pooledRemaining, pooledRemaining,
pooledLimit, pooledLimit,
pooledResetTime, pooledResetTime,
useGemini3_1,
useCustomToolModel,
}) => { }) => {
const rows = buildModelRows(models, quotas); const rows = buildModelRows(models, quotas, useGemini3_1, useCustomToolModel);
if (rows.length === 0) { if (rows.length === 0) {
return null; return null;
@@ -403,7 +413,11 @@ export const StatsDisplay: React.FC<StatsDisplayProps> = ({
const { models, tools, files } = metrics; const { models, tools, files } = metrics;
const computed = computeSessionStats(metrics); const computed = computeSessionStats(metrics);
const settings = useSettings(); const settings = useSettings();
const config = useConfig();
const useGemini3_1 = config.getGemini31LaunchedSync?.() ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const pooledRemaining = quotaStats?.remaining; const pooledRemaining = quotaStats?.remaining;
const pooledLimit = quotaStats?.limit; const pooledLimit = quotaStats?.limit;
const pooledResetTime = quotaStats?.resetTime; const pooledResetTime = quotaStats?.resetTime;
@@ -544,6 +558,8 @@ export const StatsDisplay: React.FC<StatsDisplayProps> = ({
pooledRemaining={pooledRemaining} pooledRemaining={pooledRemaining}
pooledLimit={pooledLimit} pooledLimit={pooledLimit}
pooledResetTime={pooledResetTime} pooledResetTime={pooledResetTime}
useGemini3_1={useGemini3_1}
useCustomToolModel={useCustomToolModel}
/> />
{renderFooter()} {renderFooter()}
</Box> </Box>
@@ -7,6 +7,7 @@
import type React from 'react'; import type React from 'react';
import { Text, Box } from 'ink'; import { Text, Box } from 'ink';
import { theme } from '../../semantic-colors.js'; import { theme } from '../../semantic-colors.js';
import { getDisplayString } from '@google/gemini-cli-core';
interface ModelMessageProps { interface ModelMessageProps {
model: string; model: string;
@@ -15,7 +16,7 @@ interface ModelMessageProps {
export const ModelMessage: React.FC<ModelMessageProps> = ({ model }) => ( export const ModelMessage: React.FC<ModelMessageProps> = ({ model }) => (
<Box marginLeft={2}> <Box marginLeft={2}>
<Text color={theme.ui.comment} italic> <Text color={theme.ui.comment} italic>
Responding with {model} Responding with {getDisplayString(model)}
</Text> </Text>
</Box> </Box>
); );
@@ -155,9 +155,10 @@ describe('useQuotaAndFallback', () => {
expect(request?.isTerminalQuotaError).toBe(true); expect(request?.isTerminalQuotaError).toBe(true);
const message = request!.message; const message = request!.message;
expect(message).toContain('Usage limit reached for gemini-pro.'); expect(message).toContain('Usage limit reached for all Pro models.');
expect(message).toContain('Access resets at'); // From getResetTimeMessage expect(message).toContain('Access resets at'); // From getResetTimeMessage
expect(message).toContain('/stats model for usage details'); expect(message).toContain('/stats model for usage details');
expect(message).toContain('/model to switch models.');
expect(message).toContain('/auth to switch to API key.'); expect(message).toContain('/auth to switch to API key.');
expect(mockHistoryManager.addItem).not.toHaveBeenCalled(); expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
@@ -176,6 +177,77 @@ describe('useQuotaAndFallback', () => {
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1); expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1);
}); });
it('should show the model name for a terminal quota error on a non-pro model', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
let promise: Promise<FallbackIntent | null>;
const error = new TerminalQuotaError(
'flash quota',
mockGoogleApiError,
1000 * 60 * 5,
);
act(() => {
promise = handler('gemini-flash', 'gemini-pro', error);
});
const request = result.current.proQuotaRequest;
expect(request).not.toBeNull();
expect(request?.failedModel).toBe('gemini-flash');
const message = request!.message;
expect(message).toContain('Usage limit reached for gemini-flash.');
expect(message).not.toContain('all Pro models');
act(() => {
result.current.handleProQuotaChoice('retry_later');
});
await promise!;
});
it('should handle terminal quota error without retry delay', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
let promise: Promise<FallbackIntent | null>;
const error = new TerminalQuotaError('no delay', mockGoogleApiError);
act(() => {
promise = handler('gemini-pro', 'gemini-flash', error);
});
const request = result.current.proQuotaRequest;
const message = request!.message;
expect(message).not.toContain('Access resets at');
expect(message).toContain('Usage limit reached for all Pro models.');
act(() => {
result.current.handleProQuotaChoice('retry_later');
});
await promise!;
});
it('should handle race conditions by stopping subsequent requests', async () => { it('should handle race conditions by stopping subsequent requests', async () => {
const { result } = renderHook(() => const { result } = renderHook(() =>
useQuotaAndFallback({ useQuotaAndFallback({
@@ -14,9 +14,9 @@ import {
TerminalQuotaError, TerminalQuotaError,
ModelNotFoundError, ModelNotFoundError,
type UserTierId, type UserTierId,
PREVIEW_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL,
VALID_GEMINI_MODELS, VALID_GEMINI_MODELS,
isProModel,
getDisplayString,
} from '@google/gemini-cli-core'; } from '@google/gemini-cli-core';
import { useCallback, useEffect, useRef, useState } from 'react'; import { useCallback, useEffect, useRef, useState } from 'react';
import { type UseHistoryManagerReturn } from './useHistoryManager.js'; import { type UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -67,11 +67,9 @@ export function useQuotaAndFallback({
let message: string; let message: string;
let isTerminalQuotaError = false; let isTerminalQuotaError = false;
let isModelNotFoundError = false; let isModelNotFoundError = false;
const usageLimitReachedModel = const usageLimitReachedModel = isProModel(failedModel)
failedModel === DEFAULT_GEMINI_MODEL || ? 'all Pro models'
failedModel === PREVIEW_GEMINI_MODEL : failedModel;
? 'all Pro models'
: failedModel;
if (error instanceof TerminalQuotaError) { if (error instanceof TerminalQuotaError) {
isTerminalQuotaError = true; isTerminalQuotaError = true;
// Common part of the message for both tiers // Common part of the message for both tiers
@@ -87,7 +85,7 @@ export function useQuotaAndFallback({
isModelNotFoundError = true; isModelNotFoundError = true;
if (VALID_GEMINI_MODELS.has(failedModel)) { if (VALID_GEMINI_MODELS.has(failedModel)) {
const messageLines = [ const messageLines = [
`It seems like you don't have access to ${failedModel}.`, `It seems like you don't have access to ${getDisplayString(failedModel)}.`,
`Your admin might have disabled the access. Contact them to enable the Preview Release Channel.`, `Your admin might have disabled the access. Contact them to enable the Preview Release Channel.`,
]; ];
message = messageLines.join('\n'); message = messageLines.join('\n');
@@ -520,7 +520,10 @@ export class Session {
const functionCalls: FunctionCall[] = []; const functionCalls: FunctionCall[] = [];
try { try {
const model = resolveModel(this.config.getModel()); const model = resolveModel(
this.config.getModel(),
(await this.config.getGemini31Launched?.()) ?? false,
);
const responseStream = await chat.sendMessageStream( const responseStream = await chat.sendMessageStream(
{ model }, { model },
nextMessage?.parts ?? [], nextMessage?.parts ?? [],
@@ -44,7 +44,10 @@ export function resolvePolicyChain(
const configuredModel = config.getModel(); const configuredModel = config.getModel();
let chain; let chain;
const resolvedModel = resolveModel(modelFromConfig); const resolvedModel = resolveModel(
modelFromConfig,
config.getGemini31LaunchedSync?.() ?? false,
);
const isAutoPreferred = preferredModel ? isAutoModel(preferredModel) : false; const isAutoPreferred = preferredModel ? isAutoModel(preferredModel) : false;
const isAutoConfigured = isAutoModel(configuredModel); const isAutoConfigured = isAutoModel(configuredModel);
const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true; const hasAccessToPreview = config.getHasAccessToPreviewModel?.() ?? true;
@@ -16,6 +16,7 @@ export const ExperimentFlags = {
MASKING_PROTECTION_THRESHOLD: 45758817, MASKING_PROTECTION_THRESHOLD: 45758817,
MASKING_PRUNABLE_THRESHOLD: 45758818, MASKING_PRUNABLE_THRESHOLD: 45758818,
MASKING_PROTECT_LATEST_TURN: 45758819, MASKING_PROTECT_LATEST_TURN: 45758819,
GEMINI_3_1_PRO_LAUNCHED: 45760185,
} as const; } as const;
export type ExperimentFlagName = export type ExperimentFlagName =
+48 -3
View File
@@ -1071,6 +1071,12 @@ export class Config {
// Reset availability status when switching auth (e.g. from limited key to OAuth) // Reset availability status when switching auth (e.g. from limited key to OAuth)
this.modelAvailabilityService.reset(); this.modelAvailabilityService.reset();
// Clear stale authType to ensure getGemini31LaunchedSync doesn't return stale results
// during the transition.
if (this.contentGeneratorConfig) {
this.contentGeneratorConfig.authType = undefined;
}
const newContentGeneratorConfig = await createContentGeneratorConfig( const newContentGeneratorConfig = await createContentGeneratorConfig(
this, this,
authMethod, authMethod,
@@ -1350,7 +1356,10 @@ export class Config {
if (pooled.remaining !== undefined) { if (pooled.remaining !== undefined) {
return pooled.remaining; return pooled.remaining;
} }
const primaryModel = resolveModel(this.getModel()); const primaryModel = resolveModel(
this.getModel(),
this.getGemini31LaunchedSync(),
);
return this.modelQuotas.get(primaryModel)?.remaining; return this.modelQuotas.get(primaryModel)?.remaining;
} }
@@ -1359,7 +1368,10 @@ export class Config {
if (pooled.limit !== undefined) { if (pooled.limit !== undefined) {
return pooled.limit; return pooled.limit;
} }
const primaryModel = resolveModel(this.getModel()); const primaryModel = resolveModel(
this.getModel(),
this.getGemini31LaunchedSync(),
);
return this.modelQuotas.get(primaryModel)?.limit; return this.modelQuotas.get(primaryModel)?.limit;
} }
@@ -1368,7 +1380,10 @@ export class Config {
if (pooled.resetTime !== undefined) { if (pooled.resetTime !== undefined) {
return pooled.resetTime; return pooled.resetTime;
} }
const primaryModel = resolveModel(this.getModel()); const primaryModel = resolveModel(
this.getModel(),
this.getGemini31LaunchedSync(),
);
return this.modelQuotas.get(primaryModel)?.resetTime; return this.modelQuotas.get(primaryModel)?.resetTime;
} }
@@ -2253,6 +2268,36 @@ export class Config {
); );
} }
/**
* Returns whether Gemini 3.1 has been launched.
* This method is async and ensures that experiments are loaded before returning the result.
*/
async getGemini31Launched(): Promise<boolean> {
await this.ensureExperimentsLoaded();
return this.getGemini31LaunchedSync();
}
/**
* Returns whether Gemini 3.1 has been launched.
*
* Note: This method should only be called after startup, once experiments have been loaded.
* If you need to call this during startup or from an async context, use
* getGemini31Launched instead.
*/
getGemini31LaunchedSync(): boolean {
const authType = this.contentGeneratorConfig?.authType;
if (
authType === AuthType.USE_GEMINI ||
authType === AuthType.USE_VERTEX_AI
) {
return true;
}
return (
this.experiments?.flags[ExperimentFlags.GEMINI_3_1_PRO_LAUNCHED]
?.boolValue ?? false
);
}
private async ensureExperimentsLoaded(): Promise<void> { private async ensureExperimentsLoaded(): Promise<void> {
if (!this.experimentsPromise) { if (!this.experimentsPromise) {
return; return;
+116
View File
@@ -25,8 +25,41 @@ import {
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
isActiveModel,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
isPreviewModel,
isProModel,
} from './models.js'; } from './models.js';
describe('isPreviewModel', () => {
it('should return true for preview models', () => {
expect(isPreviewModel(PREVIEW_GEMINI_MODEL)).toBe(true);
expect(isPreviewModel(PREVIEW_GEMINI_3_1_MODEL)).toBe(true);
expect(isPreviewModel(PREVIEW_GEMINI_FLASH_MODEL)).toBe(true);
expect(isPreviewModel(PREVIEW_GEMINI_MODEL_AUTO)).toBe(true);
});
it('should return false for non-preview models', () => {
expect(isPreviewModel(DEFAULT_GEMINI_MODEL)).toBe(false);
expect(isPreviewModel('gemini-1.5-pro')).toBe(false);
});
});
describe('isProModel', () => {
it('should return true for models containing "pro"', () => {
expect(isProModel('gemini-3-pro-preview')).toBe(true);
expect(isProModel('gemini-2.5-pro')).toBe(true);
expect(isProModel('pro')).toBe(true);
});
it('should return false for models without "pro"', () => {
expect(isProModel('gemini-3-flash-preview')).toBe(false);
expect(isProModel('gemini-2.5-flash')).toBe(false);
expect(isProModel('auto')).toBe(false);
});
});
describe('isCustomModel', () => { describe('isCustomModel', () => {
it('should return true for models not starting with gemini-', () => { it('should return true for models not starting with gemini-', () => {
expect(isCustomModel('testing')).toBe(true); expect(isCustomModel('testing')).toBe(true);
@@ -115,6 +148,12 @@ describe('getDisplayString', () => {
); );
}); });
it('should return PREVIEW_GEMINI_3_1_MODEL for PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL', () => {
expect(getDisplayString(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL)).toBe(
PREVIEW_GEMINI_3_1_MODEL,
);
});
it('should return the model name as is for other models', () => { it('should return the model name as is for other models', () => {
expect(getDisplayString('custom-model')).toBe('custom-model'); expect(getDisplayString('custom-model')).toBe('custom-model');
expect(getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL)).toBe( expect(getDisplayString(DEFAULT_GEMINI_FLASH_LITE_MODEL)).toBe(
@@ -146,6 +185,16 @@ describe('resolveModel', () => {
expect(model).toBe(PREVIEW_GEMINI_MODEL); expect(model).toBe(PREVIEW_GEMINI_MODEL);
}); });
it('should return Gemini 3.1 Pro when auto-gemini-3 is requested and useGemini3_1 is true', () => {
const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true);
expect(model).toBe(PREVIEW_GEMINI_3_1_MODEL);
});
it('should return Gemini 3.1 Pro Custom Tools when auto-gemini-3 is requested, useGemini3_1 is true, and useCustomToolModel is true', () => {
const model = resolveModel(PREVIEW_GEMINI_MODEL_AUTO, true, true);
expect(model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
});
it('should return the Default Pro model when auto-gemini-2.5 is requested', () => { it('should return the Default Pro model when auto-gemini-2.5 is requested', () => {
const model = resolveModel(DEFAULT_GEMINI_MODEL_AUTO); const model = resolveModel(DEFAULT_GEMINI_MODEL_AUTO);
expect(model).toBe(DEFAULT_GEMINI_MODEL); expect(model).toBe(DEFAULT_GEMINI_MODEL);
@@ -239,4 +288,71 @@ describe('resolveClassifierModel', () => {
resolveClassifierModel(PREVIEW_GEMINI_MODEL_AUTO, GEMINI_MODEL_ALIAS_PRO), resolveClassifierModel(PREVIEW_GEMINI_MODEL_AUTO, GEMINI_MODEL_ALIAS_PRO),
).toBe(PREVIEW_GEMINI_MODEL); ).toBe(PREVIEW_GEMINI_MODEL);
}); });
it('should return Gemini 3.1 Pro when alias is pro and useGemini3_1 is true', () => {
expect(
resolveClassifierModel(
PREVIEW_GEMINI_MODEL_AUTO,
GEMINI_MODEL_ALIAS_PRO,
true,
),
).toBe(PREVIEW_GEMINI_3_1_MODEL);
});
it('should return Gemini 3.1 Pro Custom Tools when alias is pro, useGemini3_1 is true, and useCustomToolModel is true', () => {
expect(
resolveClassifierModel(
PREVIEW_GEMINI_MODEL_AUTO,
GEMINI_MODEL_ALIAS_PRO,
true,
true,
),
).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
});
});
describe('isActiveModel', () => {
it('should return true for valid models when useGemini3_1 is false', () => {
expect(isActiveModel(DEFAULT_GEMINI_MODEL)).toBe(true);
expect(isActiveModel(PREVIEW_GEMINI_MODEL)).toBe(true);
expect(isActiveModel(DEFAULT_GEMINI_FLASH_MODEL)).toBe(true);
});
it('should return true for unknown models and aliases', () => {
expect(isActiveModel('invalid-model')).toBe(false);
expect(isActiveModel(GEMINI_MODEL_ALIAS_AUTO)).toBe(false);
});
it('should return false for PREVIEW_GEMINI_MODEL when useGemini3_1 is true', () => {
expect(isActiveModel(PREVIEW_GEMINI_MODEL, true)).toBe(false);
});
it('should return true for other valid models when useGemini3_1 is true', () => {
expect(isActiveModel(DEFAULT_GEMINI_MODEL, true)).toBe(true);
});
it('should correctly filter Gemini 3.1 models based on useCustomToolModel when useGemini3_1 is true', () => {
// When custom tools are preferred, standard 3.1 should be inactive
expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, true)).toBe(false);
expect(
isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, true),
).toBe(true);
// When custom tools are NOT preferred, custom tools 3.1 should be inactive
expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, true, false)).toBe(true);
expect(
isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, true, false),
).toBe(false);
});
it('should return false for both Gemini 3.1 models when useGemini3_1 is false', () => {
expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, true)).toBe(false);
expect(isActiveModel(PREVIEW_GEMINI_3_1_MODEL, false, false)).toBe(false);
expect(
isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, true),
).toBe(false);
expect(
isActiveModel(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, false, false),
).toBe(false);
});
}); });
+68 -7
View File
@@ -5,6 +5,9 @@
*/ */
export const PREVIEW_GEMINI_MODEL = 'gemini-3-pro-preview'; export const PREVIEW_GEMINI_MODEL = 'gemini-3-pro-preview';
export const PREVIEW_GEMINI_3_1_MODEL = 'gemini-3.1-pro-preview';
export const PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL =
'gemini-3.1-pro-preview-customtools';
export const PREVIEW_GEMINI_FLASH_MODEL = 'gemini-3-flash-preview'; export const PREVIEW_GEMINI_FLASH_MODEL = 'gemini-3-flash-preview';
export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro'; export const DEFAULT_GEMINI_MODEL = 'gemini-2.5-pro';
export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash'; export const DEFAULT_GEMINI_FLASH_MODEL = 'gemini-2.5-flash';
@@ -12,6 +15,8 @@ export const DEFAULT_GEMINI_FLASH_LITE_MODEL = 'gemini-2.5-flash-lite';
export const VALID_GEMINI_MODELS = new Set([ export const VALID_GEMINI_MODELS = new Set([
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
@@ -37,20 +42,29 @@ export const DEFAULT_THINKING_MODE = 8192;
* to a concrete model name. * to a concrete model name.
* *
* @param requestedModel The model alias or concrete model name requested by the user. * @param requestedModel The model alias or concrete model name requested by the user.
* @param useGemini3_1 Whether to use Gemini 3.1 Pro Preview for auto/pro aliases.
* @returns The resolved concrete model name. * @returns The resolved concrete model name.
*/ */
export function resolveModel(requestedModel: string): string { export function resolveModel(
requestedModel: string,
useGemini3_1: boolean = false,
useCustomToolModel: boolean = false,
): string {
switch (requestedModel) { switch (requestedModel) {
case PREVIEW_GEMINI_MODEL_AUTO: { case PREVIEW_GEMINI_MODEL:
case PREVIEW_GEMINI_MODEL_AUTO:
case GEMINI_MODEL_ALIAS_AUTO:
case GEMINI_MODEL_ALIAS_PRO: {
if (useGemini3_1) {
return useCustomToolModel
? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL
: PREVIEW_GEMINI_3_1_MODEL;
}
return PREVIEW_GEMINI_MODEL; return PREVIEW_GEMINI_MODEL;
} }
case DEFAULT_GEMINI_MODEL_AUTO: { case DEFAULT_GEMINI_MODEL_AUTO: {
return DEFAULT_GEMINI_MODEL; return DEFAULT_GEMINI_MODEL;
} }
case GEMINI_MODEL_ALIAS_AUTO:
case GEMINI_MODEL_ALIAS_PRO: {
return PREVIEW_GEMINI_MODEL;
}
case GEMINI_MODEL_ALIAS_FLASH: { case GEMINI_MODEL_ALIAS_FLASH: {
return PREVIEW_GEMINI_FLASH_MODEL; return PREVIEW_GEMINI_FLASH_MODEL;
} }
@@ -73,6 +87,8 @@ export function resolveModel(requestedModel: string): string {
export function resolveClassifierModel( export function resolveClassifierModel(
requestedModel: string, requestedModel: string,
modelAlias: string, modelAlias: string,
useGemini3_1: boolean = false,
useCustomToolModel: boolean = false,
): string { ): string {
if (modelAlias === GEMINI_MODEL_ALIAS_FLASH) { if (modelAlias === GEMINI_MODEL_ALIAS_FLASH) {
if ( if (
@@ -89,7 +105,7 @@ export function resolveClassifierModel(
} }
return resolveModel(GEMINI_MODEL_ALIAS_FLASH); return resolveModel(GEMINI_MODEL_ALIAS_FLASH);
} }
return resolveModel(requestedModel); return resolveModel(requestedModel, useGemini3_1, useCustomToolModel);
} }
export function getDisplayString(model: string) { export function getDisplayString(model: string) {
switch (model) { switch (model) {
@@ -101,6 +117,8 @@ export function getDisplayString(model: string) {
return PREVIEW_GEMINI_MODEL; return PREVIEW_GEMINI_MODEL;
case GEMINI_MODEL_ALIAS_FLASH: case GEMINI_MODEL_ALIAS_FLASH:
return PREVIEW_GEMINI_FLASH_MODEL; return PREVIEW_GEMINI_FLASH_MODEL;
case PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL:
return PREVIEW_GEMINI_3_1_MODEL;
default: default:
return model; return model;
} }
@@ -115,11 +133,22 @@ export function getDisplayString(model: string) {
export function isPreviewModel(model: string): boolean { export function isPreviewModel(model: string): boolean {
return ( return (
model === PREVIEW_GEMINI_MODEL || model === PREVIEW_GEMINI_MODEL ||
model === PREVIEW_GEMINI_3_1_MODEL ||
model === PREVIEW_GEMINI_FLASH_MODEL || model === PREVIEW_GEMINI_FLASH_MODEL ||
model === PREVIEW_GEMINI_MODEL_AUTO model === PREVIEW_GEMINI_MODEL_AUTO
); );
} }
/**
* Checks if the model is a Pro model.
*
* @param model The model name to check.
* @returns True if the model is a Pro model.
*/
export function isProModel(model: string): boolean {
return model.toLowerCase().includes('pro');
}
/** /**
* Checks if the model is a Gemini 3 model. * Checks if the model is a Gemini 3 model.
* *
@@ -188,3 +217,35 @@ export function isAutoModel(model: string): boolean {
export function supportsMultimodalFunctionResponse(model: string): boolean { export function supportsMultimodalFunctionResponse(model: string): boolean {
return model.startsWith('gemini-3-'); return model.startsWith('gemini-3-');
} }
/**
* Checks if the given model is considered active based on the current configuration.
*
* @param model The model name to check.
* @param useGemini3_1 Whether Gemini 3.1 Pro Preview is enabled.
* @returns True if the model is active.
*/
export function isActiveModel(
model: string,
useGemini3_1: boolean = false,
useCustomToolModel: boolean = false,
): boolean {
if (!VALID_GEMINI_MODELS.has(model)) {
return false;
}
if (useGemini3_1) {
if (model === PREVIEW_GEMINI_MODEL) {
return false;
}
if (useCustomToolModel) {
return model !== PREVIEW_GEMINI_3_1_MODEL;
} else {
return model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL;
}
} else {
return (
model !== PREVIEW_GEMINI_3_1_MODEL &&
model !== PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL
);
}
}
+4 -1
View File
@@ -542,7 +542,10 @@ export class GeminiClient {
// Availability logic: The configured model is the source of truth, // Availability logic: The configured model is the source of truth,
// including any permanent fallbacks (config.setModel) or manual overrides. // including any permanent fallbacks (config.setModel) or manual overrides.
return resolveModel(this.config.getActiveModel()); return resolveModel(
this.config.getActiveModel(),
this.config.getGemini31LaunchedSync?.() ?? false,
);
} }
private async *processTurn( private async *processTurn(
+6 -1
View File
@@ -146,7 +146,12 @@ export async function createContentGenerator(
return new LoggingContentGenerator(fakeGenerator, gcConfig); return new LoggingContentGenerator(fakeGenerator, gcConfig);
} }
const version = await getVersion(); const version = await getVersion();
const model = resolveModel(gcConfig.getModel()); const model = resolveModel(
gcConfig.getModel(),
config.authType === AuthType.USE_GEMINI ||
config.authType === AuthType.USE_VERTEX_AI ||
((await gcConfig.getGemini31Launched?.()) ?? false),
);
const customHeadersEnv = const customHeadersEnv =
process.env['GEMINI_CLI_CUSTOM_HEADERS'] || undefined; process.env['GEMINI_CLI_CUSTOM_HEADERS'] || undefined;
const userAgent = `GeminiCLI/${version}/${model} (${process.platform}; ${process.arch})`; const userAgent = `GeminiCLI/${version}/${model} (${process.platform}; ${process.arch})`;
+3 -2
View File
@@ -496,13 +496,14 @@ export class GeminiChat {
const initialActiveModel = this.config.getActiveModel(); const initialActiveModel = this.config.getActiveModel();
const apiCall = async () => { const apiCall = async () => {
const useGemini3_1 = (await this.config.getGemini31Launched?.()) ?? false;
// Default to the last used model (which respects arguments/availability selection) // Default to the last used model (which respects arguments/availability selection)
let modelToUse = resolveModel(lastModelToUse); let modelToUse = resolveModel(lastModelToUse, useGemini3_1);
// If the active model has changed (e.g. due to a fallback updating the config), // If the active model has changed (e.g. due to a fallback updating the config),
// we switch to the new active model. // we switch to the new active model.
if (this.config.getActiveModel() !== initialActiveModel) { if (this.config.getActiveModel() !== initialActiveModel) {
modelToUse = resolveModel(this.config.getActiveModel()); modelToUse = resolveModel(this.config.getActiveModel(), useGemini3_1);
} }
if (modelToUse !== lastModelToUse) { if (modelToUse !== lastModelToUse) {
+8 -2
View File
@@ -58,7 +58,10 @@ export class PromptProvider {
const enabledToolNames = new Set(toolNames); const enabledToolNames = new Set(toolNames);
const approvedPlanPath = config.getApprovedPlanPath(); const approvedPlanPath = config.getApprovedPlanPath();
const desiredModel = resolveModel(config.getActiveModel()); const desiredModel = resolveModel(
config.getActiveModel(),
config.getGemini31LaunchedSync?.() ?? false,
);
const isModernModel = supportsModernFeatures(desiredModel); const isModernModel = supportsModernFeatures(desiredModel);
const activeSnippets = isModernModel ? snippets : legacySnippets; const activeSnippets = isModernModel ? snippets : legacySnippets;
const contextFilenames = getAllGeminiMdFilenames(); const contextFilenames = getAllGeminiMdFilenames();
@@ -231,7 +234,10 @@ export class PromptProvider {
} }
getCompressionPrompt(config: Config): string { getCompressionPrompt(config: Config): string {
const desiredModel = resolveModel(config.getActiveModel()); const desiredModel = resolveModel(
config.getActiveModel(),
config.getGemini31LaunchedSync?.() ?? false,
);
const isModernModel = supportsModernFeatures(desiredModel); const isModernModel = supportsModernFeatures(desiredModel);
const activeSnippets = isModernModel ? snippets : legacySnippets; const activeSnippets = isModernModel ? snippets : legacySnippets;
return activeSnippets.getCompressionPrompt(); return activeSnippets.getCompressionPrompt();
@@ -18,11 +18,14 @@ import {
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
} from '../../config/models.js'; } from '../../config/models.js';
import { promptIdContext } from '../../utils/promptIdContext.js'; import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { AuthType } from '../../core/contentGenerator.js';
vi.mock('../../core/baseLlmClient.js'); vi.mock('../../core/baseLlmClient.js');
@@ -53,6 +56,10 @@ describe('ClassifierStrategy', () => {
}, },
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(false),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
} as unknown as Config; } as unknown as Config;
mockBaseLlmClient = { mockBaseLlmClient = {
generateJson: vi.fn(), generateJson: vi.fn(),
@@ -339,4 +346,49 @@ describe('ClassifierStrategy', () => {
// Since requestedModel is Pro, and choice is flash, it should resolve to Flash // Since requestedModel is Pro, and choice is flash, it should resolve to Flash
expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL); expect(decision?.model).toBe(DEFAULT_GEMINI_FLASH_MODEL);
}); });
describe('Gemini 3.1 and Custom Tools Routing', () => {
it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
const mockApiResponse = {
reasoning: 'Complex task',
model_choice: 'pro',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
});
it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
vi.mocked(mockConfig.getModel).mockReturnValue(PREVIEW_GEMINI_MODEL_AUTO);
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
authType: AuthType.USE_GEMINI,
});
const mockApiResponse = {
reasoning: 'Complex task',
model_choice: 'pro',
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
});
});
}); });
@@ -21,6 +21,7 @@ import {
} from '../../utils/messageInspectors.js'; } from '../../utils/messageInspectors.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { LlmRole } from '../../telemetry/types.js'; import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';
// The number of recent history turns to provide to the router for context. // The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 4; const HISTORY_TURNS_FOR_CONTEXT = 4;
@@ -169,9 +170,15 @@ export class ClassifierStrategy implements RoutingStrategy {
const reasoning = routerResponse.reasoning; const reasoning = routerResponse.reasoning;
const latencyMs = Date.now() - startTime; const latencyMs = Date.now() - startTime;
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const selectedModel = resolveClassifierModel( const selectedModel = resolveClassifierModel(
model, model,
routerResponse.model_choice, routerResponse.model_choice,
useGemini3_1,
useCustomToolModel,
); );
return { return {
@@ -21,7 +21,10 @@ export class DefaultStrategy implements TerminalStrategy {
config: Config, config: Config,
_baseLlmClient: BaseLlmClient, _baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision> { ): Promise<RoutingDecision> {
const defaultModel = resolveModel(config.getModel()); const defaultModel = resolveModel(
config.getModel(),
config.getGemini31LaunchedSync?.() ?? false,
);
return { return {
model: defaultModel, model: defaultModel,
metadata: { metadata: {
@@ -23,7 +23,10 @@ export class FallbackStrategy implements RoutingStrategy {
_baseLlmClient: BaseLlmClient, _baseLlmClient: BaseLlmClient,
): Promise<RoutingDecision | null> { ): Promise<RoutingDecision | null> {
const requestedModel = context.requestedModel ?? config.getModel(); const requestedModel = context.requestedModel ?? config.getModel();
const resolvedModel = resolveModel(requestedModel); const resolvedModel = resolveModel(
requestedModel,
config.getGemini31LaunchedSync?.() ?? false,
);
const service = config.getModelAvailabilityService(); const service = config.getModelAvailabilityService();
const snapshot = service.snapshot(resolvedModel); const snapshot = service.snapshot(resolvedModel);
@@ -12,6 +12,8 @@ import type { BaseLlmClient } from '../../core/baseLlmClient.js';
import { import {
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL_AUTO, DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
@@ -20,6 +22,7 @@ import { promptIdContext } from '../../utils/promptIdContext.js';
import type { Content } from '@google/genai'; import type { Content } from '@google/genai';
import type { ResolvedModelConfig } from '../../services/modelConfigService.js'; import type { ResolvedModelConfig } from '../../services/modelConfigService.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { AuthType } from '../../core/contentGenerator.js';
vi.mock('../../core/baseLlmClient.js'); vi.mock('../../core/baseLlmClient.js');
@@ -52,6 +55,10 @@ describe('NumericalClassifierStrategy', () => {
getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50) getSessionId: vi.fn().mockReturnValue('control-group-id'), // Default to Control Group (Hash 71 >= 50)
getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true), getNumericalRoutingEnabled: vi.fn().mockResolvedValue(true),
getClassifierThreshold: vi.fn().mockResolvedValue(undefined), getClassifierThreshold: vi.fn().mockResolvedValue(undefined),
getGemini31Launched: vi.fn().mockResolvedValue(false),
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: AuthType.LOGIN_WITH_GOOGLE,
}),
} as unknown as Config; } as unknown as Config;
mockBaseLlmClient = { mockBaseLlmClient = {
generateJson: vi.fn(), generateJson: vi.fn(),
@@ -535,4 +542,68 @@ describe('NumericalClassifierStrategy', () => {
), ),
); );
}); });
describe('Gemini 3.1 and Custom Tools Routing', () => {
it('should route to PREVIEW_GEMINI_3_1_MODEL when Gemini 3.1 is launched', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
});
it('should route to PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL when Gemini 3.1 is launched and auth is USE_GEMINI', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
authType: AuthType.USE_GEMINI,
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL);
});
it('should NOT route to custom tools model when auth is USE_VERTEX_AI', async () => {
vi.mocked(mockConfig.getGemini31Launched).mockResolvedValue(true);
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
authType: AuthType.USE_VERTEX_AI,
});
const mockApiResponse = {
complexity_reasoning: 'Complex task',
complexity_score: 80,
};
vi.mocked(mockBaseLlmClient.generateJson).mockResolvedValue(
mockApiResponse,
);
const decision = await strategy.route(
mockContext,
mockConfig,
mockBaseLlmClient,
);
expect(decision?.model).toBe(PREVIEW_GEMINI_3_1_MODEL);
});
});
}); });
@@ -17,6 +17,7 @@ import { createUserContent, Type } from '@google/genai';
import type { Config } from '../../config/config.js'; import type { Config } from '../../config/config.js';
import { debugLogger } from '../../utils/debugLogger.js'; import { debugLogger } from '../../utils/debugLogger.js';
import { LlmRole } from '../../telemetry/types.js'; import { LlmRole } from '../../telemetry/types.js';
import { AuthType } from '../../core/contentGenerator.js';
// The number of recent history turns to provide to the router for context. // The number of recent history turns to provide to the router for context.
const HISTORY_TURNS_FOR_CONTEXT = 8; const HISTORY_TURNS_FOR_CONTEXT = 8;
@@ -182,8 +183,16 @@ export class NumericalClassifierStrategy implements RoutingStrategy {
config, config,
config.getSessionId() || 'unknown-session', config.getSessionId() || 'unknown-session',
); );
const useGemini3_1 = (await config.getGemini31Launched?.()) ?? false;
const selectedModel = resolveClassifierModel(model, modelAlias); const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType === AuthType.USE_GEMINI;
const selectedModel = resolveClassifierModel(
model,
modelAlias,
useGemini3_1,
useCustomToolModel,
);
const latencyMs = Date.now() - startTime; const latencyMs = Date.now() - startTime;
@@ -33,7 +33,10 @@ export class OverrideStrategy implements RoutingStrategy {
// Return the overridden model name. // Return the overridden model name.
return { return {
model: resolveModel(overrideModel), model: resolveModel(
overrideModel,
config.getGemini31LaunchedSync?.() ?? false,
),
metadata: { metadata: {
source: this.name, source: this.name,
latencyMs: 0, latencyMs: 0,
@@ -29,6 +29,7 @@ import {
DEFAULT_GEMINI_MODEL, DEFAULT_GEMINI_MODEL,
PREVIEW_GEMINI_MODEL, PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL, PREVIEW_GEMINI_FLASH_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
} from '../config/models.js'; } from '../config/models.js';
import { PreCompressTrigger } from '../hooks/types.js'; import { PreCompressTrigger } from '../hooks/types.js';
import { LlmRole } from '../telemetry/types.js'; import { LlmRole } from '../telemetry/types.js';
@@ -101,6 +102,7 @@ export function findCompressSplitPoint(
export function modelStringToModelConfigAlias(model: string): string { export function modelStringToModelConfigAlias(model: string): string {
switch (model) { switch (model) {
case PREVIEW_GEMINI_MODEL: case PREVIEW_GEMINI_MODEL:
case PREVIEW_GEMINI_3_1_MODEL:
return 'chat-compression-3-pro'; return 'chat-compression-3-pro';
case PREVIEW_GEMINI_FLASH_MODEL: case PREVIEW_GEMINI_FLASH_MODEL:
return 'chat-compression-3-flash'; return 'chat-compression-3-flash';