mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-21 10:34:35 -07:00
refactor: Centralize and improve model fallback handling (#7634)
This commit is contained in:
@@ -52,6 +52,7 @@ import type { FileSystemService } from '../services/fileSystemService.js';
|
||||
import { StandardFileSystemService } from '../services/fileSystemService.js';
|
||||
import { logCliConfiguration, logIdeConnection } from '../telemetry/loggers.js';
|
||||
import { IdeConnectionEvent, IdeConnectionType } from '../telemetry/types.js';
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
|
||||
// Re-export OAuth config type
|
||||
export type { MCPOAuthConfig, AnyToolInvocation };
|
||||
@@ -157,12 +158,6 @@ export interface SandboxConfig {
|
||||
image: string;
|
||||
}
|
||||
|
||||
export type FlashFallbackHandler = (
|
||||
currentModel: string,
|
||||
fallbackModel: string,
|
||||
error?: unknown,
|
||||
) => Promise<boolean | string | null>;
|
||||
|
||||
export interface ConfigParameters {
|
||||
sessionId: string;
|
||||
embeddingModel?: string;
|
||||
@@ -281,7 +276,7 @@ export class Config {
|
||||
name: string;
|
||||
extensionName: string;
|
||||
}>;
|
||||
flashFallbackHandler?: FlashFallbackHandler;
|
||||
fallbackModelHandler?: FallbackModelHandler;
|
||||
private quotaErrorOccurred: boolean = false;
|
||||
private readonly summarizeToolOutput:
|
||||
| Record<string, SummarizeToolOutputSettings>
|
||||
@@ -490,8 +485,8 @@ export class Config {
|
||||
this.inFallbackMode = active;
|
||||
}
|
||||
|
||||
setFlashFallbackHandler(handler: FlashFallbackHandler): void {
|
||||
this.flashFallbackHandler = handler;
|
||||
setFallbackModelHandler(handler: FallbackModelHandler): void {
|
||||
this.fallbackModelHandler = handler;
|
||||
}
|
||||
|
||||
getMaxSessionTurns(): number {
|
||||
|
||||
@@ -4,7 +4,15 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mock,
|
||||
} from 'vitest';
|
||||
|
||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||
import {
|
||||
@@ -212,16 +220,19 @@ describe('Gemini Client (client.ts)', () => {
|
||||
let mockContentGenerator: ContentGenerator;
|
||||
let mockConfig: Config;
|
||||
let client: GeminiClient;
|
||||
let mockGenerateContentFn: Mock;
|
||||
beforeEach(async () => {
|
||||
vi.resetAllMocks();
|
||||
|
||||
mockGenerateContentFn = vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
|
||||
});
|
||||
|
||||
// Disable 429 simulation for tests
|
||||
setSimulate429(false);
|
||||
|
||||
mockContentGenerator = {
|
||||
generateContent: vi.fn().mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: '{"key": "value"}' }] } }],
|
||||
}),
|
||||
generateContent: mockGenerateContentFn,
|
||||
generateContentStream: vi.fn(),
|
||||
countTokens: vi.fn(),
|
||||
embedContent: vi.fn(),
|
||||
@@ -270,6 +281,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getDirectories: vi.fn().mockReturnValue(['/test/dir']),
|
||||
}),
|
||||
getGeminiClient: vi.fn(),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
setFallbackMode: vi.fn(),
|
||||
getChatCompression: vi.fn().mockReturnValue(undefined),
|
||||
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
|
||||
@@ -453,6 +465,27 @@ describe('Gemini Client (client.ts)', () => {
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the Flash model when fallback mode is active', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const schema = { type: 'string' };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
|
||||
|
||||
// Mock config to be in fallback mode
|
||||
// We access the mock via the client instance which holds the mocked config
|
||||
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
|
||||
|
||||
await client.generateJson(contents, schema, abortSignal, requestedModel);
|
||||
|
||||
// Assert that the Flash model was used, not the requested model
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('addHistory', () => {
|
||||
@@ -2210,32 +2243,28 @@ ${JSON.stringify(
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('handleFlashFallback', () => {
|
||||
it('should use current model from config when checking for fallback', async () => {
|
||||
const initialModel = client['config'].getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
it('should use the Flash model when fallback mode is active', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const generationConfig = { temperature: 0.5 };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const requestedModel = 'gemini-2.5-pro'; // A non-flash model
|
||||
|
||||
// mock config been changed
|
||||
const currentModel = initialModel + '-changed';
|
||||
const getModelSpy = vi.spyOn(client['config'], 'getModel');
|
||||
getModelSpy.mockReturnValue(currentModel);
|
||||
// Mock config to be in fallback mode
|
||||
vi.spyOn(client['config'], 'isInFallbackMode').mockReturnValue(true);
|
||||
|
||||
const mockFallbackHandler = vi.fn().mockResolvedValue(true);
|
||||
client['config'].flashFallbackHandler = mockFallbackHandler;
|
||||
client['config'].setModel = vi.fn();
|
||||
|
||||
const result = await client['handleFlashFallback'](
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
await client.generateContent(
|
||||
contents,
|
||||
generationConfig,
|
||||
abortSignal,
|
||||
requestedModel,
|
||||
);
|
||||
|
||||
expect(result).toBe(fallbackModel);
|
||||
|
||||
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
undefined,
|
||||
expect(mockGenerateContentFn).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'test-session-id',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,7 +31,6 @@ import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { tokenLimit } from './tokenLimits.js';
|
||||
import type { ChatRecordingService } from '../services/chatRecordingService.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_THINKING_MODE,
|
||||
@@ -49,6 +48,7 @@ import {
|
||||
NextSpeakerCheckEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import type { IdeContext, File } from '../ide/ideContext.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
|
||||
export function isThinkingSupported(model: string) {
|
||||
if (model.startsWith('gemini-2.5')) return true;
|
||||
@@ -550,6 +550,8 @@ export class GeminiClient {
|
||||
model: string,
|
||||
config: GenerateContentConfig = {},
|
||||
): Promise<Record<string, unknown>> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
try {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
@@ -559,10 +561,15 @@ export class GeminiClient {
|
||||
...config,
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: model;
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
return this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
model: modelToUse,
|
||||
config: {
|
||||
...requestConfig,
|
||||
systemInstruction,
|
||||
@@ -573,10 +580,17 @@ export class GeminiClient {
|
||||
},
|
||||
this.lastPromptId,
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) =>
|
||||
// Pass the captured model to the centralized handler.
|
||||
await handleFallback(this.config, currentAttemptModel, authType, error);
|
||||
|
||||
const result = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
@@ -599,7 +613,7 @@ export class GeminiClient {
|
||||
if (text.startsWith(prefix) && text.endsWith(suffix)) {
|
||||
logMalformedJsonResponse(
|
||||
this.config,
|
||||
new MalformedJsonResponseEvent(model),
|
||||
new MalformedJsonResponseEvent(currentAttemptModel),
|
||||
);
|
||||
text = text
|
||||
.substring(prefix.length, text.length - suffix.length)
|
||||
@@ -655,6 +669,8 @@ export class GeminiClient {
|
||||
abortSignal: AbortSignal,
|
||||
model: string,
|
||||
): Promise<GenerateContentResponse> {
|
||||
let currentAttemptModel: string = model;
|
||||
|
||||
const configToUse: GenerateContentConfig = {
|
||||
...this.generateContentConfig,
|
||||
...generationConfig,
|
||||
@@ -670,19 +686,30 @@ export class GeminiClient {
|
||||
systemInstruction,
|
||||
};
|
||||
|
||||
const apiCall = () =>
|
||||
this.getContentGeneratorOrFail().generateContent(
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: model;
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
return this.getContentGeneratorOrFail().generateContent(
|
||||
{
|
||||
model,
|
||||
model: modelToUse,
|
||||
config: requestConfig,
|
||||
contents,
|
||||
},
|
||||
this.lastPromptId,
|
||||
);
|
||||
};
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) =>
|
||||
// Pass the captured model to the centralized handler.
|
||||
await handleFallback(this.config, currentAttemptModel, authType, error);
|
||||
|
||||
const result = await retryWithBackoff(apiCall, {
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
return result;
|
||||
@@ -693,7 +720,7 @@ export class GeminiClient {
|
||||
|
||||
await reportError(
|
||||
error,
|
||||
`Error generating content via API with model ${model}.`,
|
||||
`Error generating content via API with model ${currentAttemptModel}.`,
|
||||
{
|
||||
requestContents: contents,
|
||||
requestConfig: configToUse,
|
||||
@@ -701,7 +728,7 @@ export class GeminiClient {
|
||||
'generateContent-api',
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to generate content with model ${model}: ${getErrorMessage(error)}`,
|
||||
`Failed to generate content with model ${currentAttemptModel}: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -880,53 +907,6 @@ export class GeminiClient {
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
|
||||
* Uses a fallback handler if provided by the config; otherwise, returns null.
|
||||
*/
|
||||
private async handleFlashFallback(
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentModel = this.config.getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
// Don't fallback if already using Flash model
|
||||
if (currentModel === fallbackModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if config has a fallback handler (set by CLI package)
|
||||
const fallbackHandler = this.config.flashFallbackHandler;
|
||||
if (typeof fallbackHandler === 'function') {
|
||||
try {
|
||||
const accepted = await fallbackHandler(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
// Check if the model was switched manually in the handler
|
||||
if (this.config.getModel() === fallbackModel) {
|
||||
return null; // Model was switched but don't continue with current prompt
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Flash fallback handler failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export const TEST_ONLY = {
|
||||
|
||||
@@ -20,6 +20,9 @@ import {
|
||||
} from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import { type RetryOptions } from '../utils/retry.js';
|
||||
|
||||
// Mock fs module to prevent actual file system operations during tests
|
||||
const mockFileSystem = new Map<string, string>();
|
||||
@@ -47,6 +50,23 @@ vi.mock('node:fs', () => {
|
||||
};
|
||||
});
|
||||
|
||||
const { mockHandleFallback } = vi.hoisted(() => ({
|
||||
mockHandleFallback: vi.fn(),
|
||||
}));
|
||||
|
||||
// Add mock for the retry utility
|
||||
const { mockRetryWithBackoff } = vi.hoisted(() => ({
|
||||
mockRetryWithBackoff: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/retry.js', () => ({
|
||||
retryWithBackoff: mockRetryWithBackoff,
|
||||
}));
|
||||
|
||||
vi.mock('../fallback/handler.js', () => ({
|
||||
handleFallback: mockHandleFallback,
|
||||
}));
|
||||
|
||||
const { mockLogInvalidChunk, mockLogContentRetry, mockLogContentRetryFailure } =
|
||||
vi.hoisted(() => ({
|
||||
mockLogInvalidChunk: vi.fn(),
|
||||
@@ -76,17 +96,21 @@ describe('GeminiChat', () => {
|
||||
batchEmbedContents: vi.fn(),
|
||||
} as unknown as ContentGenerator;
|
||||
|
||||
mockHandleFallback.mockClear();
|
||||
// Default mock implementation for tests that don't care about retry logic
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
mockConfig = {
|
||||
getSessionId: () => 'test-session-id',
|
||||
getTelemetryLogPromptsEnabled: () => true,
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getContentGeneratorConfig: () => ({
|
||||
authType: 'oauth-personal',
|
||||
getContentGeneratorConfig: vi.fn().mockReturnValue({
|
||||
authType: 'oauth-personal', // Ensure this is set for fallback tests
|
||||
model: 'test-model',
|
||||
}),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
setModel: vi.fn(),
|
||||
isInFallbackMode: vi.fn().mockReturnValue(false),
|
||||
getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
|
||||
setQuotaErrorOccurred: vi.fn(),
|
||||
flashFallbackHandler: undefined,
|
||||
@@ -1476,8 +1500,176 @@ describe('GeminiChat', () => {
|
||||
expect(turn4.parts[0].text).toBe('second response');
|
||||
});
|
||||
|
||||
describe('Model Resolution', () => {
|
||||
const mockResponse = {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'response' }], role: 'model' },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
|
||||
it('should use the configured model when not in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res1');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: 'gemini-2.5-pro',
|
||||
}),
|
||||
'prompt-id-res1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessage)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
vi.mocked(mockContentGenerator.generateContent).mockResolvedValue(
|
||||
mockResponse,
|
||||
);
|
||||
|
||||
await chat.sendMessage({ message: 'test' }, 'prompt-id-res2');
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'prompt-id-res2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use the FLASH model when in fallback mode (sendMessageStream)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
|
||||
vi.mocked(mockContentGenerator.generateContentStream).mockImplementation(
|
||||
async () =>
|
||||
(async function* () {
|
||||
yield mockResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-res3',
|
||||
);
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
}
|
||||
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
model: DEFAULT_GEMINI_FLASH_MODEL,
|
||||
}),
|
||||
'prompt-id-res3',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Fallback Integration (Retries)', () => {
|
||||
const error429 = Object.assign(new Error('API Error 429: Quota exceeded'), {
|
||||
status: 429,
|
||||
});
|
||||
|
||||
// Define the simulated behavior for retryWithBackoff for these tests.
|
||||
// This simulation tries the apiCall, if it fails, it calls the callback,
|
||||
// and then tries the apiCall again if the callback returns true.
|
||||
const simulateRetryBehavior = async <T>(
|
||||
apiCall: () => Promise<T>,
|
||||
options: Partial<RetryOptions>,
|
||||
) => {
|
||||
try {
|
||||
return await apiCall();
|
||||
} catch (error) {
|
||||
if (options.onPersistent429) {
|
||||
// We simulate the "persistent" trigger here for simplicity.
|
||||
const shouldRetry = await options.onPersistent429(
|
||||
options.authType,
|
||||
error,
|
||||
);
|
||||
if (shouldRetry) {
|
||||
return await apiCall();
|
||||
}
|
||||
}
|
||||
throw error; // Stop if callback returns false/null or doesn't exist
|
||||
}
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockRetryWithBackoff.mockImplementation(simulateRetryBehavior);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
mockRetryWithBackoff.mockImplementation(async (apiCall) => apiCall());
|
||||
});
|
||||
|
||||
it('should call handleFallback with the specific failed model and retry if handler returns true', async () => {
|
||||
const FAILED_MODEL = 'gemini-2.5-pro';
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue(FAILED_MODEL);
|
||||
const authType = AuthType.LOGIN_WITH_GOOGLE;
|
||||
vi.mocked(mockConfig.getContentGeneratorConfig).mockReturnValue({
|
||||
authType,
|
||||
model: FAILED_MODEL,
|
||||
});
|
||||
|
||||
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
|
||||
isInFallbackModeSpy.mockReturnValue(false);
|
||||
|
||||
vi.mocked(mockContentGenerator.generateContent)
|
||||
.mockRejectedValueOnce(error429) // Attempt 1 fails
|
||||
.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'Success on retry' }] } }],
|
||||
} as unknown as GenerateContentResponse); // Attempt 2 succeeds
|
||||
|
||||
mockHandleFallback.mockImplementation(async () => {
|
||||
isInFallbackModeSpy.mockReturnValue(true);
|
||||
return true; // Signal retry
|
||||
});
|
||||
|
||||
const result = await chat.sendMessage(
|
||||
{ message: 'trigger 429' },
|
||||
'prompt-id-fb1',
|
||||
);
|
||||
|
||||
expect(mockRetryWithBackoff).toHaveBeenCalledTimes(1);
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(2);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(mockHandleFallback).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
FAILED_MODEL,
|
||||
authType,
|
||||
error429,
|
||||
);
|
||||
|
||||
expect(result.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
|
||||
'Success on retry',
|
||||
);
|
||||
});
|
||||
|
||||
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
|
||||
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
|
||||
vi.mocked(mockContentGenerator.generateContent).mockRejectedValue(
|
||||
error429,
|
||||
);
|
||||
mockHandleFallback.mockResolvedValue(false);
|
||||
|
||||
await expect(
|
||||
chat.sendMessage({ message: 'test stop' }, 'prompt-id-fb2'),
|
||||
).rejects.toThrow(error429);
|
||||
|
||||
expect(mockContentGenerator.generateContent).toHaveBeenCalledTimes(1);
|
||||
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it('should discard valid partial content from a failed attempt upon retry', async () => {
|
||||
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
// Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
vi.mocked(mockContentGenerator.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: yields one valid chunk, then one invalid chunk
|
||||
@@ -1512,7 +1704,7 @@ describe('GeminiChat', () => {
|
||||
})(),
|
||||
);
|
||||
|
||||
// ACT: Send a message and consume the stream
|
||||
// Send a message and consume the stream
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-discard-test',
|
||||
@@ -1522,7 +1714,6 @@ describe('GeminiChat', () => {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// ASSERT
|
||||
// Check that a retry happened
|
||||
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
|
||||
|
||||
@@ -18,7 +18,6 @@ import type {
|
||||
import { toParts } from '../code_assist/converter.js';
|
||||
import { createUserContent } from '@google/genai';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { hasCycleInSchema } from '../tools/tools.js';
|
||||
@@ -35,6 +34,7 @@ import {
|
||||
ContentRetryFailureEvent,
|
||||
InvalidChunkEvent,
|
||||
} from '../telemetry/types.js';
|
||||
import { handleFallback } from '../fallback/handler.js';
|
||||
import { isFunctionResponse } from '../utils/messageInspectors.js';
|
||||
import { partListUnionToString } from './geminiRequest.js';
|
||||
|
||||
@@ -179,53 +179,6 @@ export class GeminiChat {
|
||||
this.chatRecordingService.initialize();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles falling back to Flash model when persistent 429 errors occur for OAuth users.
|
||||
* Uses a fallback handler if provided by the config; otherwise, returns null.
|
||||
*/
|
||||
private async handleFlashFallback(
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | null> {
|
||||
// Only handle fallback for OAuth users
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const currentModel = this.config.getModel();
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
// Don't fallback if already using Flash model
|
||||
if (currentModel === fallbackModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Check if config has a fallback handler (set by CLI package)
|
||||
const fallbackHandler = this.config.flashFallbackHandler;
|
||||
if (typeof fallbackHandler === 'function') {
|
||||
try {
|
||||
const accepted = await fallbackHandler(
|
||||
currentModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
if (accepted !== false && accepted !== null) {
|
||||
this.config.setModel(fallbackModel);
|
||||
this.config.setFallbackMode(true);
|
||||
return fallbackModel;
|
||||
}
|
||||
// Check if the model was switched manually in the handler
|
||||
if (this.config.getModel() === fallbackModel) {
|
||||
return null; // Model was switched but don't continue with current prompt
|
||||
}
|
||||
} catch (error) {
|
||||
console.warn('Flash fallback handler failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
setSystemInstruction(sysInstr: string) {
|
||||
this.generationConfig.systemInstruction = sysInstr;
|
||||
}
|
||||
@@ -272,8 +225,13 @@ export class GeminiChat {
|
||||
let response: GenerateContentResponse;
|
||||
|
||||
try {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.getModel() || DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
// Prevent Flash model calls immediately after quota error
|
||||
if (
|
||||
@@ -295,6 +253,19 @@ export class GeminiChat {
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
|
||||
response = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
// Check for known error messages and codes.
|
||||
@@ -305,8 +276,7 @@ export class GeminiChat {
|
||||
}
|
||||
return false; // Don't retry other errors by default
|
||||
},
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
@@ -484,8 +454,13 @@ export class GeminiChat {
|
||||
prompt_id: string,
|
||||
userContent: Content,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
let currentAttemptModel: string | undefined;
|
||||
|
||||
const apiCall = () => {
|
||||
const modelToUse = this.config.getModel();
|
||||
const modelToUse = this.config.isInFallbackMode()
|
||||
? DEFAULT_GEMINI_FLASH_MODEL
|
||||
: this.config.getModel();
|
||||
currentAttemptModel = modelToUse;
|
||||
|
||||
if (
|
||||
this.config.getQuotaErrorOccurred() &&
|
||||
@@ -506,6 +481,19 @@ export class GeminiChat {
|
||||
);
|
||||
};
|
||||
|
||||
const onPersistent429Callback = async (
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
) => {
|
||||
if (!currentAttemptModel) return null;
|
||||
return await handleFallback(
|
||||
this.config,
|
||||
currentAttemptModel,
|
||||
authType,
|
||||
error,
|
||||
);
|
||||
};
|
||||
|
||||
const streamResponse = await retryWithBackoff(apiCall, {
|
||||
shouldRetry: (error: unknown) => {
|
||||
if (error instanceof Error && error.message) {
|
||||
@@ -515,8 +503,7 @@ export class GeminiChat {
|
||||
}
|
||||
return false;
|
||||
},
|
||||
onPersistent429: async (authType?: string, error?: unknown) =>
|
||||
await this.handleFlashFallback(authType, error),
|
||||
onPersistent429: onPersistent429Callback,
|
||||
authType: this.config.getContentGeneratorConfig()?.authType,
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
type Mock,
|
||||
type MockInstance,
|
||||
afterEach,
|
||||
} from 'vitest';
|
||||
import { handleFallback } from './handler.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import {
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '../config/models.js';
|
||||
import { logFlashFallback } from '../telemetry/index.js';
|
||||
import type { FallbackModelHandler } from './types.js';
|
||||
|
||||
// Mock the telemetry logger and event class
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logFlashFallback: vi.fn(),
|
||||
FlashFallbackEvent: class {},
|
||||
}));
|
||||
|
||||
const MOCK_PRO_MODEL = DEFAULT_GEMINI_MODEL;
|
||||
const FALLBACK_MODEL = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
const AUTH_OAUTH = AuthType.LOGIN_WITH_GOOGLE;
|
||||
const AUTH_API_KEY = AuthType.USE_GEMINI;
|
||||
|
||||
const createMockConfig = (overrides: Partial<Config> = {}): Config =>
|
||||
({
|
||||
isInFallbackMode: vi.fn(() => false),
|
||||
setFallbackMode: vi.fn(),
|
||||
fallbackHandler: undefined,
|
||||
...overrides,
|
||||
}) as unknown as Config;
|
||||
|
||||
describe('handleFallback', () => {
|
||||
let mockConfig: Config;
|
||||
let mockHandler: Mock<FallbackModelHandler>;
|
||||
let consoleErrorSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockHandler = vi.fn();
|
||||
// Default setup: OAuth user, Pro model failed, handler injected
|
||||
mockConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
});
|
||||
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
consoleErrorSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should return null immediately if authType is not OAuth', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_API_KEY,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if the failed model is already the fallback model', async () => {
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
FALLBACK_MODEL, // Failed model is Flash
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return null if no fallbackHandler is injected in config', async () => {
|
||||
const configWithoutHandler = createMockConfig({
|
||||
fallbackModelHandler: undefined,
|
||||
});
|
||||
const result = await handleFallback(
|
||||
configWithoutHandler,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
describe('when handler returns "retry"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return true', async () => {
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(true);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "stop"', () => {
|
||||
it('should activate fallback mode, log telemetry, and return false', async () => {
|
||||
mockHandler.mockResolvedValue('stop');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).toHaveBeenCalledWith(true);
|
||||
expect(logFlashFallback).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns "auth"', () => {
|
||||
it('should NOT activate fallback mode and return false', async () => {
|
||||
mockHandler.mockResolvedValue('auth');
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBe(false);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when handler returns an unexpected value', () => {
|
||||
it('should log an error and return null', async () => {
|
||||
mockHandler.mockResolvedValue(null);
|
||||
|
||||
const result = await handleFallback(
|
||||
mockConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
new Error(
|
||||
'Unexpected fallback intent received from fallbackModelHandler: "null"',
|
||||
),
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should pass the correct context (failedModel, fallbackModel, error) to the handler', async () => {
|
||||
const mockError = new Error('Quota Exceeded');
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH, mockError);
|
||||
|
||||
expect(mockHandler).toHaveBeenCalledWith(
|
||||
MOCK_PRO_MODEL,
|
||||
FALLBACK_MODEL,
|
||||
mockError,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not call setFallbackMode or log telemetry if already in fallback mode', async () => {
|
||||
// Setup config where fallback mode is already active
|
||||
const activeFallbackConfig = createMockConfig({
|
||||
fallbackModelHandler: mockHandler,
|
||||
isInFallbackMode: vi.fn(() => true), // Already active
|
||||
setFallbackMode: vi.fn(),
|
||||
});
|
||||
|
||||
mockHandler.mockResolvedValue('retry');
|
||||
|
||||
const result = await handleFallback(
|
||||
activeFallbackConfig,
|
||||
MOCK_PRO_MODEL,
|
||||
AUTH_OAUTH,
|
||||
);
|
||||
|
||||
// Should still return true to allow the retry (which will use the active fallback mode)
|
||||
expect(result).toBe(true);
|
||||
// Should still consult the handler
|
||||
expect(mockHandler).toHaveBeenCalled();
|
||||
// But should not mutate state or log telemetry again
|
||||
expect(activeFallbackConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
expect(logFlashFallback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should catch errors from the handler, log an error, and return null', async () => {
|
||||
const handlerError = new Error('UI interaction failed');
|
||||
mockHandler.mockRejectedValue(handlerError);
|
||||
|
||||
const result = await handleFallback(mockConfig, MOCK_PRO_MODEL, AUTH_OAUTH);
|
||||
|
||||
expect(result).toBeNull();
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'Fallback UI handler failed:',
|
||||
handlerError,
|
||||
);
|
||||
expect(mockConfig.setFallbackMode).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '../config/config.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { logFlashFallback, FlashFallbackEvent } from '../telemetry/index.js';
|
||||
|
||||
export async function handleFallback(
|
||||
config: Config,
|
||||
failedModel: string,
|
||||
authType?: string,
|
||||
error?: unknown,
|
||||
): Promise<string | boolean | null> {
|
||||
// Applicability Checks
|
||||
if (authType !== AuthType.LOGIN_WITH_GOOGLE) return null;
|
||||
|
||||
const fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
|
||||
if (failedModel === fallbackModel) return null;
|
||||
|
||||
// Consult UI Handler for Intent
|
||||
const fallbackModelHandler = config.fallbackModelHandler;
|
||||
if (typeof fallbackModelHandler !== 'function') return null;
|
||||
|
||||
try {
|
||||
// Pass the specific failed model to the UI handler.
|
||||
const intent = await fallbackModelHandler(
|
||||
failedModel,
|
||||
fallbackModel,
|
||||
error,
|
||||
);
|
||||
|
||||
// Process Intent and Update State
|
||||
switch (intent) {
|
||||
case 'retry':
|
||||
// Activate fallback mode. The NEXT retry attempt will pick this up.
|
||||
activateFallbackMode(config, authType);
|
||||
return true; // Signal retryWithBackoff to continue.
|
||||
|
||||
case 'stop':
|
||||
activateFallbackMode(config, authType);
|
||||
return false;
|
||||
|
||||
case 'auth':
|
||||
return false;
|
||||
|
||||
default:
|
||||
throw new Error(
|
||||
`Unexpected fallback intent received from fallbackModelHandler: "${intent}"`,
|
||||
);
|
||||
}
|
||||
} catch (handlerError) {
|
||||
console.error('Fallback UI handler failed:', handlerError);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function activateFallbackMode(config: Config, authType: string | undefined) {
|
||||
if (!config.isInFallbackMode()) {
|
||||
config.setFallbackMode(true);
|
||||
if (authType) {
|
||||
logFlashFallback(config, new FlashFallbackEvent(authType));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Defines the intent returned by the UI layer during a fallback scenario.
|
||||
*/
|
||||
export type FallbackIntent =
|
||||
| 'retry' // Immediately retry the current request with the fallback model.
|
||||
| 'stop' // Switch to fallback for future requests, but stop the current request.
|
||||
| 'auth'; // Stop the current request; user intends to change authentication.
|
||||
|
||||
/**
|
||||
* The interface for the handler provided by the UI layer (e.g., the CLI)
|
||||
* to interact with the user during a fallback scenario.
|
||||
*/
|
||||
export type FallbackModelHandler = (
|
||||
failedModel: string,
|
||||
fallbackModel: string,
|
||||
error?: unknown,
|
||||
) => Promise<FallbackIntent | null>;
|
||||
@@ -20,6 +20,8 @@ export * from './core/geminiRequest.js';
|
||||
export * from './core/coreToolScheduler.js';
|
||||
export * from './core/nonInteractiveToolExecutor.js';
|
||||
|
||||
export * from './fallback/types.js';
|
||||
|
||||
export * from './code_assist/codeAssist.js';
|
||||
export * from './code_assist/oauth2.js';
|
||||
export * from './code_assist/server.js';
|
||||
|
||||
+32
-26
@@ -17,10 +17,13 @@ import {
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { retryWithBackoff } from './retry.js';
|
||||
import { AuthType } from '../core/contentGenerator.js';
|
||||
// Import the new types (Assuming this test file is in packages/core/src/utils/)
|
||||
import type { FallbackModelHandler } from '../fallback/types.js';
|
||||
|
||||
vi.mock('node:fs');
|
||||
|
||||
describe('Flash Fallback Integration', () => {
|
||||
// Update the description to reflect that this tests the retry utility's integration
|
||||
describe('Retry Utility Fallback Integration', () => {
|
||||
let config: Config;
|
||||
|
||||
beforeEach(() => {
|
||||
@@ -41,25 +44,28 @@ describe('Flash Fallback Integration', () => {
|
||||
resetRequestCounter();
|
||||
});
|
||||
|
||||
it('should automatically accept fallback', async () => {
|
||||
// Set up a minimal flash fallback handler for testing
|
||||
const flashFallbackHandler = async (): Promise<boolean> => true;
|
||||
// This test validates the Config's ability to store and execute the handler contract.
|
||||
it('should execute the injected FallbackHandler contract correctly', async () => {
|
||||
// Set up a minimal handler for testing, ensuring it matches the new type.
|
||||
const fallbackHandler: FallbackModelHandler = async () => 'retry';
|
||||
|
||||
config.setFlashFallbackHandler(flashFallbackHandler);
|
||||
// Use the generalized setter
|
||||
config.setFallbackModelHandler(fallbackHandler);
|
||||
|
||||
// Call the handler directly to test
|
||||
const result = await config.flashFallbackHandler!(
|
||||
// Call the handler directly via the config property
|
||||
const result = await config.fallbackModelHandler!(
|
||||
'gemini-2.5-pro',
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
);
|
||||
|
||||
// Verify it automatically accepts
|
||||
expect(result).toBe(true);
|
||||
// Verify it returns the correct intent
|
||||
expect(result).toBe('retry');
|
||||
});
|
||||
|
||||
it('should trigger fallback after 2 consecutive 429 errors for OAuth users', async () => {
|
||||
// This test validates the retry utility's logic for triggering the callback.
|
||||
it('should trigger onPersistent429 after 2 consecutive 429 errors for OAuth users', async () => {
|
||||
let fallbackCalled = false;
|
||||
let fallbackModel = '';
|
||||
// Removed fallbackModel variable as it's no longer relevant here.
|
||||
|
||||
// Mock function that simulates exactly 2 429 errors, then succeeds after fallback
|
||||
const mockApiCall = vi
|
||||
@@ -68,11 +74,11 @@ describe('Flash Fallback Integration', () => {
|
||||
.mockRejectedValueOnce(createSimulated429Error())
|
||||
.mockResolvedValueOnce('success after fallback');
|
||||
|
||||
// Mock fallback handler
|
||||
const mockFallbackHandler = vi.fn(async (_authType?: string) => {
|
||||
// Mock the onPersistent429 callback (this is what client.ts/geminiChat.ts provides)
|
||||
const mockPersistent429Callback = vi.fn(async (_authType?: string) => {
|
||||
fallbackCalled = true;
|
||||
fallbackModel = DEFAULT_GEMINI_FLASH_MODEL;
|
||||
return fallbackModel;
|
||||
// Return true to signal retryWithBackoff to reset attempts and continue.
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with OAuth personal auth type, with maxAttempts = 2 to ensure fallback triggers
|
||||
@@ -84,14 +90,13 @@ describe('Flash Fallback Integration', () => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockFallbackHandler,
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.LOGIN_WITH_GOOGLE,
|
||||
});
|
||||
|
||||
// Verify fallback was triggered
|
||||
// Verify fallback mechanism was triggered
|
||||
expect(fallbackCalled).toBe(true);
|
||||
expect(fallbackModel).toBe(DEFAULT_GEMINI_FLASH_MODEL);
|
||||
expect(mockFallbackHandler).toHaveBeenCalledWith(
|
||||
expect(mockPersistent429Callback).toHaveBeenCalledWith(
|
||||
AuthType.LOGIN_WITH_GOOGLE,
|
||||
expect.any(Error),
|
||||
);
|
||||
@@ -100,16 +105,16 @@ describe('Flash Fallback Integration', () => {
|
||||
expect(mockApiCall).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('should not trigger fallback for API key users', async () => {
|
||||
it('should not trigger onPersistent429 for API key users', async () => {
|
||||
let fallbackCalled = false;
|
||||
|
||||
// Mock function that simulates 429 errors
|
||||
const mockApiCall = vi.fn().mockRejectedValue(createSimulated429Error());
|
||||
|
||||
// Mock fallback handler
|
||||
const mockFallbackHandler = vi.fn(async () => {
|
||||
// Mock the callback
|
||||
const mockPersistent429Callback = vi.fn(async () => {
|
||||
fallbackCalled = true;
|
||||
return DEFAULT_GEMINI_FLASH_MODEL;
|
||||
return true;
|
||||
});
|
||||
|
||||
// Test with API key auth type - should not trigger fallback
|
||||
@@ -122,7 +127,7 @@ describe('Flash Fallback Integration', () => {
|
||||
const status = (error as Error & { status?: number }).status;
|
||||
return status === 429;
|
||||
},
|
||||
onPersistent429: mockFallbackHandler,
|
||||
onPersistent429: mockPersistent429Callback,
|
||||
authType: AuthType.USE_GEMINI, // API key auth type
|
||||
});
|
||||
} catch (error) {
|
||||
@@ -132,10 +137,11 @@ describe('Flash Fallback Integration', () => {
|
||||
|
||||
// Verify fallback was NOT triggered for API key users
|
||||
expect(fallbackCalled).toBe(false);
|
||||
expect(mockFallbackHandler).not.toHaveBeenCalled();
|
||||
expect(mockPersistent429Callback).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should properly disable simulation state after fallback', () => {
|
||||
// This test validates the test utilities themselves.
|
||||
it('should properly disable simulation state after fallback (Test Utility)', () => {
|
||||
// Enable simulation
|
||||
setSimulate429(true);
|
||||
|
||||
Reference in New Issue
Block a user