From 265f24e5d7d382c8b44674849935626343c9e267 Mon Sep 17 00:00:00 2001 From: Abhi <43648792+abhipatel12@users.noreply.github.com> Date: Mon, 3 Nov 2025 14:59:51 -0500 Subject: [PATCH] fix(ui): ensure model changes update the UI immediately (#12412) --- packages/cli/src/ui/AppContainer.test.tsx | 36 +++++++++++++++++++++++ packages/cli/src/ui/AppContainer.tsx | 9 +++++- packages/core/src/config/config.test.ts | 1 + packages/core/src/config/config.ts | 7 +++-- packages/core/src/utils/events.test.ts | 13 ++++++++ packages/core/src/utils/events.ts | 31 +++++++++++++++++++ 6 files changed, 94 insertions(+), 3 deletions(-) diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index 5e9ce80623..25e53d2657 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -1501,5 +1501,41 @@ describe('AppContainer State Management', () => { ); unmount(); }); + + it('updates currentModel when ModelChanged event is received', async () => { + // Arrange: Mock initial model + vi.spyOn(mockConfig, 'getModel').mockReturnValue('initial-model'); + + const { unmount } = render( + , + ); + + // Verify initial model + await act(async () => { + await vi.waitFor(() => { + expect(capturedUIState?.currentModel).toBe('initial-model'); + }); + }); + + // Get the registered handler for ModelChanged + const handler = mockCoreEvents.on.mock.calls.find( + (call: unknown[]) => call[0] === CoreEvent.ModelChanged, + )?.[1]; + expect(handler).toBeDefined(); + + // Act: Simulate ModelChanged event + act(() => { + handler({ model: 'new-model' }); + }); + + // Assert: Verify model is updated + expect(capturedUIState.currentModel).toBe('new-model'); + unmount(); + }); }); }); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 5983d2a09d..da4394877c 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -48,6 +48,7 @@ import { debugLogger, coreEvents, CoreEvent, + type ModelChangedPayload, } from '@google/gemini-cli-core'; import { validateAuthMethod } from '../config/auth.js'; import { loadHierarchicalGeminiMemory } from '../config/config.js'; @@ -258,16 +259,22 @@ export const AppContainer = (props: AppContainerProps) => { [historyManager.addItem], ); - // Subscribe to fallback mode changes from core + // Subscribe to fallback mode and model changes from core useEffect(() => { const handleFallbackModeChanged = () => { const effectiveModel = getEffectiveModel(); setCurrentModel(effectiveModel); }; + const handleModelChanged = (payload: ModelChangedPayload) => { + setCurrentModel(payload.model); + }; + coreEvents.on(CoreEvent.FallbackModeChanged, handleFallbackModeChanged); + coreEvents.on(CoreEvent.ModelChanged, handleModelChanged); return () => { coreEvents.off(CoreEvent.FallbackModeChanged, handleFallbackModeChanged); + coreEvents.off(CoreEvent.ModelChanged, handleModelChanged); }; }, [getEffectiveModel]); diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index e50b31829d..620b9f9b55 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -147,6 +147,7 @@ vi.mock('../agents/subagent-tool-wrapper.js', () => ({ const mockCoreEvents = vi.hoisted(() => ({ emitFeedback: vi.fn(), + emitModelChanged: vi.fn(), })); const mockSetGlobalProxy = vi.hoisted(() => vi.fn()); diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 8dab6cbaa4..3dffa488e4 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -41,6 +41,7 @@ import { DEFAULT_OTLP_ENDPOINT, uiTelemetryService, } from '../telemetry/index.js'; +import { coreEvents } from '../utils/events.js'; import { tokenLimit } from '../core/tokenLimits.js'; import { DEFAULT_GEMINI_EMBEDDING_MODEL, @@ -76,7 +77,6 @@ import type { UserTierId } from '../code_assist/types.js'; import { AgentRegistry } from '../agents/registry.js'; import { setGlobalProxy } from '../utils/fetch.js'; import { SubagentToolWrapper } from '../agents/subagent-tool-wrapper.js'; -import { coreEvents } from '../utils/events.js'; export enum ApprovalMode { DEFAULT = 'default', @@ -711,7 +711,10 @@ export class Config { return; } - this.model = newModel; + if (this.model !== newModel) { + this.model = newModel; + coreEvents.emitModelChanged(newModel); + } } isInFallbackMode(): boolean { diff --git a/packages/core/src/utils/events.test.ts b/packages/core/src/utils/events.test.ts index 4a11263014..9ba660bf26 100644 --- a/packages/core/src/utils/events.test.ts +++ b/packages/core/src/utils/events.test.ts @@ -156,4 +156,17 @@ describe('CoreEventEmitter', () => { }); expect(listener.mock.calls[2][0]).toMatchObject({ message: 'Buffered 2' }); }); + + describe('ModelChanged Event', () => { + it('should emit ModelChanged event with correct payload', () => { + const listener = vi.fn(); + events.on(CoreEvent.ModelChanged, listener); + + const newModel = 'gemini-2.5-pro'; + events.emitModelChanged(newModel); + + expect(listener).toHaveBeenCalledTimes(1); + expect(listener).toHaveBeenCalledWith({ model: newModel }); + }); + }); }); diff --git a/packages/core/src/utils/events.ts b/packages/core/src/utils/events.ts index 9b34d27883..386200fad7 100644 --- a/packages/core/src/utils/events.ts +++ b/packages/core/src/utils/events.ts @@ -43,9 +43,20 @@ export interface FallbackModeChangedPayload { isInFallbackMode: boolean; } +/** + * Payload for the 'model-changed' event. + */ +export interface ModelChangedPayload { + /** + * The new model that was set. + */ + model: string; +} + export enum CoreEvent { UserFeedback = 'user-feedback', FallbackModeChanged = 'fallback-mode-changed', + ModelChanged = 'model-changed', } export class CoreEventEmitter extends EventEmitter { @@ -86,6 +97,14 @@ export class CoreEventEmitter extends EventEmitter { this.emit(CoreEvent.FallbackModeChanged, payload); } + /** + * Notifies subscribers that the model has changed. + */ + emitModelChanged(model: string): void { + const payload: ModelChangedPayload = { model }; + this.emit(CoreEvent.ModelChanged, payload); + } + /** * Flushes buffered messages. Call this immediately after primary UI listener * subscribes. @@ -106,6 +125,10 @@ export class CoreEventEmitter extends EventEmitter { event: CoreEvent.FallbackModeChanged, listener: (payload: FallbackModeChangedPayload) => void, ): this; + override on( + event: CoreEvent.ModelChanged, + listener: (payload: ModelChangedPayload) => void, + ): this; override on( event: string | symbol, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -122,6 +145,10 @@ export class CoreEventEmitter extends EventEmitter { event: CoreEvent.FallbackModeChanged, listener: (payload: FallbackModeChangedPayload) => void, ): this; + override off( + event: CoreEvent.ModelChanged, + listener: (payload: ModelChangedPayload) => void, + ): this; override off( event: string | symbol, // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -138,6 +165,10 @@ export class CoreEventEmitter extends EventEmitter { event: CoreEvent.FallbackModeChanged, payload: FallbackModeChangedPayload, ): boolean; + override emit( + event: CoreEvent.ModelChanged, + payload: ModelChangedPayload, + ): boolean; // eslint-disable-next-line @typescript-eslint/no-explicit-any override emit(event: string | symbol, ...args: any[]): boolean { return super.emit(event, ...args);