refactor: Centralize and improve model fallback handling (#7634)

This commit is contained in:
Abhi
2025-09-08 16:19:52 -04:00
committed by GitHub
parent 9c71d3dd64
commit f6f2fff724
20 changed files with 1543 additions and 380 deletions
+4 -9
View File
@@ -52,6 +52,7 @@ import type { FileSystemService } from '../services/fileSystemService.js';
import { StandardFileSystemService } from '../services/fileSystemService.js';
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
import type { FallbackModelHandler } from '../fallback/types.js';
// Re-export OAuth config type
export type { MCPOAuthConfig, AnyToolInvocation };
@@ -157,12 +158,6 @@ export interface SandboxConfig {
image: string;
}
export type FlashFallbackHandler = (
currentModel: string,
fallbackModel: string,
error?: unknown,
) => Promise<boolean | string | null>;
export interface ConfigParameters {
sessionId: string;
embeddingModel?: string;
@@ -281,7 +276,7 @@ export class Config {
name: string;
extensionName: string;
}>;
flashFallbackHandler?: FlashFallbackHandler;
fallbackModelHandler?: FallbackModelHandler;
private quotaErrorOccurred: boolean = false;
private readonly summarizeToolOutput:
| Record<string, SummarizeToolOutputSettings>
@@ -490,8 +485,8 @@ export class Config {
this.inFallbackMode = active;
}
setFlashFallbackHandler(handler: FlashFallbackHandler): void {
this.flashFallbackHandler = handler;
setFallbackModelHandler(handler: FallbackModelHandler): void {
this.fallbackModelHandler = handler;
}
getMaxSessionTurns(): number {
+54 -25
View File
@@ -4,7 +4,15 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import type { Content, GenerateContentResponse, Part } from '@google/genai';
import {
@@ -212,16 +220,19 @@ describe('Gemini Client (client.ts)', () => {
let mockContentGenerator: ContentGenerator;
let mockConfig: Config;
let client: GeminiClient;
let mockGenerateContentFn: Mock;
beforeEach(async () => {
vi.resetAllMocks();
mockGenerateContentFn = vi.fn().mockResolvedValue({
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
});
// Disable 429 simulation for tests
setSimulate429(false);
mockContentGenerator = {
generateContent: vi.fn().mockResolvedValue({
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
}),
generateContent: mockGenerateContentFn,
generateContentStream: vi.fn(),
countTokens: vi.fn(),
embedContent: vi.fn(),
@@ -270,6 +281,7 @@ describe('Gemini Client (client.ts)', () => {
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
}),
getGeminiClient: vi.fn(),
isInFallbackMode: vi.fn().mockReturnValue(false),
setFallbackMode: vi.fn(),
getChatCompression: vi.fn().mockReturnValue(undefined),
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
@@ -453,6 +465,27 @@ describe('Gemini Client (client.ts)', () => {
'test-session-id',
);
});
it('should use the Flash model when fallback mode is active', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const schema = { type: 'string' };
const abortSignal = new AbortController().signal;
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
// Mock config to be in fallback mode
// We access the mock via the client instance which holds the mocked config
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
await client.generateJson(contents, schema, abortSignal, requestedModel);
// Assert that the Flash model was used, not the requested model
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
}),
'test-session-id',
);
});
});
describe('addHistory', () => {
@@ -2210,32 +2243,28 @@ ${JSON.stringify(
'test-session-id',
);
});
});
describe('handleFlashFallback', () => {
it('should use current model from config when checking for fallback', async () => {
const initialModel = client['config'].getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
it('should use the Flash model when fallback mode is active', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const generationConfig = { temperature: 0.5 };
const abortSignal = new AbortController().signal;
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
// mock config been changed
const currentModel = initialModel + '-changed';
const getModelSpy = vi.spyOn(client['config'], 'getModel');
getModelSpy.mockReturnValue(currentModel);
// Mock config to be in fallback mode
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
const mockFallbackHandler = vi.fn().mockResolvedValue(true);
client['config'].flashFallbackHandler = mockFallbackHandler;
client['config'].setModel = vi.fn();
const result = await client['handleFlashFallback'](
AuthType.LOGIN_WITH_GOOGLE,
await client.generateContent(
contents,
generationConfig,
abortSignal,
requestedModel,
);
expect(result).toBe(fallbackModel);
expect(mockFallbackHandler).toHaveBeenCalledWith(
currentModel,
fallbackModel,
undefined,
expect(mockGenerateContentFn).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
}),
'test-session-id',
);
});
});
+41 -61
View File
@@ -31,7 +31,6 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
import { tokenLimit } from './tokenLimits.js';
import type { ChatRecordingService } from '../services/chatRecordingService.js';
import type { ContentGenerator } from './contentGenerator.js';
import { AuthType } from './contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_THINKING_MODE,
@@ -49,6 +48,7 @@ import {
NextSpeakerCheckEvent,
} from '../telemetry/types.js';
import type { IdeContext, File } from '../ide/ideContext.js';
import { handleFallback } from '../fallback/handler.js';
export function isThinkingSupported(model: string) {
if (model.startsWith('gemini-2.5')) return true;
@@ -550,6 +550,8 @@ export class GeminiClient {
model: string,
config: GenerateContentConfig = {},
): Promise<Record<string, unknown>> {
let currentAttemptModel: string = model;
try {
const userMemory = this.config.getUserMemory();
const systemInstruction = getCoreSystemPrompt(userMemory);
@@ -559,10 +561,15 @@ export class GeminiClient {
...config,
};
const apiCall = () =>
this.getContentGeneratorOrFail().generateContent(
const apiCall = () => {
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: model;
currentAttemptModel = modelToUse;
return this.getContentGeneratorOrFail().generateContent(
{
model,
model: modelToUse,
config: {
...requestConfig,
systemInstruction,
@@ -573,10 +580,17 @@ export class GeminiClient {
},
this.lastPromptId,
);
};
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) =>
// Pass the captured model to the centralized handler.
await handleFallback(this.config, currentAttemptModel, authType, error);
const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});
@@ -599,7 +613,7 @@ export class GeminiClient {
if (text.startsWith(prefix) && text.endsWith(suffix)) {
logMalformedJsonResponse(
this.config,
new MalformedJsonResponseEvent(model),
new MalformedJsonResponseEvent(currentAttemptModel),
);
text = text
.substring(prefix.length, text.length - suffix.length)
@@ -655,6 +669,8 @@ export class GeminiClient {
abortSignal: AbortSignal,
model: string,
): Promise<GenerateContentResponse> {
let currentAttemptModel: string = model;
const configToUse: GenerateContentConfig = {
...this.generateContentConfig,
...generationConfig,
@@ -670,19 +686,30 @@ export class GeminiClient {
systemInstruction,
};
const apiCall = () =>
this.getContentGeneratorOrFail().generateContent(
const apiCall = () => {
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: model;
currentAttemptModel = modelToUse;
return this.getContentGeneratorOrFail().generateContent(
{
model,
model: modelToUse,
config: requestConfig,
contents,
},
this.lastPromptId,
);
};
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) =>
// Pass the captured model to the centralized handler.
await handleFallback(this.config, currentAttemptModel, authType, error);
const result = await retryWithBackoff(apiCall, {
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});
return result;
@@ -693,7 +720,7 @@ export class GeminiClient {
await reportError(
error,
`Error generating content via API with model ${model}.`,
`Error generating content via API with model ${currentAttemptModel}.`,
{
requestContents: contents,
requestConfig: configToUse,
@@ -701,7 +728,7 @@ export class GeminiClient {
'generateContent-api',
);
throw new Error(
`Failed to generate content with model ${model}: ${getErrorMessage(error)}`,
`Failed to generate content with model ${currentAttemptModel}: ${getErrorMessage(error)}`,
);
}
}
@@ -880,53 +907,6 @@ export class GeminiClient {
compressionStatus: CompressionStatus.COMPRESSED,
};
}
/**
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config; otherwise, returns null.
*/
private async handleFlashFallback(
authType?: string,
error?: unknown,
): Promise<string | null> {
// Only handle fallback for OAuth users
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
}
const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
if (currentModel === fallbackModel) {
return null;
}
// Check if config has a fallback handler (set by CLI package)
const fallbackHandler = this.config.flashFallbackHandler;
if (typeof fallbackHandler === 'function') {
try {
const accepted = await fallbackHandler(
currentModel,
fallbackModel,
error,
);
if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel);
this.config.setFallbackMode(true);
return fallbackModel;
}
// Check if the model was switched manually in the handler
if (this.config.getModel() === fallbackModel) {
return null; // Model was switched but don't continue with current prompt
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
}
return null;
}
}
export const TEST_ONLY = {
+196 -5
View File
@@ -20,6 +20,9 @@ import {
} from './geminiChat.js';
import type { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { AuthType } from './contentGenerator.js';
import { type RetryOptions } from '../utils/retry.js';
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -47,6 +50,23 @@ vi.mock('node:fs', () => {
};
});
const { mockHandleFallback } = vi.hoisted(() => ({
mockHandleFallback: vi.fn(),
}));
// Add mock for the retry utility
const { mockRetryWithBackoff } = vi.hoisted(() => ({
mockRetryWithBackoff: vi.fn(),
}));
vi.mock('../utils/retry.js', () => ({
retryWithBackoff: mockRetryWithBackoff,
}));
vi.mock('../fallback/handler.js', () => ({
handleFallback: mockHandleFallback,
}));
const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } =
vi.hoisted(() => ({
mockLogInvalidChunk: vi.fn(),
@@ -76,17 +96,21 @@ describe('GeminiChat', () => {
batchEmbedContents: vi.fn(),
} as unknown as ContentGenerator;
mockHandleFallback.mockClear();
// Default mock implementation for tests that don't care about retry logic
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
mockConfig = {
getSessionId: () => 'test-session-id',
getTelemetryLogPromptsEnabled: () => true,
getUsageStatisticsEnabled: () => true,
getDebugMode: () => false,
getContentGeneratorConfig: () => ({
authType: 'oauth-personal',
getContentGeneratorConfig: vi.fn().mockReturnValue({
authType: 'oauth-personal', // Ensure this is set for fallback tests
model: 'test-model',
}),
getModel: vi.fn().mockReturnValue('gemini-pro'),
setModel: vi.fn(),
isInFallbackMode: vi.fn().mockReturnValue(false),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(),
flashFallbackHandler: undefined,
@@ -1476,8 +1500,176 @@ describe('GeminiChat', () => {
expect(turn4.parts[0].text).toBe('second response');
});
describe('Model Resolution', () => {
const mockResponse = {
candidates: [
{
content: { parts: [{ text: 'response' }], role: 'model' },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
it('should use the configured model when not in fallback mode (sendMessage)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
await chat.sendMessage({ message: 'test' }, 'prompt-id-res1');
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: 'gemini-2.5-pro',
}),
'prompt-id-res1',
);
});
it('should use the FLASH model when in fallback mode (sendMessage)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
mockResponse,
);
await chat.sendMessage({ message: 'test' }, 'prompt-id-res2');
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
}),
'prompt-id-res2',
);
});
it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
async () =>
(async function* () {
yield mockResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-res3',
);
for await (const _ of stream) {
// consume stream
}
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
model: DEFAULT_GEMINI_FLASH_MODEL,
}),
'prompt-id-res3',
);
});
});
describe('Fallback Integration (Retries)', () => {
const error429 = Object.assign(new Error('API Error 429: Quota exceeded'), {
status: 429,
});
// Define the simulated behavior for retryWithBackoff for these tests.
// This simulation tries the apiCall, if it fails, it calls the callback,
// and then tries the apiCall again if the callback returns true.
const simulateRetryBehavior = async <T>(
apiCall: () => Promise<T>,
options: Partial<RetryOptions>,
) => {
try {
return await apiCall();
} catch (error) {
if (options.onPersistent429) {
// We simulate the "persistent" trigger here for simplicity.
const shouldRetry = await options.onPersistent429(
options.authType,
error,
);
if (shouldRetry) {
return await apiCall();
}
}
throw error; // Stop if callback returns false/null or doesn't exist
}
};
beforeEach(() => {
mockRetryWithBackoff.mockImplementation(simulateRetryBehavior);
});
afterEach(() => {
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
});
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
const FAILED_MODEL = 'gemini-2.5-pro';
vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL);
const authType = AuthType.LOGIN_WITH_GOOGLE;
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
authType,
model: FAILED_MODEL,
});
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
isInFallbackModeSpy.mockReturnValue(false);
vi.mocked(mockContentGenerator.generateContent)
.mockRejectedValueOnce(error429) // Attempt 1 fails
.mockResolvedValueOnce({
candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }],
} as unknown as GenerateContentResponse); // Attempt 2 succeeds
mockHandleFallback.mockImplementation(async () => {
isInFallbackModeSpy.mockReturnValue(true);
return true; // Signal retry
});
const result = await chat.sendMessage(
{ message: 'trigger 429' },
'prompt-id-fb1',
);
expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1);
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
expect(mockHandleFallback).toHaveBeenCalledWith(
mockConfig,
FAILED_MODEL,
authType,
error429,
);
expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'Success on retry',
);
});
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
vi.mocked(mockContentGenerator.generateContent).mockRejectedValue(
error429,
);
mockHandleFallback.mockResolvedValue(false);
await expect(
chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'),
).rejects.toThrow(error429);
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
});
});
it('should discard valid partial content from a failed attempt upon retry', async () => {
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
// Mock the stream to fail on the first attempt after yielding some valid content.
vi.mocked(mockContentGenerator.generateContentStream)
.mockImplementationOnce(async () =>
// First attempt: yields one valid chunk, then one invalid chunk
@@ -1512,7 +1704,7 @@ describe('GeminiChat', () => {
})(),
);
// ACT: Send a message and consume the stream
// Send a message and consume the stream
const stream = await chat.sendMessageStream(
{ message: 'test' },
'prompt-id-discard-test',
@@ -1522,7 +1714,6 @@ describe('GeminiChat', () => {
events.push(event);
}
// ASSERT
// Check that a retry happened
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
+41 -54
View File
@@ -18,7 +18,6 @@ import type {
import { toParts } from '../code_assist/converter.js';
import { createUserContent } from '@google/genai';
import { retryWithBackoff } from '../utils/retry.js';
import { AuthType } from './contentGenerator.js';
import type { Config } from '../config/config.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { hasCycleInSchema } from '../tools/tools.js';
@@ -35,6 +34,7 @@ import {
ContentRetryFailureEvent,
InvalidChunkEvent,
} from '../telemetry/types.js';
import { handleFallback } from '../fallback/handler.js';
import { isFunctionResponse } from '../utils/messageInspectors.js';
import { partListUnionToString } from './geminiRequest.js';
@@ -179,53 +179,6 @@ export class GeminiChat {
this.chatRecordingService.initialize();
}
/**
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
* Uses a fallback handler if provided by the config; otherwise, returns null.
*/
private async handleFlashFallback(
authType?: string,
error?: unknown,
): Promise<string | null> {
// Only handle fallback for OAuth users
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
return null;
}
const currentModel = this.config.getModel();
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
// Don't fallback if already using Flash model
if (currentModel === fallbackModel) {
return null;
}
// Check if config has a fallback handler (set by CLI package)
const fallbackHandler = this.config.flashFallbackHandler;
if (typeof fallbackHandler === 'function') {
try {
const accepted = await fallbackHandler(
currentModel,
fallbackModel,
error,
);
if (accepted !== false && accepted !== null) {
this.config.setModel(fallbackModel);
this.config.setFallbackMode(true);
return fallbackModel;
}
// Check if the model was switched manually in the handler
if (this.config.getModel() === fallbackModel) {
return null; // Model was switched but don't continue with current prompt
}
} catch (error) {
console.warn('Flash fallback handler failed:', error);
}
}
return null;
}
setSystemInstruction(sysInstr: string) {
this.generationConfig.systemInstruction = sysInstr;
}
@@ -272,8 +225,13 @@ export class GeminiChat {
let response: GenerateContentResponse;
try {
let currentAttemptModel: string | undefined;
const apiCall = () => {
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: this.config.getModel();
currentAttemptModel = modelToUse;
// Prevent Flash model calls immediately after quota error
if (
@@ -295,6 +253,19 @@ export class GeminiChat {
);
};
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) => {
if (!currentAttemptModel) return null;
return await handleFallback(
this.config,
currentAttemptModel,
authType,
error,
);
};
response = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => {
// Check for known error messages and codes.
@@ -305,8 +276,7 @@ export class GeminiChat {
}
return false; // Don't retry other errors by default
},
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});
@@ -484,8 +454,13 @@ export class GeminiChat {
prompt_id: string,
userContent: Content,
): Promise<AsyncGenerator<GenerateContentResponse>> {
let currentAttemptModel: string | undefined;
const apiCall = () => {
const modelToUse = this.config.getModel();
const modelToUse = this.config.isInFallbackMode()
? DEFAULT_GEMINI_FLASH_MODEL
: this.config.getModel();
currentAttemptModel = modelToUse;
if (
this.config.getQuotaErrorOccurred() &&
@@ -506,6 +481,19 @@ export class GeminiChat {
);
};
const onPersistent429Callback = async (
authType?: string,
error?: unknown,
) => {
if (!currentAttemptModel) return null;
return await handleFallback(
this.config,
currentAttemptModel,
authType,
error,
);
};
const streamResponse = await retryWithBackoff(apiCall, {
shouldRetry: (error: unknown) => {
if (error instanceof Error && error.message) {
@@ -515,8 +503,7 @@ export class GeminiChat {
}
return false;
},
onPersistent429: async (authType?: string, error?: unknown) =>
await this.handleFlashFallback(authType, error),
onPersistent429: onPersistent429Callback,
authType: this.config.getContentGeneratorConfig()?.authType,
});
+218
View File
@@ -0,0 +1,218 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
type Mock,
type MockInstance,
afterEach,
} from 'vitest';
import { handleFallback } from './handler.js';
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import {
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL,
} from '../config/models.js';
import { logFlashFallback } from '../telemetry/index.js';
import type { FallbackModelHandler } from './types.js';
// Mock the telemetry logger and event class
vi.mock('../telemetry/index.js', () => ({
logFlashFallback: vi.fn(),
FlashFallbackEvent: class {},
}));
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
const AUTH_API_KEY = AuthType.USE_GEMINI;
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({
isInFallbackMode: vi.fn(() => false),
setFallbackMode: vi.fn(),
fallbackHandler: undefined,
...overrides,
}) as unknown as Config;
describe('handleFallback', () => {
let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance;
beforeEach(() => {
vi.clearAllMocks();
mockHandler = vi.fn();
// Default setup: OAuth user, Pro model failed, handler injected
mockConfig = createMockConfig({
fallbackModelHandler: mockHandler,
});
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
});
afterEach(() => {
consoleErrorSpy.mockRestore();
});
it('should return null immediately if authType is not OAuth', async () => {
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_API_KEY,
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
it('should return null if the failed model is already the fallback model', async () => {
const result = await handleFallback(
mockConfig,
FALLBACK_MODEL, // Failed model is Flash
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(mockHandler).not.toHaveBeenCalled();
});
it('should return null if no fallbackHandler is injected in config', async () => {
const configWithoutHandler = createMockConfig({
fallbackModelHandler: undefined,
});
const result = await handleFallback(
configWithoutHandler,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
});
describe('when handler returns "retry"', () => {
it('should activate fallback mode, log telemetry, and return true', async () => {
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(true);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "stop"', () => {
it('should activate fallback mode, log telemetry, and return false', async () => {
mockHandler.mockResolvedValue('stop');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
expect(logFlashFallback).toHaveBeenCalled();
});
});
describe('when handler returns "auth"', () => {
it('should NOT activate fallback mode and return false', async () => {
mockHandler.mockResolvedValue('auth');
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBe(false);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
});
describe('when handler returns an unexpected value', () => {
it('should log an error and return null', async () => {
mockHandler.mockResolvedValue(null);
const result = await handleFallback(
mockConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
new Error(
'Unexpected fallback intent received from fallbackModelHandler: "null"',
),
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
const mockError = new Error('Quota Exceeded');
mockHandler.mockResolvedValue('retry');
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
expect(mockHandler).toHaveBeenCalledWith(
MOCK_PRO_MODEL,
FALLBACK_MODEL,
mockError,
);
});
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
// Setup config where fallback mode is already active
const activeFallbackConfig = createMockConfig({
fallbackModelHandler: mockHandler,
isInFallbackMode: vi.fn(() => true), // Already active
setFallbackMode: vi.fn(),
});
mockHandler.mockResolvedValue('retry');
const result = await handleFallback(
activeFallbackConfig,
MOCK_PRO_MODEL,
AUTH_OAUTH,
);
// Should still return true to allow the retry (which will use the active fallback mode)
expect(result).toBe(true);
// Should still consult the handler
expect(mockHandler).toHaveBeenCalled();
// But should not mutate state or log telemetry again
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
expect(logFlashFallback).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
const handlerError = new Error('UI interaction failed');
mockHandler.mockRejectedValue(handlerError);
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
expect(result).toBeNull();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'Fallback UI handler failed:',
handlerError,
);
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
});
});
+69
View File
@@ -0,0 +1,69 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '../config/config.js';
import { AuthType } from '../core/contentGenerator.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
export async function handleFallback(
config: Config,
failedModel: string,
authType?: string,
error?: unknown,
): Promise<string | boolean | null> {
// Applicability Checks
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
if (failedModel === fallbackModel) return null;
// Consult UI Handler for Intent
const fallbackModelHandler = config.fallbackModelHandler;
if (typeof fallbackModelHandler !== 'function') return null;
try {
// Pass the specific failed model to the UI handler.
const intent = await fallbackModelHandler(
failedModel,
fallbackModel,
error,
);
// Process Intent and Update State
switch (intent) {
case 'retry':
// Activate fallback mode. The NEXT retry attempt will pick this up.
activateFallbackMode(config, authType);
return true; // Signal retryWithBackoff to continue.
case 'stop':
activateFallbackMode(config, authType);
return false;
case 'auth':
return false;
default:
throw new Error(
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
);
}
} catch (handlerError) {
console.error('Fallback UI handler failed:', handlerError);
return null;
}
}
function activateFallbackMode(config: Config, authType: string | undefined) {
if (!config.isInFallbackMode()) {
config.setFallbackMode(true);
if (authType) {
logFlashFallback(config, new FlashFallbackEvent(authType));
}
}
}
+23
View File
@@ -0,0 +1,23 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Defines the intent returned by the UI layer during a fallback scenario.
*/
export type FallbackIntent =
| 'retry' // Immediately retry the current request with the fallback model.
| 'stop' // Switch to fallback for future requests, but stop the current request.
| 'auth'; // Stop the current request; user intends to change authentication.
/**
* The interface for the handler provided by the UI layer (e.g., the CLI)
* to interact with the user during a fallback scenario.
*/
export type FallbackModelHandler = (
failedModel: string,
fallbackModel: string,
error?: unknown,
) => Promise<FallbackIntent | null>;
+2
View File
@@ -20,6 +20,8 @@ export * from './core/geminiRequest.js';
export * from './core/coreToolScheduler.js';
export * from './core/nonInteractiveToolExecutor.js';
export * from './fallback/types.js';
export * from './code_assist/codeAssist.js';
export * from './code_assist/oauth2.js';
export * from './code_assist/server.js';
@@ -17,10 +17,13 @@ import {
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { retryWithBackoff } from './retry.js';
import { AuthType } from '../core/contentGenerator.js';
// Import the new types (Assuming this test file is in packages/core/src/utils/)
import type { FallbackModelHandler } from '../fallback/types.js';
vi.mock('node:fs');
describe('Flash Fallback Integration', () => {
// Update the description to reflect that this tests the retry utility's integration
describe('Retry Utility Fallback Integration', () => {
let config: Config;
beforeEach(() => {
@@ -41,25 +44,28 @@ describe('Flash Fallback Integration', () => {
resetRequestCounter();
});
it('should automatically accept fallback', async () => {
// Set up a minimal flash fallback handler for testing
const flashFallbackHandler = async (): Promise<boolean> => true;
// This test validates the Config's ability to store and execute the handler contract.
it('should execute the injected FallbackHandler contract correctly', async () => {
// Set up a minimal handler for testing, ensuring it matches the new type.
const fallbackHandler: FallbackModelHandler = async () => 'retry';
config.setFlashFallbackHandler(flashFallbackHandler);
// Use the generalized setter
config.setFallbackModelHandler(fallbackHandler);
// Call the handler directly to test
const result = await config.flashFallbackHandler!(
// Call the handler directly via the config property
const result = await config.fallbackModelHandler!(
'gemini-2.5-pro',
DEFAULT_GEMINI_FLASH_MODEL,
);
// Verify it automatically accepts
expect(result).toBe(true);
// Verify it returns the correct intent
expect(result).toBe('retry');
});
it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => {
// This test validates the retry utility's logic for triggering the callback.
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
let fallbackCalled = false;
let fallbackModel = '';
// Removed fallbackModel variable as it's no longer relevant here.
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
const mockApiCall = vi
@@ -68,11 +74,11 @@ describe('Flash Fallback Integration', () => {
.mockRejectedValueOnce(createSimulated429Error())
.mockResolvedValueOnce('success after fallback');
// Mock fallback handler
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
fallbackCalled = true;
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
return fallbackModel;
// Return true to signal retryWithBackoff to reset attempts and continue.
return true;
});
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
@@ -84,14 +90,13 @@ describe('Flash Fallback Integration', () => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
onPersistent429: mockPersistent429Callback,
authType: AuthType.LOGIN_WITH_GOOGLE,
});
// Verify fallback was triggered
// Verify fallback mechanism was triggered
expect(fallbackCalled).toBe(true);
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
expect(mockFallbackHandler).toHaveBeenCalledWith(
expect(mockPersistent429Callback).toHaveBeenCalledWith(
AuthType.LOGIN_WITH_GOOGLE,
expect.any(Error),
);
@@ -100,16 +105,16 @@ describe('Flash Fallback Integration', () => {
expect(mockApiCall).toHaveBeenCalledTimes(3);
});
it('should not trigger fallback for API key users', async () => {
it('should not trigger onPersistent429 for API key users', async () => {
let fallbackCalled = false;
// Mock function that simulates 429 errors
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
// Mock fallback handler
const mockFallbackHandler = vi.fn(async () => {
// Mock the callback
const mockPersistent429Callback = vi.fn(async () => {
fallbackCalled = true;
return DEFAULT_GEMINI_FLASH_MODEL;
return true;
});
// Test with API key auth type - should not trigger fallback
@@ -122,7 +127,7 @@ describe('Flash Fallback Integration', () => {
const status = (error as Error & { status?: number }).status;
return status === 429;
},
onPersistent429: mockFallbackHandler,
onPersistent429: mockPersistent429Callback,
authType: AuthType.USE_GEMINI, // API key auth type
});
} catch (error) {
@@ -132,10 +137,11 @@ describe('Flash Fallback Integration', () => {
// Verify fallback was NOT triggered for API key users
expect(fallbackCalled).toBe(false);
expect(mockFallbackHandler).not.toHaveBeenCalled();
expect(mockPersistent429Callback).not.toHaveBeenCalled();
});
it('should properly disable simulation state after fallback', () => {
// This test validates the test utilities themselves.
it('should properly disable simulation state after fallback (Test Utility)', () => {
// Enable simulation
setSimulate429(true);