/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import { describe, it, expect, vi, beforeEach } from 'vitest'; import { resolvePolicyChain, buildFallbackPolicyContext, applyModelSelection, } from './policyHelpers.js'; import { createDefaultPolicy } from './policyCatalog.js'; import type { Config } from '../config/config.js'; import { DEFAULT_GEMINI_FLASH_LITE_MODEL, DEFAULT_GEMINI_MODEL_AUTO, PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL, PREVIEW_GEMINI_3_1_MODEL, } from '../config/models.js'; import { AuthType } from '../core/contentGenerator.js'; const createMockConfig = (overrides: Partial = {}): Config => { const config = { getUserTier: () => undefined, getModel: () => 'gemini-2.5-pro', getGemini31LaunchedSync: () => false, getUseCustomToolModelSync: () => { const useGemini31 = config.getGemini31LaunchedSync(); const authType = config.getContentGeneratorConfig().authType; return useGemini31 && authType === AuthType.USE_GEMINI; }, getContentGeneratorConfig: () => ({ authType: undefined }), ...overrides, } as unknown as Config; return config; }; describe('policyHelpers', () => { describe('resolvePolicyChain', () => { it('returns a single-model chain for a custom model', () => { const config = createMockConfig({ getModel: () => 'custom-model', }); const chain = resolvePolicyChain(config); expect(chain).toHaveLength(1); expect(chain[0]?.model).toBe('custom-model'); }); it('leaves catalog order untouched when active model already present', () => { const config = createMockConfig({ getModel: () => 'gemini-2.5-pro', }); const chain = resolvePolicyChain(config); expect(chain[0]?.model).toBe('gemini-2.5-pro'); }); it('returns the default chain when active model is "auto"', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_MODEL_AUTO, }); const chain = resolvePolicyChain(config); // Expect default chain [Pro, Flash] expect(chain).toHaveLength(2); expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); }); it('uses auto chain when preferred model is auto', () => { const config = createMockConfig({ getModel: () => 'gemini-2.5-pro', }); const chain = resolvePolicyChain(config, DEFAULT_GEMINI_MODEL_AUTO); expect(chain).toHaveLength(2); expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); }); it('uses auto chain when configured model is auto even if preferred is concrete', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_MODEL_AUTO, }); const chain = resolvePolicyChain(config, 'gemini-2.5-pro'); expect(chain).toHaveLength(2); expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); }); it('starts chain from preferredModel when model is "auto"', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_MODEL_AUTO, }); const chain = resolvePolicyChain(config, 'gemini-2.5-flash'); expect(chain).toHaveLength(1); expect(chain[0]?.model).toBe('gemini-2.5-flash'); }); it('returns flash-lite chain when preferred model is flash-lite', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_MODEL_AUTO, }); const chain = resolvePolicyChain(config, DEFAULT_GEMINI_FLASH_LITE_MODEL); expect(chain).toHaveLength(3); expect(chain[0]?.model).toBe('gemini-2.5-flash-lite'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); expect(chain[2]?.model).toBe('gemini-2.5-pro'); }); it('returns flash-lite chain when configured model is flash-lite', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_FLASH_LITE_MODEL, }); const chain = resolvePolicyChain(config); expect(chain).toHaveLength(3); expect(chain[0]?.model).toBe('gemini-2.5-flash-lite'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); expect(chain[2]?.model).toBe('gemini-2.5-pro'); }); it('wraps around the chain when wrapsAround is true', () => { const config = createMockConfig({ getModel: () => DEFAULT_GEMINI_MODEL_AUTO, }); const chain = resolvePolicyChain(config, 'gemini-2.5-flash', true); expect(chain).toHaveLength(2); expect(chain[0]?.model).toBe('gemini-2.5-flash'); expect(chain[1]?.model).toBe('gemini-2.5-pro'); }); it('proactively returns Gemini 2.5 chain if Gemini 3 requested but user lacks access', () => { const config = createMockConfig({ getModel: () => 'auto-gemini-3', getHasAccessToPreviewModel: () => false, }); const chain = resolvePolicyChain(config); // Should downgrade to [Pro 2.5, Flash 2.5] expect(chain).toHaveLength(2); expect(chain[0]?.model).toBe('gemini-2.5-pro'); expect(chain[1]?.model).toBe('gemini-2.5-flash'); }); it('returns Gemini 3.1 Pro chain when launched and auto-gemini-3 requested', () => { const config = createMockConfig({ getModel: () => 'auto-gemini-3', getGemini31LaunchedSync: () => true, }); const chain = resolvePolicyChain(config); expect(chain[0]?.model).toBe(PREVIEW_GEMINI_3_1_MODEL); expect(chain[1]?.model).toBe('gemini-3-flash-preview'); }); it('returns Gemini 3.1 Pro Custom Tools chain when launched, auth is Gemini, and auto-gemini-3 requested', () => { const config = createMockConfig({ getModel: () => 'auto-gemini-3', getGemini31LaunchedSync: () => true, getContentGeneratorConfig: () => ({ authType: AuthType.USE_GEMINI }), }); const chain = resolvePolicyChain(config); expect(chain[0]?.model).toBe(PREVIEW_GEMINI_3_1_CUSTOM_TOOLS_MODEL); expect(chain[1]?.model).toBe('gemini-3-flash-preview'); }); }); describe('buildFallbackPolicyContext', () => { it('returns remaining candidates after the failed model', () => { const chain = [ createDefaultPolicy('a'), createDefaultPolicy('b'), createDefaultPolicy('c'), ]; const context = buildFallbackPolicyContext(chain, 'b'); expect(context.failedPolicy?.model).toBe('b'); expect(context.candidates.map((p) => p.model)).toEqual(['c']); }); it('wraps around when building fallback context if wrapsAround is true', () => { const chain = [ createDefaultPolicy('a'), createDefaultPolicy('b'), createDefaultPolicy('c'), ]; const context = buildFallbackPolicyContext(chain, 'b', true); expect(context.failedPolicy?.model).toBe('b'); expect(context.candidates.map((p) => p.model)).toEqual(['c', 'a']); }); it('returns full chain when model is not in policy list', () => { const chain = [createDefaultPolicy('a'), createDefaultPolicy('b')]; const context = buildFallbackPolicyContext(chain, 'x'); expect(context.failedPolicy).toBeUndefined(); expect(context.candidates).toEqual(chain); }); }); describe('applyModelSelection', () => { const mockModelConfigService = { getResolvedConfig: vi.fn(), }; const mockAvailabilityService = { selectFirstAvailable: vi.fn(), consumeStickyAttempt: vi.fn(), }; const createExtendedMockConfig = ( overrides: Partial = {}, ): Config => { const defaults = { getModelAvailabilityService: () => mockAvailabilityService, setActiveModel: vi.fn(), modelConfigService: mockModelConfigService, }; return createMockConfig({ ...defaults, ...overrides } as Partial); }; beforeEach(() => { vi.clearAllMocks(); }); it('returns requested model if it is available', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', generateContentConfig: {}, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-pro', }); const result = applyModelSelection(config, { model: 'gemini-pro', isChatModel: true, }); expect(result.model).toBe('gemini-pro'); expect(result.maxAttempts).toBeUndefined(); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); }); it('switches to backup model and updates config if requested is unavailable', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig .mockReturnValueOnce({ model: 'gemini-pro', generateContentConfig: { temperature: 0.9, topP: 1 }, }) .mockReturnValueOnce({ model: 'gemini-flash', generateContentConfig: { temperature: 0.1, topP: 1 }, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-flash', }); const result = applyModelSelection(config, { model: 'gemini-pro', isChatModel: true, }); expect(result.model).toBe('gemini-flash'); expect(result.config).toEqual({ temperature: 0.1, topP: 1, }); expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ model: 'gemini-pro', isChatModel: true, }); expect(mockModelConfigService.getResolvedConfig).toHaveBeenCalledWith({ model: 'gemini-flash', isChatModel: true, }); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-flash'); }); it('does not call setActiveModel if isChatModel is false', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', generateContentConfig: {}, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-pro', }); applyModelSelection(config, { model: 'gemini-pro', isChatModel: false, }); expect(config.setActiveModel).not.toHaveBeenCalled(); }); it('consumes sticky attempt if indicated and isChatModel is true', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', generateContentConfig: {}, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-pro', attempts: 1, }); const result = applyModelSelection(config, { model: 'gemini-pro', isChatModel: true, }); expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( 'gemini-pro', ); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); expect(result.maxAttempts).toBe(1); }); it('consumes sticky attempt if indicated but does not call setActiveModel if isChatModel is false', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', generateContentConfig: {}, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-pro', attempts: 1, }); const result = applyModelSelection(config, { model: 'gemini-pro', isChatModel: false, }); expect(mockAvailabilityService.consumeStickyAttempt).toHaveBeenCalledWith( 'gemini-pro', ); expect(config.setActiveModel).not.toHaveBeenCalled(); expect(result.maxAttempts).toBe(1); }); it('does not consume sticky attempt if consumeAttempt is false', () => { const config = createExtendedMockConfig(); mockModelConfigService.getResolvedConfig.mockReturnValue({ model: 'gemini-pro', generateContentConfig: {}, }); mockAvailabilityService.selectFirstAvailable.mockReturnValue({ selectedModel: 'gemini-pro', attempts: 1, }); const result = applyModelSelection( config, { model: 'gemini-pro', isChatModel: true }, { consumeAttempt: false, }, ); expect( mockAvailabilityService.consumeStickyAttempt, ).not.toHaveBeenCalled(); expect(config.setActiveModel).toHaveBeenCalledWith('gemini-pro'); expect(result.maxAttempts).toBe(1); }); }); });