chore(core): remove legacy fallback flags and migrate loop detection (#15213)

This commit is contained in:
Adam Weidman
2025-12-17 17:14:33 -05:00
committed by GitHub
parent 3d486ec1e9
commit bf6d0485ce
17 changed files with 56 additions and 419 deletions
+11 -15
View File
@@ -6,24 +6,20 @@ and provides resilience when the primary model is unavailable.
## How it works ## How it works
Model routing is not based on prompt complexity, but is a fallback mechanism. Model routing is managed by the `ModelAvailabilityService`, which monitors model
Here's how it works: health and automatically routes requests to available models based on defined
policies.
1. **Model failure:** If the currently selected model fails to respond (for 1. **Model failure:** If the currently selected model fails (e.g., due to quota
example, due to a server error or other issue), the CLI will initiate the or server errors), the CLI will iniate the fallback process.
fallback process.
2. **User consent:** The CLI will prompt you to ask if you want to switch to 2. **User consent:** Depending on the failure and the model's policy, the CLI
the fallback model. This is handled by the `fallbackModelHandler`. may prompt you to switch to a fallback model (by default always prompts
you).
3. **Fallback activation:** If you consent, the CLI will activate the fallback 3. **Model switch:** If approved, or if the policy allows for silent fallback,
mode by calling `config.setFallbackMode(true)`. the CLI will use an available fallback model for the current turn or the
remainder of the session.
4. **Model switch:** On the next request, the CLI will use the
`DEFAULT_GEMINI_FLASH_MODEL` as the fallback model. This is handled by the
`resolveModel` function in
`packages/cli/src/zed-integration/zedIntegration.ts` which checks if
`isInFallbackMode()` is true.
### Model selection precedence ### Model selection precedence
@@ -280,7 +280,6 @@ describe('Session', () => {
getTool: vi.fn().mockReturnValue(mockTool), getTool: vi.fn().mockReturnValue(mockTool),
}; };
mockConfig = { mockConfig = {
isInFallbackMode: vi.fn().mockReturnValue(false),
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getPreviewFeatures: vi.fn().mockReturnValue({}), getPreviewFeatures: vi.fn().mockReturnValue({}),
getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry),
@@ -19,7 +19,6 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
getPreviewFeatures: () => false, getPreviewFeatures: () => false,
getUserTier: () => undefined, getUserTier: () => undefined,
getModel: () => 'gemini-2.5-pro', getModel: () => 'gemini-2.5-pro',
isInFallbackMode: () => false,
...overrides, ...overrides,
}) as unknown as Config; }) as unknown as Config;
+17 -22
View File
@@ -344,10 +344,6 @@ describe('Server Config (config.ts)', () => {
mockContentConfig, mockContentConfig,
); );
// Set fallback mode to true to ensure it gets reset
config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true);
await config.refreshAuth(authType); await config.refreshAuth(authType);
expect(createContentGeneratorConfig).toHaveBeenCalledWith( expect(createContentGeneratorConfig).toHaveBeenCalledWith(
@@ -357,8 +353,6 @@ describe('Server Config (config.ts)', () => {
// Verify that contentGeneratorConfig is updated // Verify that contentGeneratorConfig is updated
expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig); expect(config.getContentGeneratorConfig()).toEqual(mockContentConfig);
expect(GeminiClient).toHaveBeenCalledWith(config); expect(GeminiClient).toHaveBeenCalledWith(config);
// Verify that fallback mode is reset
expect(config.isInFallbackMode()).toBe(false);
}); });
it('should reset model availability status', async () => { it('should reset model availability status', async () => {
@@ -1569,40 +1563,32 @@ describe('Config getHooks', () => {
}); });
describe('setModel', () => { describe('setModel', () => {
it('should allow setting a pro (any) model and disable fallback mode', () => { it('should allow setting a pro (any) model and reset availability', () => {
const config = new Config(baseParams); const config = new Config(baseParams);
const service = config.getModelAvailabilityService(); const service = config.getModelAvailabilityService();
const spy = vi.spyOn(service, 'reset'); const spy = vi.spyOn(service, 'reset');
config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true);
const proModel = 'gemini-2.5-pro'; const proModel = 'gemini-2.5-pro';
config.setModel(proModel); config.setModel(proModel);
expect(config.getModel()).toBe(proModel); expect(config.getModel()).toBe(proModel);
expect(config.isInFallbackMode()).toBe(false);
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(proModel); expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith(proModel);
expect(spy).toHaveBeenCalled(); expect(spy).toHaveBeenCalled();
}); });
it('should allow setting auto model from non-auto model and disable fallback mode', () => { it('should allow setting auto model from non-auto model and reset availability', () => {
const config = new Config(baseParams); const config = new Config(baseParams);
const service = config.getModelAvailabilityService(); const service = config.getModelAvailabilityService();
const spy = vi.spyOn(service, 'reset'); const spy = vi.spyOn(service, 'reset');
config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true);
config.setModel('auto'); config.setModel('auto');
expect(config.getModel()).toBe('auto'); expect(config.getModel()).toBe('auto');
expect(config.isInFallbackMode()).toBe(false);
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith('auto'); expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith('auto');
expect(spy).toHaveBeenCalled(); expect(spy).toHaveBeenCalled();
}); });
it('should allow setting auto model from auto model if it is in the fallback mode', () => { it('should allow setting auto model from auto model and reset availability', () => {
const config = new Config({ const config = new Config({
cwd: '/tmp', cwd: '/tmp',
targetDir: '/path/to/target', targetDir: '/path/to/target',
@@ -1614,16 +1600,25 @@ describe('Config getHooks', () => {
const service = config.getModelAvailabilityService(); const service = config.getModelAvailabilityService();
const spy = vi.spyOn(service, 'reset'); const spy = vi.spyOn(service, 'reset');
config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true);
config.setModel('auto'); config.setModel('auto');
expect(config.getModel()).toBe('auto'); expect(config.getModel()).toBe('auto');
expect(config.isInFallbackMode()).toBe(false);
expect(mockCoreEvents.emitModelChanged).toHaveBeenCalledWith('auto');
expect(spy).toHaveBeenCalled(); expect(spy).toHaveBeenCalled();
}); });
it('should reset active model when setModel is called with the current model after a fallback', () => {
const config = new Config(baseParams);
const originalModel = config.getModel();
const fallbackModel = 'fallback-model';
config.setActiveModel(fallbackModel);
expect(config.getActiveModel()).toBe(fallbackModel);
config.setModel(originalModel);
expect(config.getModel()).toBe(originalModel);
expect(config.getActiveModel()).toBe(originalModel);
});
}); });
}); });
+3 -34
View File
@@ -388,7 +388,6 @@ export class Config {
private readonly folderTrust: boolean; private readonly folderTrust: boolean;
private ideMode: boolean; private ideMode: boolean;
private inFallbackMode = false;
private _activeModel: string; private _activeModel: string;
private readonly maxSessionTurns: number; private readonly maxSessionTurns: number;
private readonly listSessions: boolean; private readonly listSessions: boolean;
@@ -447,8 +446,6 @@ export class Config {
private experimentsPromise: Promise<void> | undefined; private experimentsPromise: Promise<void> | undefined;
private hookSystem?: HookSystem; private hookSystem?: HookSystem;
private previewModelFallbackMode = false;
private previewModelBypassMode = false;
private readonly enableAgents: boolean; private readonly enableAgents: boolean;
private readonly experimentalJitContext: boolean; private readonly experimentalJitContext: boolean;
@@ -774,9 +771,6 @@ export class Config {
this.setHasAccessToPreviewModel(true); this.setHasAccessToPreviewModel(true);
} }
// Reset the session flag since we're explicitly changing auth and using default model
this.inFallbackMode = false;
// Update model if user no longer has access to the preview model // Update model if user no longer has access to the preview model
if (!this.hasAccessToPreviewModel && isPreviewModel(this.model)) { if (!this.hasAccessToPreviewModel && isPreviewModel(this.model)) {
this.setModel(DEFAULT_GEMINI_MODEL_AUTO); this.setModel(DEFAULT_GEMINI_MODEL_AUTO);
@@ -847,13 +841,12 @@ export class Config {
} }
setModel(newModel: string): void { setModel(newModel: string): void {
if (this.model !== newModel || this.inFallbackMode) { if (this.model !== newModel || this._activeModel !== newModel) {
this.model = newModel; this.model = newModel;
// When the user explicitly sets a model, that becomes the active model. // When the user explicitly sets a model, that becomes the active model.
this._activeModel = newModel; this._activeModel = newModel;
coreEvents.emitModelChanged(newModel); coreEvents.emitModelChanged(newModel);
} }
this.setFallbackMode(false);
this.modelAvailabilityService.reset(); this.modelAvailabilityService.reset();
} }
@@ -867,18 +860,6 @@ export class Config {
} }
} }
resetTurn(): void {
this.modelAvailabilityService.resetTurn();
}
isInFallbackMode(): boolean {
return this.inFallbackMode;
}
setFallbackMode(active: boolean): void {
this.inFallbackMode = active;
}
setFallbackModelHandler(handler: FallbackModelHandler): void { setFallbackModelHandler(handler: FallbackModelHandler): void {
this.fallbackModelHandler = handler; this.fallbackModelHandler = handler;
} }
@@ -887,20 +868,8 @@ export class Config {
return this.fallbackModelHandler; return this.fallbackModelHandler;
} }
isPreviewModelFallbackMode(): boolean { resetTurn(): void {
return this.previewModelFallbackMode; this.modelAvailabilityService.resetTurn();
}
setPreviewModelFallbackMode(active: boolean): void {
this.previewModelFallbackMode = active;
}
isPreviewModelBypassMode(): boolean {
return this.previewModelBypassMode;
}
setPreviewModelBypassMode(active: boolean): void {
this.previewModelBypassMode = active;
} }
getMaxSessionTurns(): number { getMaxSessionTurns(): number {
@@ -37,27 +37,6 @@ describe('Flash Model Fallback Configuration', () => {
}; };
}); });
// These tests do not actually test fallback. isInFallbackMode() only returns true,
// when setFallbackMode is marked as true. This is to decouple setting a model
// with the fallback mechanism. This will be necessary we introduce more
// intelligent model routing.
describe('setModel', () => {
it('should only mark as switched if contentGeneratorConfig exists', () => {
// Create config without initializing contentGeneratorConfig
const newConfig = new Config({
sessionId: 'test-session-2',
targetDir: '/test',
debugMode: false,
cwd: '/test',
model: DEFAULT_GEMINI_MODEL,
});
// Should not crash when contentGeneratorConfig is undefined
newConfig.setModel(DEFAULT_GEMINI_FLASH_MODEL);
expect(newConfig.isInFallbackMode()).toBe(false);
});
});
describe('getModel', () => { describe('getModel', () => {
it('should return contentGeneratorConfig model if available', () => { it('should return contentGeneratorConfig model if available', () => {
// Simulate initialized content generator config // Simulate initialized content generator config
@@ -78,26 +57,4 @@ describe('Flash Model Fallback Configuration', () => {
expect(newConfig.getModel()).toBe('custom-model'); expect(newConfig.getModel()).toBe('custom-model');
}); });
}); });
describe('isInFallbackMode', () => {
it('should start as false for new session', () => {
expect(config.isInFallbackMode()).toBe(false);
});
it('should remain false if no model switch occurs', () => {
// Perform other operations that don't involve model switching
expect(config.isInFallbackMode()).toBe(false);
});
it('should persist switched state throughout session', () => {
config.setModel(DEFAULT_GEMINI_FLASH_MODEL);
// Setting state for fallback mode as is expected of clients
config.setFallbackMode(true);
expect(config.isInFallbackMode()).toBe(true);
// Should remain true even after getting model
config.getModel();
expect(config.isInFallbackMode()).toBe(true);
});
});
}); });
@@ -117,7 +117,6 @@ describe('BaseLlmClient', () => {
setActiveModel: vi.fn(), setActiveModel: vi.fn(),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
getUserTier: vi.fn().mockReturnValue(undefined), getUserTier: vi.fn().mockReturnValue(undefined),
isInFallbackMode: vi.fn().mockReturnValue(false),
getModel: vi.fn().mockReturnValue('test-model'), getModel: vi.fn().mockReturnValue('test-model'),
getActiveModel: vi.fn().mockReturnValue('test-model'), getActiveModel: vi.fn().mockReturnValue('test-model'),
} as unknown as Mocked<Config>; } as unknown as Mocked<Config>;
+1 -68
View File
@@ -30,10 +30,7 @@ import {
type ChatCompressionInfo, type ChatCompressionInfo,
} from './turn.js'; } from './turn.js';
import { getCoreSystemPrompt } from './prompts.js'; import { getCoreSystemPrompt } from './prompts.js';
import { import { DEFAULT_GEMINI_MODEL_AUTO } from '../config/models.js';
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_GEMINI_MODEL_AUTO,
} from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
import { setSimulate429 } from '../utils/testUtils.js'; import { setSimulate429 } from '../utils/testUtils.js';
import { tokenLimit } from './tokenLimits.js'; import { tokenLimit } from './tokenLimits.js';
@@ -234,8 +231,6 @@ describe('Gemini Client (client.ts)', () => {
.mockReturnValue(mockRouterService as unknown as ModelRouterService), .mockReturnValue(mockRouterService as unknown as ModelRouterService),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
getEnableHooks: vi.fn().mockReturnValue(false), getEnableHooks: vi.fn().mockReturnValue(false),
isInFallbackMode: vi.fn().mockReturnValue(false),
setFallbackMode: vi.fn(),
getChatCompression: vi.fn().mockReturnValue(undefined), getChatCompression: vi.fn().mockReturnValue(undefined),
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false), getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
getUseSmartEdit: vi.fn().mockReturnValue(false), getUseSmartEdit: vi.fn().mockReturnValue(false),
@@ -1535,68 +1530,6 @@ ${JSON.stringify(
expect.any(AbortSignal), expect.any(AbortSignal),
); );
}); });
it('should use the fallback model and bypass routing when in fallback mode', async () => {
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
mockRouterService.route.mockResolvedValue({
model: DEFAULT_GEMINI_FLASH_MODEL,
reason: 'fallback',
});
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-1',
);
await fromAsync(stream);
expect(mockTurnRunFn).toHaveBeenCalledWith(
{ model: DEFAULT_GEMINI_FLASH_MODEL },
[{ text: 'Hi' }],
expect.any(AbortSignal),
);
});
it('should stick to the fallback model for the entire sequence even if fallback mode ends', async () => {
// Start the sequence in fallback mode
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true);
mockRouterService.route.mockResolvedValue({
model: DEFAULT_GEMINI_FLASH_MODEL,
reason: 'fallback',
});
let stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-fallback-stickiness',
);
await fromAsync(stream);
// First call should use fallback model
expect(mockTurnRunFn).toHaveBeenCalledWith(
{ model: DEFAULT_GEMINI_FLASH_MODEL },
[{ text: 'Hi' }],
expect.any(AbortSignal),
);
// End fallback mode
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
// Second call in the same sequence
stream = client.sendMessageStream(
[{ text: 'Continue' }],
new AbortController().signal,
'prompt-fallback-stickiness',
);
await fromAsync(stream);
// Router should still not be called, and it should stick to the fallback model
expect(mockTurnRunFn).toHaveBeenCalledTimes(2); // Ensure it was called again
expect(mockTurnRunFn).toHaveBeenLastCalledWith(
{ model: DEFAULT_GEMINI_FLASH_MODEL }, // Still the fallback model
[{ text: 'Continue' }],
expect.any(AbortSignal),
);
});
}); });
it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => { it('should recursively call sendMessageStream with "Please continue." when InvalidStream event is received', async () => {
@@ -31,7 +31,6 @@ const mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: vi.fn().mockReturnValue(true), getUsageStatisticsEnabled: vi.fn().mockReturnValue(true),
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
@@ -120,7 +119,6 @@ describe('createContentGenerator', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: () => true, getUsageStatisticsEnabled: () => true,
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
@@ -189,7 +187,6 @@ describe('createContentGenerator', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: () => false, getUsageStatisticsEnabled: () => false,
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
@@ -237,7 +234,6 @@ describe('createContentGenerator', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: () => false, getUsageStatisticsEnabled: () => false,
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
@@ -272,7 +268,6 @@ describe('createContentGenerator', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getProxy: vi.fn().mockReturnValue(undefined), getProxy: vi.fn().mockReturnValue(undefined),
getUsageStatisticsEnabled: () => false, getUsageStatisticsEnabled: () => false,
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
@@ -315,7 +310,6 @@ describe('createContentGenerator', () => {
const mockConfig = { const mockConfig = {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getUsageStatisticsEnabled: () => false, getUsageStatisticsEnabled: () => false,
isInFallbackMode: vi.fn().mockReturnValue(false),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
} as unknown as Config; } as unknown as Config;
const mockGenerator = { const mockGenerator = {
+4 -145
View File
@@ -17,12 +17,7 @@ import {
} from './geminiChat.js'; } from './geminiChat.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { setSimulate429 } from '../utils/testUtils.js'; import { setSimulate429 } from '../utils/testUtils.js';
import { import { DEFAULT_THINKING_MODE } from '../config/models.js';
DEFAULT_GEMINI_FLASH_MODEL,
DEFAULT_THINKING_MODE,
PREVIEW_GEMINI_MODEL,
PREVIEW_GEMINI_FLASH_MODEL,
} from '../config/models.js';
import { AuthType } from './contentGenerator.js'; import { AuthType } from './contentGenerator.js';
import { TerminalQuotaError } from '../utils/googleQuotaErrors.js'; import { TerminalQuotaError } from '../utils/googleQuotaErrors.js';
import { type RetryOptions } from '../utils/retry.js'; import { type RetryOptions } from '../utils/retry.js';
@@ -146,7 +141,6 @@ describe('GeminiChat', () => {
// When model is explicitly set, active model usually resets or updates to it // When model is explicitly set, active model usually resets or updates to it
currentActiveModel = m; currentActiveModel = m;
}), }),
isInFallbackMode: vi.fn().mockReturnValue(false),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false), getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
setQuotaErrorOccurred: vi.fn(), setQuotaErrorOccurred: vi.fn(),
flashFallbackHandler: undefined, flashFallbackHandler: undefined,
@@ -179,10 +173,6 @@ describe('GeminiChat', () => {
}; };
}), }),
}, },
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
setPreviewModelBypassMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
setPreviewModelFallbackMode: vi.fn(),
isInteractive: vi.fn().mockReturnValue(false), isInteractive: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false), getEnableHooks: vi.fn().mockReturnValue(false),
getActiveModel: vi.fn().mockImplementation(() => currentActiveModel), getActiveModel: vi.fn().mockImplementation(() => currentActiveModel),
@@ -548,105 +538,6 @@ describe('GeminiChat', () => {
); );
}); });
it('should use maxAttempts=1 for retryWithBackoff when in Preview Model Fallback Mode', async () => {
vi.mocked(mockConfig.isPreviewModelFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: PREVIEW_GEMINI_MODEL },
'test',
'prompt-id-fast-retry',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
}
expect(mockRetryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
maxAttempts: 1,
}),
);
});
it('should use maxAttempts=1 for retryWithBackoff when in Preview Model Fallback Mode (Flash)', async () => {
vi.mocked(mockConfig.isPreviewModelFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: PREVIEW_GEMINI_FLASH_MODEL },
'test',
'prompt-id-fast-retry-flash',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
}
expect(mockRetryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
maxAttempts: 1,
}),
);
});
it('should NOT use maxAttempts=1 for other models even in Preview Model Fallback Mode', async () => {
vi.mocked(mockConfig.isPreviewModelFallbackMode).mockReturnValue(true);
vi.mocked(mockContentGenerator.generateContentStream).mockResolvedValue(
(async function* () {
yield {
candidates: [
{
content: { parts: [{ text: 'Success' }] },
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})(),
);
const stream = await chat.sendMessageStream(
{ model: DEFAULT_GEMINI_FLASH_MODEL },
'test',
'prompt-id-normal-retry',
new AbortController().signal,
);
for await (const _ of stream) {
// consume stream
}
expect(mockRetryWithBackoff).toHaveBeenCalledWith(
expect.any(Function),
expect.objectContaining({
maxAttempts: undefined, // Should use default
}),
);
});
it('should throw an error when a tool call is followed by an empty stream response', async () => { it('should throw an error when a tool call is followed by an empty stream response', async () => {
// 1. Setup: A history where the model has just made a function call. // 1. Setup: A history where the model has just made a function call.
const initialHistory: Content[] = [ const initialHistory: Content[] = [
@@ -1880,9 +1771,6 @@ describe('GeminiChat', () => {
authType, authType,
}); });
const isInFallbackModeSpy = vi.spyOn(mockConfig, 'isInFallbackMode');
isInFallbackModeSpy.mockReturnValue(false);
vi.mocked(mockContentGenerator.generateContentStream) vi.mocked(mockContentGenerator.generateContentStream)
.mockRejectedValueOnce(error429) // Attempt 1 fails .mockRejectedValueOnce(error429) // Attempt 1 fails
.mockResolvedValueOnce( .mockResolvedValueOnce(
@@ -1899,10 +1787,9 @@ describe('GeminiChat', () => {
})(), })(),
); );
mockHandleFallback.mockImplementation(async () => { mockHandleFallback.mockImplementation(
isInFallbackModeSpy.mockReturnValue(true); async () => true, // Signal retry
return true; // Signal retry );
});
const stream = await chat.sendMessageStream( const stream = await chat.sendMessageStream(
{ model: 'test-model' }, { model: 'test-model' },
@@ -1931,34 +1818,6 @@ describe('GeminiChat', () => {
const modelTurn = history[1]; const modelTurn = history[1];
expect(modelTurn.parts![0].text).toBe('Success on retry'); expect(modelTurn.parts![0].text).toBe('Success on retry');
}); });
it('should stop retrying if handleFallback returns false (e.g., auth intent)', async () => {
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-pro');
vi.mocked(mockContentGenerator.generateContentStream).mockRejectedValue(
error429,
);
mockHandleFallback.mockResolvedValue(false);
const stream = await chat.sendMessageStream(
{ model: 'gemini-2.0-flash' },
'test stop',
'prompt-id-fb2',
new AbortController().signal,
);
await expect(
(async () => {
for await (const _ of stream) {
/* consume stream */
}
})(),
).rejects.toThrow(error429);
expect(mockContentGenerator.generateContentStream).toHaveBeenCalledTimes(
1,
);
expect(mockHandleFallback).toHaveBeenCalledTimes(1);
});
}); });
it('should discard valid partial content from a failed attempt upon retry', async () => { it('should discard valid partial content from a failed attempt upon retry', async () => {
+2 -25
View File
@@ -263,10 +263,6 @@ export class GeminiChat {
): Promise<AsyncGenerator<StreamEvent>> { ): Promise<AsyncGenerator<StreamEvent>> {
await this.sendPromise; await this.sendPromise;
// Preview Model Bypass mode for the new request.
// This ensures that we attempt to use Preview Model for every new user turn
// (unless the "Always" fallback mode is active, which is handled separately).
this.config.setPreviewModelBypassMode(false);
let streamDoneResolver: () => void; let streamDoneResolver: () => void;
const streamDonePromise = new Promise<void>((resolve) => { const streamDonePromise = new Promise<void>((resolve) => {
streamDoneResolver = resolve; streamDoneResolver = resolve;
@@ -299,12 +295,7 @@ export class GeminiChat {
try { try {
let lastError: unknown = new Error('Request failed after all retries.'); let lastError: unknown = new Error('Request failed after all retries.');
let maxAttempts = INVALID_CONTENT_RETRY_OPTIONS.maxAttempts; const maxAttempts = INVALID_CONTENT_RETRY_OPTIONS.maxAttempts;
// If we are in Preview Model Fallback Mode, we want to fail fast (1 attempt)
// when probing the Preview Model.
if (this.config.isPreviewModelFallbackMode() && isPreviewModel(model)) {
maxAttempts = 1;
}
for (let attempt = 0; attempt < maxAttempts; attempt++) { for (let attempt = 0; attempt < maxAttempts; attempt++) {
let isConnectionPhase = true; let isConnectionPhase = true;
@@ -378,15 +369,6 @@ export class GeminiChat {
); );
} }
throw lastError; throw lastError;
} else {
// Preview Model successfully used, disable fallback mode.
// We only do this if we didn't bypass Preview Model (i.e. we actually used it).
if (
isPreviewModel(model) &&
!this.config.isPreviewModelBypassMode()
) {
this.config.setPreviewModelFallbackMode(false);
}
} }
} finally { } finally {
streamDoneResolver!(); streamDoneResolver!();
@@ -548,12 +530,7 @@ export class GeminiChat {
authType: this.config.getContentGeneratorConfig()?.authType, authType: this.config.getContentGeneratorConfig()?.authType,
retryFetchErrors: this.config.getRetryFetchErrors(), retryFetchErrors: this.config.getRetryFetchErrors(),
signal: abortSignal, signal: abortSignal,
maxAttempts: maxAttempts: availabilityMaxAttempts,
availabilityMaxAttempts ??
(this.config.isPreviewModelFallbackMode() &&
isPreviewModel(lastModelToUse)
? 1
: undefined),
getAvailabilityContext, getAvailabilityContext,
}); });
@@ -81,7 +81,6 @@ describe('GeminiChat Network Retries', () => {
getModel: vi.fn().mockReturnValue('gemini-pro'), getModel: vi.fn().mockReturnValue('gemini-pro'),
getActiveModel: vi.fn().mockReturnValue('gemini-pro'), getActiveModel: vi.fn().mockReturnValue('gemini-pro'),
setActiveModel: vi.fn(), setActiveModel: vi.fn(),
isInFallbackMode: vi.fn().mockReturnValue(false),
getQuotaErrorOccurred: vi.fn().mockReturnValue(false), getQuotaErrorOccurred: vi.fn().mockReturnValue(false),
getProjectRoot: vi.fn().mockReturnValue('/test/project/root'), getProjectRoot: vi.fn().mockReturnValue('/test/project/root'),
storage: { storage: {
@@ -96,11 +95,7 @@ describe('GeminiChat Network Retries', () => {
generateContentConfig: { temperature: 0 }, generateContentConfig: { temperature: 0 },
})), })),
}, },
isPreviewModelBypassMode: vi.fn().mockReturnValue(false),
setPreviewModelBypassMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn().mockReturnValue(false),
getEnableHooks: vi.fn().mockReturnValue(false), getEnableHooks: vi.fn().mockReturnValue(false),
setPreviewModelFallbackMode: vi.fn(),
getModelAvailabilityService: vi getModelAvailabilityService: vi
.fn() .fn()
.mockReturnValue(createAvailabilityServiceMock()), .mockReturnValue(createAvailabilityServiceMock()),
-2
View File
@@ -69,7 +69,6 @@ describe('Core System Prompt (prompts.ts)', () => {
getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO), getModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL_AUTO),
getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL), getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
isInFallbackMode: vi.fn().mockReturnValue(false),
getAgentRegistry: vi.fn().mockReturnValue({ getAgentRegistry: vi.fn().mockReturnValue({
getDirectoryContext: vi.fn().mockReturnValue('Mock Agent Directory'), getDirectoryContext: vi.fn().mockReturnValue('Mock Agent Directory'),
}), }),
@@ -173,7 +172,6 @@ describe('Core System Prompt (prompts.ts)', () => {
getModel: vi.fn().mockReturnValue('auto'), getModel: vi.fn().mockReturnValue('auto'),
getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL), getActiveModel: vi.fn().mockReturnValue(DEFAULT_GEMINI_MODEL),
getPreviewFeatures: vi.fn().mockReturnValue(false), getPreviewFeatures: vi.fn().mockReturnValue(false),
isInFallbackMode: vi.fn().mockReturnValue(false),
getAgentRegistry: vi.fn().mockReturnValue({ getAgentRegistry: vi.fn().mockReturnValue({
getDirectoryContext: vi.fn().mockReturnValue('Mock Agent Directory'), getDirectoryContext: vi.fn().mockReturnValue('Mock Agent Directory'),
}), }),
@@ -29,7 +29,6 @@ import {
} from '../config/models.js'; } from '../config/models.js';
import type { FallbackModelHandler } from './types.js'; import type { FallbackModelHandler } from './types.js';
import { openBrowserSecurely } from '../utils/secure-browser-launcher.js'; import { openBrowserSecurely } from '../utils/secure-browser-launcher.js';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import * as policyHelpers from '../availability/policyHelpers.js'; import * as policyHelpers from '../availability/policyHelpers.js';
import { createDefaultPolicy } from '../availability/policyCatalog.js'; import { createDefaultPolicy } from '../availability/policyCatalog.js';
@@ -63,12 +62,6 @@ const AUTH_API_KEY = AuthType.USE_GEMINI;
const createMockConfig = (overrides: Partial<Config> = {}): Config => const createMockConfig = (overrides: Partial<Config> = {}): Config =>
({ ({
isInFallbackMode: vi.fn(() => false),
setFallbackMode: vi.fn(),
isPreviewModelFallbackMode: vi.fn(() => false),
setPreviewModelFallbackMode: vi.fn(),
isPreviewModelBypassMode: vi.fn(() => false),
setPreviewModelBypassMode: vi.fn(),
fallbackHandler: undefined, fallbackHandler: undefined,
getFallbackModelHandler: vi.fn(), getFallbackModelHandler: vi.fn(),
setActiveModel: vi.fn(), setActiveModel: vi.fn(),
@@ -90,7 +83,6 @@ describe('handleFallback', () => {
let mockConfig: Config; let mockConfig: Config;
let mockHandler: Mock<FallbackModelHandler>; let mockHandler: Mock<FallbackModelHandler>;
let consoleErrorSpy: MockInstance; let consoleErrorSpy: MockInstance;
let fallbackEventSpy: MockInstance;
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
@@ -106,12 +98,10 @@ describe('handleFallback', () => {
// But tests might check console.error usage in legacy code if any? // But tests might check console.error usage in legacy code if any?
// The handler uses console.error in legacyHandleFallback. // The handler uses console.error in legacyHandleFallback.
consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
fallbackEventSpy = vi.spyOn(coreEvents, 'emitFallbackModeChanged');
}); });
afterEach(() => { afterEach(() => {
consoleErrorSpy.mockRestore(); consoleErrorSpy.mockRestore();
fallbackEventSpy.mockRestore();
}); });
describe('policy-driven flow', () => { describe('policy-driven flow', () => {
@@ -211,14 +201,6 @@ describe('handleFallback', () => {
expect(policyConfig.setActiveModel).toHaveBeenCalledWith( expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL, DEFAULT_GEMINI_FLASH_MODEL,
); );
// Silent actions should not trigger the legacy fallback mode (via activateFallbackMode),
// but setActiveModel might trigger it via legacy sync if it switches to Flash.
// However, the test requirement is "doesn't emit fallback mode".
// Since we are mocking setActiveModel, we can verify setFallbackMode isn't called *independently*.
// But setActiveModel is mocked, so it won't trigger side effects unless the implementation does.
// We verified setActiveModel is called.
// We verify setFallbackMode is NOT called (which would happen if activateFallbackMode was called).
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
} finally { } finally {
chainSpy.mockRestore(); chainSpy.mockRestore();
} }
@@ -410,7 +392,6 @@ describe('handleFallback', () => {
expect(result).toBe(true); expect(result).toBe(true);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL); expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
expect(policyConfig.setFallbackMode).not.toHaveBeenCalled();
// TODO: add logging expect statement // TODO: add logging expect statement
}); });
@@ -18,6 +18,7 @@ import { GeminiEventType } from '../core/turn.js';
import * as loggers from '../telemetry/loggers.js'; import * as loggers from '../telemetry/loggers.js';
import { LoopType } from '../telemetry/types.js'; import { LoopType } from '../telemetry/types.js';
import { LoopDetectionService } from './loopDetectionService.js'; import { LoopDetectionService } from './loopDetectionService.js';
import { createAvailabilityServiceMock } from '../availability/testUtils.js';
vi.mock('../telemetry/loggers.js', () => ({ vi.mock('../telemetry/loggers.js', () => ({
logLoopDetected: vi.fn(), logLoopDetected: vi.fn(),
@@ -37,6 +38,9 @@ describe('LoopDetectionService', () => {
mockConfig = { mockConfig = {
getTelemetryEnabled: () => true, getTelemetryEnabled: () => true,
isInteractive: () => false, isInteractive: () => false,
getModelAvailabilityService: vi
.fn()
.mockReturnValue(createAvailabilityServiceMock()),
} as unknown as Config; } as unknown as Config;
service = new LoopDetectionService(mockConfig); service = new LoopDetectionService(mockConfig);
vi.clearAllMocks(); vi.clearAllMocks();
@@ -732,13 +736,15 @@ describe('LoopDetectionService LLM Checks', () => {
generateJson: vi.fn(), generateJson: vi.fn(),
} as unknown as BaseLlmClient; } as unknown as BaseLlmClient;
const mockAvailability = createAvailabilityServiceMock();
vi.mocked(mockAvailability.snapshot).mockReturnValue({ available: true });
mockConfig = { mockConfig = {
getGeminiClient: () => mockGeminiClient, getGeminiClient: () => mockGeminiClient,
getBaseLlmClient: () => mockBaseLlmClient, getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false, getDebugMode: () => false,
getTelemetryEnabled: () => true, getTelemetryEnabled: () => true,
getModel: vi.fn().mockReturnValue('cognitive-loop-v1'), getModel: vi.fn().mockReturnValue('cognitive-loop-v1'),
isInFallbackMode: vi.fn().mockReturnValue(false),
modelConfigService: { modelConfigService: {
getResolvedConfig: vi.fn().mockImplementation((key) => { getResolvedConfig: vi.fn().mockImplementation((key) => {
if (key.model === 'loop-detection') { if (key.model === 'loop-detection') {
@@ -751,6 +757,7 @@ describe('LoopDetectionService LLM Checks', () => {
}), }),
}, },
isInteractive: () => false, isInteractive: () => false,
getModelAvailabilityService: vi.fn().mockReturnValue(mockAvailability),
} as unknown as Config; } as unknown as Config;
service = new LoopDetectionService(mockConfig); service = new LoopDetectionService(mockConfig);
@@ -901,9 +908,6 @@ describe('LoopDetectionService LLM Checks', () => {
}); });
it('should detect a loop when confidence is exactly equal to the threshold (0.9)', async () => { it('should detect a loop when confidence is exactly equal to the threshold (0.9)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi mockBaseLlmClient.generateJson = vi
.fn() .fn()
.mockResolvedValueOnce({ .mockResolvedValueOnce({
@@ -944,9 +948,6 @@ describe('LoopDetectionService LLM Checks', () => {
}); });
it('should not detect a loop when Flash is confident (0.9) but Main model is not (0.89)', async () => { it('should not detect a loop when Flash is confident (0.9) but Main model is not (0.89)', async () => {
// Mock isInFallbackMode to false so it double checks
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(false);
mockBaseLlmClient.generateJson = vi mockBaseLlmClient.generateJson = vi
.fn() .fn()
.mockResolvedValueOnce({ .mockResolvedValueOnce({
@@ -988,9 +989,13 @@ describe('LoopDetectionService LLM Checks', () => {
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3); expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(3);
}); });
it('should only call Flash model if in fallback mode', async () => { it('should only call Flash model if main model is unavailable', async () => {
// Mock isInFallbackMode to true // Mock availability to return unavailable for the main model
vi.mocked(mockConfig.isInFallbackMode).mockReturnValue(true); const availability = mockConfig.getModelAvailabilityService();
vi.mocked(availability.snapshot).mockReturnValue({
available: false,
reason: 'quota',
});
mockBaseLlmClient.generateJson = vi.fn().mockResolvedValueOnce({ mockBaseLlmClient.generateJson = vi.fn().mockResolvedValueOnce({
unproductive_state_confidence: 0.9, unproductive_state_confidence: 0.9,
@@ -472,7 +472,9 @@ export class LoopDetectionService {
return false; return false;
} }
if (this.config.isInFallbackMode()) { const availability = this.config.getModelAvailabilityService();
if (!availability.snapshot(doubleCheckModelName).available) {
const flashModelName = this.config.modelConfigService.getResolvedConfig({ const flashModelName = this.config.modelConfigService.getResolvedConfig({
model: 'loop-detection', model: 'loop-detection',
}).model; }).model;
-21
View File
@@ -34,16 +34,6 @@ export interface UserFeedbackPayload {
error?: unknown; error?: unknown;
} }
/**
* Payload for the 'fallback-mode-changed' event.
*/
export interface FallbackModeChangedPayload {
/**
* Whether fallback mode is now active.
*/
isInFallbackMode: boolean;
}
/** /**
* Payload for the 'model-changed' event. * Payload for the 'model-changed' event.
*/ */
@@ -78,7 +68,6 @@ export type MemoryChangedPayload = LoadServerHierarchicalMemoryResponse;
export enum CoreEvent { export enum CoreEvent {
UserFeedback = 'user-feedback', UserFeedback = 'user-feedback',
FallbackModeChanged = 'fallback-mode-changed',
ModelChanged = 'model-changed', ModelChanged = 'model-changed',
ConsoleLog = 'console-log', ConsoleLog = 'console-log',
Output = 'output', Output = 'output',
@@ -88,7 +77,6 @@ export enum CoreEvent {
export interface CoreEvents { export interface CoreEvents {
[CoreEvent.UserFeedback]: [UserFeedbackPayload]; [CoreEvent.UserFeedback]: [UserFeedbackPayload];
[CoreEvent.FallbackModeChanged]: [FallbackModeChangedPayload];
[CoreEvent.ModelChanged]: [ModelChangedPayload]; [CoreEvent.ModelChanged]: [ModelChangedPayload];
[CoreEvent.ConsoleLog]: [ConsoleLogPayload]; [CoreEvent.ConsoleLog]: [ConsoleLogPayload];
[CoreEvent.Output]: [OutputPayload]; [CoreEvent.Output]: [OutputPayload];
@@ -166,15 +154,6 @@ export class CoreEventEmitter extends EventEmitter<CoreEvents> {
this._emitOrQueue(CoreEvent.Output, payload); this._emitOrQueue(CoreEvent.Output, payload);
} }
/**
* Notifies subscribers that fallback mode has changed.
* This is synchronous and doesn't use backlog (UI should already be initialized).
*/
emitFallbackModeChanged(isInFallbackMode: boolean): void {
const payload: FallbackModeChangedPayload = { isInFallbackMode };
this.emit(CoreEvent.FallbackModeChanged, payload);
}
/** /**
* Notifies subscribers that the model has changed. * Notifies subscribers that the model has changed.
*/ */