mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-12 23:21:27 -07:00
feat(core): implement towards policy-driven model fallback mechanism (#13781)
This commit is contained in:
@@ -16,6 +16,7 @@ import {
|
||||
} from 'vitest';
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { ModelAvailabilityService } from '../availability/modelAvailabilityService.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
@@ -25,6 +26,11 @@ import {
|
||||
import { logFlashFallback } from '../telemetry/index.js';
|
||||
import type { FallbackModelHandler } from './types.js';
|
||||
import { ModelNotFoundError } from '../utils/httpErrors.js';
|
||||
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
|
||||
import { coreEvents } from '../utils/events.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import * as policyHelpers from '../availability/policyHelpers.js';
|
||||
import { createDefaultPolicy } from '../availability/policyCatalog.js';
|
||||
import {
|
||||
RetryableQuotaError,
|
||||
TerminalQuotaError,
|
||||
@@ -35,22 +41,46 @@ vi.mock('../telemetry/index.js', () => ({
|
||||
logFlashFallback: vi.fn(),
|
||||
FlashFallbackEvent: class {},
|
||||
}));
|
||||
vi.mock('../utils/secure-browser-launcher.js', () => ({
|
||||
openBrowserSecurely: vi.fn(),
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
function createAvailabilityMock(
|
||||
result: ReturnType<ModelAvailabilityService['selectFirstAvailable']>,
|
||||
): ModelAvailabilityService {
|
||||
return {
|
||||
markTerminal: vi.fn(),
|
||||
markHealthy: vi.fn(),
|
||||
markRetryOncePerTurn: vi.fn(),
|
||||
consumeStickyAttempt: vi.fn(),
|
||||
snapshot: vi.fn(),
|
||||
selectFirstAvailable: vi.fn().mockReturnValue(result),
|
||||
resetTurn: vi.fn(),
|
||||
} as unknown as ModelAvailabilityService;
|
||||
}
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
isInFallbackMode: vi.fn(() => false),
|
||||
setFallbackMode: vi.fn(),
|
||||
isModelAvailabilityServiceEnabled: vi.fn(() => false),
|
||||
isPreviewModelFallbackMode: vi.fn(() => false),
|
||||
setPreviewModelFallbackMode: vi.fn(),
|
||||
isPreviewModelBypassMode: vi.fn(() => false),
|
||||
setPreviewModelBypassMode: vi.fn(),
|
||||
fallbackHandler: undefined,
|
||||
getFallbackModelHandler: vi.fn(),
|
||||
getModelAvailabilityService: vi.fn(() =>
|
||||
createAvailabilityMock({ selectedModel: FALLBACK_MODEL, skipped: [] }),
|
||||
),
|
||||
getModel: vi.fn(() => MOCK_PRO_MODEL),
|
||||
getPreviewFeatures: vi.fn(() => false),
|
||||
getUserTier: vi.fn(() => undefined),
|
||||
isInteractive: vi.fn(() => false),
|
||||
...overrides,
|
||||
}) as unknown as Config;
|
||||
@@ -59,6 +89,7 @@ describe('handleFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHandler: Mock<FallbackModelHandler>;
|
||||
let consoleErrorSpy: MockInstance;
|
||||
let fallbackEventSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
@@ -68,10 +99,12 @@ describe('handleFallback', () => {
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged');
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
fallbackEventSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null immediately if authType is not OAuth', async () => {
|
||||
@@ -140,6 +173,53 @@ describe('handleFallback', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should return false without toggling fallback when handler returns "retry_later"', async () => {
|
||||
mockHandler.mockResolvedValue('retry_later');
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
expect(fallbackEventSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should launch upgrade flow and avoid fallback mode when handler returns "upgrade"', async () => {
|
||||
mockHandler.mockResolvedValue('upgrade');
|
||||
vi.mocked(openBrowserSecurely).mockResolvedValue(undefined);
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(openBrowserSecurely).toHaveBeenCalledWith(
|
||||
'https://goo.gle/set-up-gemini-code-assist',
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
expect(fallbackEventSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should log a warning and continue when upgrade flow fails to open a browser', async () => {
|
||||
mockHandler.mockResolvedValue('upgrade');
|
||||
const debugWarnSpy = vi.spyOn(debugLogger, 'warn');
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, 'warn')
|
||||
.mockImplementation(() => {});
|
||||
vi.mocked(openBrowserSecurely).mockRejectedValue(new Error('blocked'));
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(debugWarnSpy).toHaveBeenCalledWith(
|
||||
'Failed to open browser automatically:',
|
||||
'blocked',
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(fallbackEventSpy).not.toHaveBeenCalled();
|
||||
debugWarnSpy.mockRestore();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
describe('when handler returns an unexpected value', () => {
|
||||
it('should log an error and return null', async () => {
|
||||
mockHandler.mockResolvedValue(null);
|
||||
@@ -450,4 +530,142 @@ describe('handleFallback', () => {
|
||||
expect(result).toBe(true);
|
||||
expect(mockHandler).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe('policy-driven flow', () => {
|
||||
let policyConfig: Config;
|
||||
let availability: ModelAvailabilityService;
|
||||
let policyHandler: Mock<FallbackModelHandler>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
availability = createAvailabilityMock({
|
||||
selectedModel: 'gemini-1.5-flash',
|
||||
skipped: [],
|
||||
});
|
||||
policyHandler = vi.fn().mockResolvedValue('retry_once');
|
||||
policyConfig = createMockConfig();
|
||||
vi.spyOn(
|
||||
policyConfig,
|
||||
'isModelAvailabilityServiceEnabled',
|
||||
).mockReturnValue(true);
|
||||
vi.spyOn(policyConfig, 'getModelAvailabilityService').mockReturnValue(
|
||||
availability,
|
||||
);
|
||||
vi.spyOn(policyConfig, 'getFallbackModelHandler').mockReturnValue(
|
||||
policyHandler,
|
||||
);
|
||||
});
|
||||
|
||||
it('uses availability selection when enabled', async () => {
|
||||
await handleFallback(policyConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
expect(availability.selectFirstAvailable).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('falls back to last resort when availability returns null', async () => {
|
||||
availability.selectFirstAvailable = vi
|
||||
.fn()
|
||||
.mockReturnValue({ selectedModel: null, skipped: [] });
|
||||
policyHandler.mockResolvedValue('retry_once');
|
||||
|
||||
await handleFallback(policyConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(policyHandler).toHaveBeenCalledWith(
|
||||
MOCK_PRO_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('executes silent policy action without invoking UI handler', async () => {
|
||||
const proPolicy = createDefaultPolicy(MOCK_PRO_MODEL);
|
||||
const flashPolicy = createDefaultPolicy(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
flashPolicy.actions = {
|
||||
...flashPolicy.actions,
|
||||
terminal: 'silent',
|
||||
unknown: 'silent',
|
||||
};
|
||||
flashPolicy.isLastResort = true;
|
||||
|
||||
const silentChain = [proPolicy, flashPolicy];
|
||||
const chainSpy = vi
|
||||
.spyOn(policyHelpers, 'resolvePolicyChain')
|
||||
.mockReturnValue(silentChain);
|
||||
|
||||
try {
|
||||
availability.selectFirstAvailable = vi.fn().mockReturnValue({
|
||||
selectedModel: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
skipped: [],
|
||||
});
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
|
||||
expect(policyConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
} finally {
|
||||
chainSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it('logs and returns null when handler resolves to null', async () => {
|
||||
policyHandler.mockResolvedValue(null);
|
||||
const debugLoggerErrorSpy = vi.spyOn(debugLogger, 'error');
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(debugLoggerErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback handler failed:',
|
||||
new Error(
|
||||
'Unexpected fallback intent received from fallbackModelHandler: "null"',
|
||||
),
|
||||
);
|
||||
debugLoggerErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('successfully follows expected availability response for Preview Chain', async () => {
|
||||
availability.selectFirstAvailable = vi
|
||||
.fn()
|
||||
.mockReturnValue({ selectedModel: DEFAULT_GEMINI_MODEL, skipped: [] });
|
||||
policyHandler.mockResolvedValue('retry_once');
|
||||
vi.spyOn(policyConfig, 'getPreviewFeatures').mockReturnValue(true);
|
||||
vi.spyOn(policyConfig, 'getModel').mockReturnValue(PREVIEW_GEMINI_MODEL);
|
||||
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(availability.selectFirstAvailable).toHaveBeenCalledWith([
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
]);
|
||||
expect(policyHandler).toHaveBeenCalledWith(
|
||||
PREVIEW_GEMINI_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('short-circuits when the failed model is already the last-resort policy', async () => {
|
||||
const result = await handleFallback(
|
||||
policyConfig,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(policyConfig.getModelAvailabilityService).not.toHaveBeenCalled();
|
||||
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user