fix: update currentSequenceModel when modelChanged (#17051)

This commit is contained in:
Adam Weidman
2026-01-20 01:25:15 -05:00
committed by Tommaso Sciortino
parent 02e68e4554
commit 217f277580
7 changed files with 77 additions and 36 deletions

View File

@@ -166,11 +166,6 @@ describe('useQuotaAndFallback', () => {
const intent = await promise!;
expect(intent).toBe('retry_always');
// Verify activateFallbackMode was called
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
'gemini-flash',
);
// The pending request should be cleared from the state
expect(result.current.proQuotaRequest).toBeNull();
expect(mockHistoryManager.addItem).toHaveBeenCalledTimes(1);
@@ -282,11 +277,6 @@ describe('useQuotaAndFallback', () => {
const intent = await promise!;
expect(intent).toBe('retry_always');
// Verify activateFallbackMode was called
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
'model-B',
);
// The pending request should be cleared from the state
expect(result.current.proQuotaRequest).toBeNull();
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(true);
@@ -342,11 +332,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
const intent = await promise!;
expect(intent).toBe('retry_always');
// Verify activateFallbackMode was called
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
'gemini-2.5-pro',
);
expect(result.current.proQuotaRequest).toBeNull();
});
});
@@ -430,11 +415,6 @@ To disable gemini-3-pro-preview, disable "Preview features" in /settings.`,
expect(intent).toBe('retry_always');
expect(result.current.proQuotaRequest).toBeNull();
// Verify activateFallbackMode was called
expect(mockConfig.activateFallbackMode).toHaveBeenCalledWith(
'gemini-flash',
);
// Verify quota error flags are reset
expect(mockSetModelSwitchedFromQuotaError).toHaveBeenCalledWith(false);
expect(mockConfig.setQuotaErrorOccurred).toHaveBeenCalledWith(false);

View File

@@ -135,10 +135,6 @@ export function useQuotaAndFallback({
config.setQuotaErrorOccurred(false);
if (choice === 'retry_always') {
// Set the model to the fallback model for the current session.
// This ensures the Footer updates and future turns use this model.
// The change is not persisted, so the original model is restored on restart.
config.activateFallbackMode(proQuotaRequest.fallbackModel);
historyManager.addItem(
{
type: MessageType.INFO,

View File

@@ -1955,9 +1955,8 @@ export class Config {
*/
async dispose(): Promise<void> {
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
if (this.agentRegistry) {
this.agentRegistry.dispose();
}
this.agentRegistry?.dispose();
this.geminiClient?.dispose();
if (this.mcpClientManager) {
await this.mcpClientManager.stop();
}

View File

@@ -48,6 +48,7 @@ import type {
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import * as policyCatalog from '../availability/policyCatalog.js';
import { partToString } from '../utils/partUtils.js';
import { coreEvents } from '../utils/events.js';
vi.mock('../services/chatCompressionService.js');
@@ -290,6 +291,7 @@ describe('Gemini Client (client.ts)', () => {
});
afterEach(() => {
client.dispose();
vi.restoreAllMocks();
});
@@ -1579,6 +1581,55 @@ ${JSON.stringify(
expect.any(AbortSignal),
);
});
it('should re-route within the same prompt when the configured model changes', async () => {
mockTurnRunFn.mockClear();
mockTurnRunFn.mockImplementation(async function* () {
yield { type: 'content', value: 'Hello' };
});
mockRouterService.route.mockResolvedValueOnce({
model: 'original-model',
reason: 'test',
});
let stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-1',
);
await fromAsync(stream);
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1,
{ model: 'original-model' },
[{ text: 'Hi' }],
expect.any(AbortSignal),
);
mockRouterService.route.mockResolvedValue({
model: 'fallback-model',
reason: 'test',
});
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-flash');
coreEvents.emitModelChanged('gemini-2.5-flash');
stream = client.sendMessageStream(
[{ text: 'Continue' }],
new AbortController().signal,
'prompt-1',
);
await fromAsync(stream);
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
{ model: 'fallback-model' },
[{ text: 'Continue' }],
expect.any(AbortSignal),
);
});
});
it('should use getGlobalMemory for system instruction when JIT is enabled', async () => {

View File

@@ -58,6 +58,7 @@ import {
import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
const MAX_TURNS = 100;
@@ -94,8 +95,14 @@ export class GeminiClient {
this.loopDetector = new LoopDetectionService(config);
this.compressionService = new ChatCompressionService();
this.lastPromptId = this.config.getSessionId();
coreEvents.on(CoreEvent.ModelChanged, this.handleModelChanged);
}
private handleModelChanged = () => {
this.currentSequenceModel = null;
};
// Hook state to deduplicate BeforeAgent calls and track response for
// AfterAgent
private hookStateMap = new Map<
@@ -253,6 +260,10 @@ export class GeminiClient {
this.updateTelemetryTokenCount();
}
dispose() {
coreEvents.off(CoreEvent.ModelChanged, this.handleModelChanged);
}
async resumeChat(
history: Content[],
resumedSessionData?: ResumedSessionData,

View File

@@ -65,6 +65,8 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
fallbackHandler: undefined,
getFallbackModelHandler: vi.fn(),
setActiveModel: vi.fn(),
setModel: vi.fn(),
activateFallbackMode: vi.fn(),
getModelAvailabilityService: vi.fn(() =>
createAvailabilityServiceMock({
selectedModel: FALLBACK_MODEL,
@@ -198,7 +200,7 @@ describe('handleFallback', () => {
expect(result).toBe(true);
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
);
} finally {
@@ -273,7 +275,7 @@ describe('handleFallback', () => {
expect(openBrowserSecurely).toHaveBeenCalledWith(
'https://goo.gle/set-up-gemini-code-assist',
);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
@@ -378,7 +380,7 @@ describe('handleFallback', () => {
);
});
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
it('calls activateFallbackMode when handler returns "retry_always"', async () => {
policyHandler.mockResolvedValue('retry_always');
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
@@ -391,11 +393,13 @@ describe('handleFallback', () => {
);
expect(result).toBe(true);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
FALLBACK_MODEL,
);
// TODO: add logging expect statement
});
it('does NOT call setActiveModel when handler returns "stop"', async () => {
it('does NOT call activateFallbackMode when handler returns "stop"', async () => {
policyHandler.mockResolvedValue('stop');
const result = await handleFallback(
@@ -405,11 +409,11 @@ describe('handleFallback', () => {
);
expect(result).toBe(false);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
// TODO: add logging expect statement
});
it('does NOT call setActiveModel when handler returns "retry_once"', async () => {
it('does NOT call activateFallbackMode when handler returns "retry_once"', async () => {
policyHandler.mockResolvedValue('retry_once');
const result = await handleFallback(
@@ -419,7 +423,7 @@ describe('handleFallback', () => {
);
expect(result).toBe(true);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
});
});
});

View File

@@ -131,7 +131,7 @@ async function processIntent(
case 'retry_always':
// TODO(telemetry): Implement generic fallback event logging. Existing
// logFlashFallback is specific to a single Model.
config.setActiveModel(fallbackModel);
config.activateFallbackMode(fallbackModel);
return true;
case 'retry_once':