diff --git a/packages/cli/src/ui/AppContainer.test.tsx b/packages/cli/src/ui/AppContainer.test.tsx index acbf175d35..c8814969f7 100644 --- a/packages/cli/src/ui/AppContainer.test.tsx +++ b/packages/cli/src/ui/AppContainer.test.tsx @@ -15,6 +15,7 @@ import { type MockedObject, } from 'vitest'; import { render } from '../test-utils/render.js'; +import { waitFor } from '../test-utils/async.js'; import { cleanup } from 'ink-testing-library'; import { act, useContext } from 'react'; import { AppContainer } from './AppContainer.js'; @@ -36,12 +37,18 @@ const mockCoreEvents = vi.hoisted(() => ({ emit: vi.fn(), })); +// Mock IdeClient +const mockIdeClient = vi.hoisted(() => ({ + getInstance: vi.fn().mockReturnValue(new Promise(() => {})), +})); + vi.mock('@google/gemini-cli-core', async (importOriginal) => { const actual = await importOriginal(); return { ...actual, coreEvents: mockCoreEvents, + IdeClient: mockIdeClient, }; }); import type { LoadedSettings } from '../config/settings.js'; @@ -312,7 +319,7 @@ describe('AppContainer State Management', () => { // Add other properties if AppContainer uses them }); mockedUseLogger.mockReturnValue({ - getPreviousUserMessages: vi.fn().mockResolvedValue([]), + getPreviousUserMessages: vi.fn().mockReturnValue(new Promise(() => {})), }); mockedUseLoadingIndicator.mockReturnValue({ elapsedTime: '0.0s', @@ -367,9 +374,7 @@ describe('AppContainer State Management', () => { describe('Basic Rendering', () => { it('renders without crashing with minimal props', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); @@ -377,9 +382,7 @@ describe('AppContainer State Management', () => { const startupWarnings = ['Warning 1', 'Warning 2']; const { unmount } = renderAppContainer({ startupWarnings }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); }); @@ -394,9 +397,7 @@ describe('AppContainer State Management', () => { const { unmount } = renderAppContainer({ initResult: initResultWithError, }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); @@ -413,9 +414,7 @@ describe('AppContainer State Management', () => { describe('Context Providers', () => { it('provides AppContext with correct values', async () => { const { unmount } = renderAppContainer({ version: '2.0.0' }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); // Should render and unmount cleanly expect(() => unmount()).not.toThrow(); @@ -423,25 +422,19 @@ describe('AppContainer State Management', () => { it('provides UIStateContext with state management', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); it('provides UIActionsContext with action handlers', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); it('provides ConfigContext with config object', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); }); @@ -458,9 +451,7 @@ describe('AppContainer State Management', () => { } as unknown as LoadedSettings; const { unmount } = renderAppContainer({ settings: settingsAllHidden }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); @@ -475,9 +466,7 @@ describe('AppContainer State Management', () => { } as unknown as LoadedSettings; const { unmount } = renderAppContainer({ settings: settingsWithMemory }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); }); @@ -487,9 +476,7 @@ describe('AppContainer State Management', () => { 'handles version format: %s', async (version) => { const { unmount } = renderAppContainer({ version }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }, ); @@ -504,9 +491,6 @@ describe('AppContainer State Management', () => { // Should still render without crashing - errors should be handled internally const { unmount } = renderAppContainer({ config: errorConfig }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); unmount(); }); @@ -516,9 +500,7 @@ describe('AppContainer State Management', () => { } as LoadedSettings; const { unmount } = renderAppContainer({ settings: undefinedSettings }); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); }); }); @@ -849,12 +831,10 @@ describe('AppContainer State Management', () => { it('passes a null proQuotaRequest to UIStateContext by default', async () => { // The default mock from beforeEach already sets proQuotaRequest to null const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); + await waitFor(() => { + // Assert that the context value is as expected + expect(capturedUIState.proQuotaRequest).toBeNull(); }); - - // Assert that the context value is as expected - expect(capturedUIState.proQuotaRequest).toBeNull(); unmount(); }); @@ -872,12 +852,10 @@ describe('AppContainer State Management', () => { // Act: Render the container const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); + await waitFor(() => { + // Assert: The mock request is correctly passed through the context + expect(capturedUIState.proQuotaRequest).toEqual(mockRequest); }); - - // Assert: The mock request is correctly passed through the context - expect(capturedUIState.proQuotaRequest).toEqual(mockRequest); unmount(); }); @@ -891,13 +869,11 @@ describe('AppContainer State Management', () => { // Act: Render the container const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); + await waitFor(() => { + // Assert: The action in the context is the mock handler we provided + expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler); }); - // Assert: The action in the context is the mock handler we provided - expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler); - // You can even verify that the plumbed function is callable act(() => { capturedUIActions.handleProQuotaChoice('retry_later'); @@ -1308,13 +1284,7 @@ describe('AppContainer State Management', () => { }); const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); - - // Assert: The shell should be resized to a minimum height of 1, not a negative number. - // The old code would have tried to set a negative height. - expect(resizePtySpy).toHaveBeenCalled(); + await waitFor(() => expect(resizePtySpy).toHaveBeenCalled()); const lastCall = resizePtySpy.mock.calls[resizePtySpy.mock.calls.length - 1]; // Check the height argument specifically @@ -1649,9 +1619,7 @@ describe('AppContainer State Management', () => { }); const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); expect(capturedUIState.isModelDialogOpen).toBe(true); unmount(); @@ -1667,9 +1635,7 @@ describe('AppContainer State Management', () => { }); const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); // Verify that the actions are correctly passed through context act(() => { @@ -1683,9 +1649,7 @@ describe('AppContainer State Management', () => { describe('CoreEvents Integration', () => { it('subscribes to UserFeedback and drains backlog on mount', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); expect(mockCoreEvents.on).toHaveBeenCalledWith( CoreEvent.UserFeedback, @@ -1697,9 +1661,7 @@ describe('AppContainer State Management', () => { it('unsubscribes from UserFeedback on unmount', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); unmount(); @@ -1711,9 +1673,7 @@ describe('AppContainer State Management', () => { it('adds history item when UserFeedback event is received', async () => { const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); + await waitFor(() => expect(capturedUIState).toBeTruthy()); // Get the registered handler const handler = mockCoreEvents.on.mock.calls.find( @@ -1745,10 +1705,8 @@ describe('AppContainer State Management', () => { vi.spyOn(mockConfig, 'getModel').mockReturnValue('initial-model'); const { unmount } = renderAppContainer(); - await act(async () => { - await vi.waitFor(() => { - expect(capturedUIState?.currentModel).toBe('initial-model'); - }); + await waitFor(() => { + expect(capturedUIState?.currentModel).toBe('initial-model'); }); // Get the registered handler for ModelChanged @@ -1789,11 +1747,7 @@ describe('AppContainer State Management', () => { // The main assertion is that the render does not throw. const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); - - expect(resizePtySpy).toHaveBeenCalled(); + await waitFor(() => expect(resizePtySpy).toHaveBeenCalled()); unmount(); }); }); @@ -1805,13 +1759,70 @@ describe('AppContainer State Management', () => { apiKey: 'fake-key', }); const { unmount } = renderAppContainer(); - await act(async () => { - await new Promise((resolve) => setTimeout(resolve, 0)); - }); - await vi.waitFor(() => { + await waitFor(() => { expect(capturedUIState.bannerData.defaultText).toBeDefined(); unmount(); }); }); }); + + describe('onCancelSubmit Behavior', () => { + let mockSetText: Mock; + + // Helper to extract arguments from the useGeminiStream hook call + // This isolates the positional argument dependency to a single location + const extractUseGeminiStreamArgs = (args: unknown[]) => ({ + onCancelSubmit: args[14] as (shouldRestorePrompt?: boolean) => void, + }); + + beforeEach(() => { + mockSetText = vi.fn(); + mockedUseTextBuffer.mockReturnValue({ + text: '', + setText: mockSetText, + }); + }); + + it('clears the prompt when onCancelSubmit is called with shouldRestorePrompt=false', async () => { + const { unmount } = renderAppContainer(); + await waitFor(() => expect(capturedUIState).toBeTruthy()); + + const { onCancelSubmit } = extractUseGeminiStreamArgs( + mockedUseGeminiStream.mock.lastCall!, + ); + + act(() => { + onCancelSubmit(false); + }); + + expect(mockSetText).toHaveBeenCalledWith(''); + + unmount(); + }); + + it('restores the prompt when onCancelSubmit is called with shouldRestorePrompt=true (or undefined)', async () => { + mockedUseLogger.mockReturnValue({ + getPreviousUserMessages: vi + .fn() + .mockResolvedValue(['previous message']), + }); + + const { unmount } = renderAppContainer(); + await waitFor(() => + expect(capturedUIState.userMessages).toContain('previous message'), + ); + + const { onCancelSubmit } = extractUseGeminiStreamArgs( + mockedUseGeminiStream.mock.lastCall!, + ); + + await act(async () => { + onCancelSubmit(true); + }); + + expect(mockSetText).toHaveBeenCalledWith('previous message'); + + unmount(); + }); + }); }); diff --git a/packages/cli/src/ui/AppContainer.tsx b/packages/cli/src/ui/AppContainer.tsx index 22e4537cec..a827483cd9 100644 --- a/packages/cli/src/ui/AppContainer.tsx +++ b/packages/cli/src/ui/AppContainer.tsx @@ -648,15 +648,17 @@ Logging in with Google... Please restart Gemini CLI to continue. } }, [config, historyManager]); - const cancelHandlerRef = useRef<() => void>(() => {}); + const cancelHandlerRef = useRef<(shouldRestorePrompt?: boolean) => void>( + () => {}, + ); const getPreferredEditor = useCallback( () => settings.merged.general?.preferredEditor as EditorType, [settings.merged.general?.preferredEditor], ); - const onCancelSubmit = useCallback(() => { - cancelHandlerRef.current(); + const onCancelSubmit = useCallback((shouldRestorePrompt?: boolean) => { + cancelHandlerRef.current(shouldRestorePrompt); }, []); const { @@ -710,36 +712,39 @@ Logging in with Google... Please restart Gemini CLI to continue. submitQuery, }); - cancelHandlerRef.current = useCallback(() => { - const pendingHistoryItems = [ - ...pendingSlashCommandHistoryItems, - ...pendingGeminiHistoryItems, - ]; - if (isToolExecuting(pendingHistoryItems)) { - buffer.setText(''); // Just clear the prompt - return; - } + cancelHandlerRef.current = useCallback( + (shouldRestorePrompt: boolean = true) => { + const pendingHistoryItems = [ + ...pendingSlashCommandHistoryItems, + ...pendingGeminiHistoryItems, + ]; + if (isToolExecuting(pendingHistoryItems)) { + buffer.setText(''); // Just clear the prompt + return; + } - const lastUserMessage = userMessages.at(-1); - let textToSet = lastUserMessage || ''; + const lastUserMessage = userMessages.at(-1); + let textToSet = shouldRestorePrompt ? lastUserMessage || '' : ''; - const queuedText = getQueuedMessagesText(); - if (queuedText) { - textToSet = textToSet ? `${textToSet}\n\n${queuedText}` : queuedText; - clearQueue(); - } + const queuedText = getQueuedMessagesText(); + if (queuedText) { + textToSet = textToSet ? `${textToSet}\n\n${queuedText}` : queuedText; + clearQueue(); + } - if (textToSet) { - buffer.setText(textToSet); - } - }, [ - buffer, - userMessages, - getQueuedMessagesText, - clearQueue, - pendingSlashCommandHistoryItems, - pendingGeminiHistoryItems, - ]); + if (textToSet || !shouldRestorePrompt) { + buffer.setText(textToSet); + } + }, + [ + buffer, + userMessages, + getQueuedMessagesText, + clearQueue, + pendingSlashCommandHistoryItems, + pendingGeminiHistoryItems, + ], + ); const handleFinalSubmit = useCallback( (submittedValue: string) => { diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index 5019d51d9c..4ebab04c54 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -1050,7 +1050,7 @@ describe('useGeminiStream', () => { simulateEscapeKeyPress(); - expect(cancelSubmitSpy).toHaveBeenCalled(); + expect(cancelSubmitSpy).toHaveBeenCalledWith(false); }); it('should call setShellInputFocused(false) when escape is pressed', async () => { @@ -1968,7 +1968,7 @@ describe('useGeminiStream', () => { // Check that onCancelSubmit was called await waitFor(() => { - expect(onCancelSubmitSpy).toHaveBeenCalled(); + expect(onCancelSubmitSpy).toHaveBeenCalledWith(true); }); }); diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index d4fd21942d..f23f0f83d0 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -105,7 +105,7 @@ export const useGeminiStream = ( modelSwitchedFromQuotaError: boolean, setModelSwitchedFromQuotaError: React.Dispatch>, onEditorClose: () => void, - onCancelSubmit: () => void, + onCancelSubmit: (shouldRestorePrompt?: boolean) => void, setShellInputFocused: (value: boolean) => void, terminalWidth: number, terminalHeight: number, @@ -324,7 +324,7 @@ export const useGeminiStream = ( setIsResponding(false); } - onCancelSubmit(); + onCancelSubmit(false); setShellInputFocused(false); }, [ streamingState, @@ -690,7 +690,7 @@ export const useGeminiStream = ( const handleContextWindowWillOverflowEvent = useCallback( (estimatedRequestTokenCount: number, remainingTokenCount: number) => { - onCancelSubmit(); + onCancelSubmit(true); const limit = tokenLimit(config.getModel());