fix: update currentSequenceModel when modelChanged (#17051)

This commit is contained in:
Adam Weidman
2026-01-20 01:25:15 -05:00
committed by GitHub
parent 4920ad2694
commit e34f0b4a98
7 changed files with 77 additions and 36 deletions

View File

@@ -2037,9 +2037,8 @@ export class Config {
*/
async dispose(): Promise<void> {
coreEvents.off(CoreEvent.AgentsRefreshed, this.onAgentsRefreshed);
if (this.agentRegistry) {
this.agentRegistry.dispose();
}
this.agentRegistry?.dispose();
this.geminiClient?.dispose();
if (this.mcpClientManager) {
await this.mcpClientManager.stop();
}

View File

@@ -48,6 +48,7 @@ import type {
import { ClearcutLogger } from '../telemetry/clearcut-logger/clearcut-logger.js';
import * as policyCatalog from '../availability/policyCatalog.js';
import { partToString } from '../utils/partUtils.js';
import { coreEvents } from '../utils/events.js';
// Mock fs module to prevent actual file system operations during tests
const mockFileSystem = new Map<string, string>();
@@ -281,6 +282,7 @@ describe('Gemini Client (client.ts)', () => {
});
afterEach(() => {
client.dispose();
vi.restoreAllMocks();
});
@@ -1757,6 +1759,55 @@ ${JSON.stringify(
expect.any(AbortSignal),
);
});
it('should re-route within the same prompt when the configured model changes', async () => {
mockTurnRunFn.mockClear();
mockTurnRunFn.mockImplementation(async function* () {
yield { type: 'content', value: 'Hello' };
});
mockRouterService.route.mockResolvedValueOnce({
model: 'original-model',
reason: 'test',
});
let stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-1',
);
await fromAsync(stream);
expect(mockRouterService.route).toHaveBeenCalledTimes(1);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
1,
{ model: 'original-model' },
[{ text: 'Hi' }],
expect.any(AbortSignal),
);
mockRouterService.route.mockResolvedValue({
model: 'fallback-model',
reason: 'test',
});
vi.mocked(mockConfig.getModel).mockReturnValue('gemini-2.5-flash');
coreEvents.emitModelChanged('gemini-2.5-flash');
stream = client.sendMessageStream(
[{ text: 'Continue' }],
new AbortController().signal,
'prompt-1',
);
await fromAsync(stream);
expect(mockRouterService.route).toHaveBeenCalledTimes(2);
expect(mockTurnRunFn).toHaveBeenNthCalledWith(
2,
{ model: 'fallback-model' },
[{ text: 'Continue' }],
expect.any(AbortSignal),
);
});
});
it('should use getGlobalMemory for system instruction when JIT is enabled', async () => {

View File

@@ -58,6 +58,7 @@ import {
import { resolveModel } from '../config/models.js';
import type { RetryAvailabilityContext } from '../utils/retry.js';
import { partToString } from '../utils/partUtils.js';
import { coreEvents, CoreEvent } from '../utils/events.js';
const MAX_TURNS = 100;
@@ -94,8 +95,14 @@ export class GeminiClient {
this.loopDetector = new LoopDetectionService(config);
this.compressionService = new ChatCompressionService();
this.lastPromptId = this.config.getSessionId();
coreEvents.on(CoreEvent.ModelChanged, this.handleModelChanged);
}
private handleModelChanged = () => {
this.currentSequenceModel = null;
};
// Hook state to deduplicate BeforeAgent calls and track response for
// AfterAgent
private hookStateMap = new Map<
@@ -253,6 +260,10 @@ export class GeminiClient {
this.updateTelemetryTokenCount();
}
dispose() {
coreEvents.off(CoreEvent.ModelChanged, this.handleModelChanged);
}
async resumeChat(
history: Content[],
resumedSessionData?: ResumedSessionData,

View File

@@ -65,6 +65,8 @@ const createMockConfig = (overrides: Partial<Config> = {}): Config =>
fallbackHandler: undefined,
getFallbackModelHandler: vi.fn(),
setActiveModel: vi.fn(),
setModel: vi.fn(),
activateFallbackMode: vi.fn(),
getModelAvailabilityService: vi.fn(() =>
createAvailabilityServiceMock({
selectedModel: FALLBACK_MODEL,
@@ -198,7 +200,7 @@ describe('handleFallback', () => {
expect(result).toBe(true);
expect(policyConfig.getFallbackModelHandler).not.toHaveBeenCalled();
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
DEFAULT_GEMINI_FLASH_MODEL,
);
} finally {
@@ -273,7 +275,7 @@ describe('handleFallback', () => {
expect(openBrowserSecurely).toHaveBeenCalledWith(
'https://goo.gle/set-up-gemini-code-assist',
);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
});
it('should catch errors from the handler, log an error, and return null', async () => {
@@ -378,7 +380,7 @@ describe('handleFallback', () => {
);
});
it('calls setActiveModel and logs telemetry when handler returns "retry_always"', async () => {
it('calls activateFallbackMode when handler returns "retry_always"', async () => {
policyHandler.mockResolvedValue('retry_always');
vi.mocked(policyConfig.getModel).mockReturnValue(
DEFAULT_GEMINI_MODEL_AUTO,
@@ -391,11 +393,13 @@ describe('handleFallback', () => {
);
expect(result).toBe(true);
expect(policyConfig.setActiveModel).toHaveBeenCalledWith(FALLBACK_MODEL);
expect(policyConfig.activateFallbackMode).toHaveBeenCalledWith(
FALLBACK_MODEL,
);
// TODO: add logging expect statement
});
it('does NOT call setActiveModel when handler returns "stop"', async () => {
it('does NOT call activateFallbackMode when handler returns "stop"', async () => {
policyHandler.mockResolvedValue('stop');
const result = await handleFallback(
@@ -405,11 +409,11 @@ describe('handleFallback', () => {
);
expect(result).toBe(false);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
// TODO: add logging expect statement
});
it('does NOT call setActiveModel when handler returns "retry_once"', async () => {
it('does NOT call activateFallbackMode when handler returns "retry_once"', async () => {
policyHandler.mockResolvedValue('retry_once');
const result = await handleFallback(
@@ -419,7 +423,7 @@ describe('handleFallback', () => {
);
expect(result).toBe(true);
expect(policyConfig.setActiveModel).not.toHaveBeenCalled();
expect(policyConfig.activateFallbackMode).not.toHaveBeenCalled();
});
});
});

View File

@@ -131,7 +131,7 @@ async function processIntent(
case 'retry_always':
// TODO(telemetry): Implement generic fallback event logging. Existing
// logFlashFallback is specific to a single Model.
config.setActiveModel(fallbackModel);
config.activateFallbackMode(fallbackModel);
return true;
case 'retry_once':