feat(models): support Gemini 3.1 custom tool model (#131)

* feat(models): add support for Gemini 3.1 and custom tool models

* test(routing): fix classifier and numerical classifier strategy tests

* test(routing): add Gemini 3.1 tests for classifier strategy

* fix(models): correctly filter active Gemini 3.1 models

* fix(routing): ensure useCustomToolModel is only true when Gemini 3.1 is enabled

* fix(test-utils): prevent double newline in lastFrame() on Windows

* fix(test-utils): surgically fix double newline in lastFrame() on Windows

* use custom_tools_model string for api key only

* fix(ui): correct useCustomToolModel logic and update tests

* fix(ui): correct useCustomToolModel logic in StatsDisplay

* fix(routing): ensure test models are active and sync useCustomToolModel logic
This commit is contained in:
Sehoon Shon
2026-02-19 12:25:59 -05:00
parent 2ef6149684
commit e36dfc9fc9
12 changed files with 502 additions and 38 deletions

View File

@@ -14,8 +14,14 @@ import {
DEFAULT_GEMINI_MODEL_AUTO,
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_FLASH_LITE_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
AuthType,
} from '@google/gemini-cli-core';
import type { Config, ModelSlashCommandEvent } from '@google/gemini-cli-core';
import { createMockSettings } from '../../test-utils/settings.js';
// Mock dependencies
const mockGetDisplayString = vi.fn();
@@ -42,12 +48,14 @@ describe('<ModelDialog />', () => {
const mockGetModel = vi.fn();
const mockOnClose = vi.fn();
const mockGetHasAccessToPreviewModel = vi.fn();
const mockGetGemini31LaunchedSync = vi.fn();
interface MockConfig extends Partial<Config> {
setModel: (model: string, isTemporary?: boolean) => void;
getModel: () => string;
getHasAccessToPreviewModel: () => boolean;
getIdeMode: () => boolean;
getGemini31LaunchedSync: () => boolean;
}
const mockConfig: MockConfig = {
@@ -55,12 +63,14 @@ describe('<ModelDialog />', () => {
getModel: mockGetModel,
getHasAccessToPreviewModel: mockGetHasAccessToPreviewModel,
getIdeMode: () => false,
getGemini31LaunchedSync: mockGetGemini31LaunchedSync,
};
beforeEach(() => {
vi.resetAllMocks();
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO);
mockGetHasAccessToPreviewModel.mockReturnValue(false);
mockGetGemini31LaunchedSync.mockReturnValue(false);
// Default implementation for getDisplayString
mockGetDisplayString.mockImplementation((val: string) => {
@@ -70,11 +80,24 @@ describe('<ModelDialog />', () => {
});
});
const renderComponent = (configValue = mockConfig as Config) =>
renderWithProviders(<ModelDialog onClose={mockOnClose} />, {
config: configValue,
const renderComponent = (
configValue = mockConfig as Config,
authType = AuthType.LOGIN_WITH_GOOGLE,
) => {
const settings = createMockSettings({
security: {
auth: {
selectedType: authType,
},
},
});
return renderWithProviders(<ModelDialog onClose={mockOnClose} />, {
config: configValue,
settings,
});
};
it('renders the initial "main" view correctly', () => {
const { lastFrame } = renderComponent();
expect(lastFrame()).toContain('Select Model');
@@ -210,4 +233,96 @@ describe('<ModelDialog />', () => {
expect(lastFrame()).toContain('Manual');
});
});
it('shows the preferred manual model in the main view option', () => {
mockGetModel.mockReturnValue(DEFAULT_GEMINI_MODEL);
const { lastFrame, unmount } = renderComponent();
expect(lastFrame()).toContain(`Manual (${DEFAULT_GEMINI_MODEL})`);
unmount();
});
describe('Preview Models', () => {
beforeEach(() => {
mockGetHasAccessToPreviewModel.mockReturnValue(true);
});
it('shows Auto (Preview) in main view when access is granted', () => {
const { lastFrame, unmount } = renderComponent();
expect(lastFrame()).toContain('Auto (Preview)');
unmount();
});
it('shows Gemini 3 models in manual view when Gemini 3.1 is NOT launched', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(false);
const { lastFrame, stdin, unmount } = renderComponent();
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await act(async () => {
stdin.write('\r');
});
await waitFor(() => {
const output = lastFrame();
expect(output).toContain(PREVIEW_GEMINI_MODEL);
expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL);
});
unmount();
});
it('shows Gemini 3.1 models in manual view when Gemini 3.1 IS launched', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(true);
const { lastFrame, stdin, unmount } = renderComponent(
mockConfig as Config,
AuthType.USE_VERTEX_AI,
);
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await act(async () => {
stdin.write('\r');
});
await waitFor(() => {
const output = lastFrame();
expect(output).toContain(PREVIEW_GEMINI_3_1_MODEL);
expect(output).toContain(PREVIEW_GEMINI_FLASH_MODEL);
});
unmount();
});
it('uses custom tools model when Gemini 3.1 IS launched and auth is Gemini API Key', async () => {
mockGetGemini31LaunchedSync.mockReturnValue(true);
const { stdin, unmount } = renderComponent(
mockConfig as Config,
AuthType.USE_GEMINI,
);
// Go to manual view
await act(async () => {
stdin.write('\u001B[B'); // Manual
});
await act(async () => {
stdin.write('\r');
});
// Select Gemini 3.1 (first item in preview section)
await act(async () => {
stdin.write('\r');
});
await waitFor(() => {
expect(mockSetModel).toHaveBeenCalledWith(
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
true,
);
});
unmount();
});
});
});

View File

@@ -19,11 +19,14 @@ import {
ModelSlashCommandEvent,
logModelSlashCommand,
getDisplayString,
AuthType,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
} from '@google/gemini-cli-core';
import { useKeypress } from '../hooks/useKeypress.js';
import { theme } from '../semantic-colors.js';
import { DescriptiveRadioButtonSelect } from './shared/DescriptiveRadioButtonSelect.js';
import { ConfigContext } from '../contexts/ConfigContext.js';
import { useSettings } from '../contexts/SettingsContext.js';
interface ModelDialogProps {
onClose: () => void;
@@ -31,6 +34,7 @@ interface ModelDialogProps {
export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
const config = useContext(ConfigContext);
const settings = useSettings();
const [view, setView] = useState<'main' | 'manual'>('main');
const [persistMode, setPersistMode] = useState(false);
@@ -39,6 +43,9 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
const shouldShowPreviewModels = config?.getHasAccessToPreviewModel();
const useGemini31 = config?.getGemini31LaunchedSync?.() ?? false;
const selectedAuthType = settings.merged.security.auth.selectedType;
const useCustomToolModel =
useGemini31 && selectedAuthType !== AuthType.USE_VERTEX_AI;
const manualModelSelected = useMemo(() => {
const manualModels = [
@@ -47,6 +54,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
DEFAULT_GEMINI_FLASH_LITE_MODEL,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_3_1_MODEL,
PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
];
if (manualModels.includes(preferredModel)) {
@@ -126,11 +134,19 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
];
if (shouldShowPreviewModels) {
const previewProModel = useGemini31
? PREVIEW_GEMINI_3_1_MODEL
: PREVIEW_GEMINI_MODEL;
const previewProValue = useCustomToolModel
? PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL
: previewProModel;
list.unshift(
{
value: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL,
title: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL,
key: useGemini31 ? PREVIEW_GEMINI_3_1_MODEL : PREVIEW_GEMINI_MODEL,
value: previewProValue,
title: previewProModel,
key: previewProModel,
},
{
value: PREVIEW_GEMINI_FLASH_MODEL,
@@ -140,7 +156,7 @@ export function ModelDialog({ onClose }: ModelDialogProps): React.JSX.Element {
);
}
return list;
}, [shouldShowPreviewModels, useGemini31]);
}, [shouldShowPreviewModels, useGemini31, useCustomToolModel]);
const options = view === 'main' ? mainOptions : manualOptions;

View File

@@ -26,6 +26,7 @@ import {
isActiveModel,
getDisplayString,
isAutoModel,
AuthType,
} from '@google/gemini-cli-core';
import { useSettings } from '../contexts/SettingsContext.js';
import { useConfig } from '../contexts/ConfigContext.js';
@@ -84,9 +85,12 @@ const buildModelRows = (
models: Record<string, ModelMetrics>,
quotas?: RetrieveUserQuotaResponse,
useGemini3_1 = false,
useCustomToolModel = false,
) => {
const getBaseModelName = (name: string) => name.replace('-001', '');
const usedModelNames = new Set(Object.keys(models).map(getBaseModelName));
const usedModelNames = new Set(
Object.keys(models).map(getBaseModelName).map(getDisplayString),
);
// 1. Models with active usage
const activeRows = Object.entries(models).map(([name, metrics]) => {
@@ -95,7 +99,7 @@ const buildModelRows = (
const inputTokens = metrics.tokens.input;
return {
key: name,
modelName,
modelName: getDisplayString(modelName),
requests: metrics.api.totalRequests,
cachedTokens: cachedTokens.toLocaleString(),
inputTokens: inputTokens.toLocaleString(),
@@ -111,12 +115,12 @@ const buildModelRows = (
?.filter(
(b) =>
b.modelId &&
isActiveModel(b.modelId, useGemini3_1) &&
!usedModelNames.has(b.modelId),
isActiveModel(b.modelId, useGemini3_1, useCustomToolModel) &&
!usedModelNames.has(getDisplayString(b.modelId)),
)
.map((bucket) => ({
key: bucket.modelId!,
modelName: bucket.modelId!,
modelName: getDisplayString(bucket.modelId!),
requests: '-',
cachedTokens: '-',
inputTokens: '-',
@@ -138,6 +142,7 @@ const ModelUsageTable: React.FC<{
pooledLimit?: number;
pooledResetTime?: string;
useGemini3_1?: boolean;
useCustomToolModel?: boolean;
}> = ({
models,
quotas,
@@ -147,9 +152,10 @@ const ModelUsageTable: React.FC<{
pooledRemaining,
pooledLimit,
pooledResetTime,
useGemini3_1 = false,
useGemini3_1,
useCustomToolModel,
}) => {
const rows = buildModelRows(models, quotas, useGemini3_1);
const rows = buildModelRows(models, quotas, useGemini3_1, useCustomToolModel);
if (rows.length === 0) {
return null;
@@ -407,7 +413,9 @@ export const StatsDisplay: React.FC<StatsDisplayProps> = ({
const settings = useSettings();
const config = useConfig();
const useGemini3_1 = config.getGemini31LaunchedSync?.() ?? false;
const useCustomToolModel =
useGemini3_1 &&
config.getContentGeneratorConfig().authType !== AuthType.USE_VERTEX_AI;
const pooledRemaining = quotaStats?.remaining;
const pooledLimit = quotaStats?.limit;
const pooledResetTime = quotaStats?.resetTime;
@@ -542,6 +550,7 @@ export const StatsDisplay: React.FC<StatsDisplayProps> = ({
pooledLimit={pooledLimit}
pooledResetTime={pooledResetTime}
useGemini3_1={useGemini3_1}
useCustomToolModel={useCustomToolModel}
/>
</Box>
);

View File

@@ -7,6 +7,7 @@
import type React from 'react';
import { Text, Box } from 'ink';
import { theme } from '../../semantic-colors.js';
import { getDisplayString } from '@google/gemini-cli-core';
interface ModelMessageProps {
model: string;
@@ -15,7 +16,7 @@ interface ModelMessageProps {
export const ModelMessage: React.FC<ModelMessageProps> = ({ model }) => (
<Box marginLeft={2}>
<Text color={theme.ui.comment} italic>
Responding with {model}
Responding with {getDisplayString(model)}
</Text>
</Box>
);

View File

@@ -155,9 +155,10 @@ describe('useQuotaAndFallback', () => {
expect(request?.isTerminalQuotaError).toBe(true);
const message = request!.message;
expect(message).toContain('Usage limit reached for gemini-pro.');
expect(message).toContain('Usage limit reached for all Pro models.');
expect(message).toContain('Access resets at'); // From getResetTimeMessage
expect(message).toContain('/stats model for usage details');
expect(message).toContain('/model to switch models.');
expect(message).toContain('/auth to switch to API key.');
expect(mockHistoryManager.addItem).not.toHaveBeenCalled();
@@ -176,6 +177,77 @@ describe('useQuotaAndFallback', () => {
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1);
});
it('should show the model name for a terminal quota error on a non-pro model', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
let promise: Promise<FallbackIntent | null>;
const error = new TerminalQuotaError(
'flash quota',
mockGoogleApiError,
1000 * 60 * 5,
);
act(() => {
promise = handler('gemini-flash', 'gemini-pro', error);
});
const request = result.current.proQuotaRequest;
expect(request).not.toBeNull();
expect(request?.failedModel).toBe('gemini-flash');
const message = request!.message;
expect(message).toContain('Usage limit reached for gemini-flash.');
expect(message).not.toContain('all Pro models');
act(() => {
result.current.handleProQuotaChoice('retry_later');
});
await promise!;
});
it('should handle terminal quota error without retry delay', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({
config: mockConfig,
historyManager: mockHistoryManager,
userTier: UserTierId.FREE,
setModelSwitchedFromQuotaError: mockSetModelSwitchedFromQuotaError,
onShowAuthSelection: mockOnShowAuthSelection,
}),
);
const handler = setFallbackHandlerSpy.mock
.calls[0][0] as FallbackModelHandler;
let promise: Promise<FallbackIntent | null>;
const error = new TerminalQuotaError('no delay', mockGoogleApiError);
act(() => {
promise = handler('gemini-pro', 'gemini-flash', error);
});
const request = result.current.proQuotaRequest;
const message = request!.message;
expect(message).not.toContain('Access resets at');
expect(message).toContain('Usage limit reached for all Pro models.');
act(() => {
result.current.handleProQuotaChoice('retry_later');
});
await promise!;
});
it('should handle race conditions by stopping subsequent requests', async () => {
const { result } = renderHook(() =>
useQuotaAndFallback({

View File

@@ -16,6 +16,7 @@ import {
type UserTierId,
VALID_GEMINI_MODELS,
isProModel,
getDisplayString,
} from '@google/gemini-cli-core';
import { useCallback, useEffect, useRef, useState } from 'react';
import { type UseHistoryManagerReturn } from './useHistoryManager.js';
@@ -86,7 +87,7 @@ export function useQuotaAndFallback({
) {
isModelNotFoundError = true;
const messageLines = [
`It seems like you don't have access to ${failedModel}.`,
`It seems like you don't have access to ${getDisplayString(failedModel)}.`,
`Your admin might have disabled the access. Contact them to enable the Preview Release Channel.`,
];
message = messageLines.join('\n');