fix(ui): Clear input prompt on Escape key press (#13335)

This commit is contained in:
Sandy Tao
2025-11-19 11:11:36 +08:00
committed by GitHub
parent c5498bbb07
commit b644f037a3
4 changed files with 138 additions and 122 deletions
+98 -87
View File
@@ -15,6 +15,7 @@ import {
type MockedObject, type MockedObject,
} from 'vitest'; } from 'vitest';
import { render } from '../test-utils/render.js'; import { render } from '../test-utils/render.js';
import { waitFor } from '../test-utils/async.js';
import { cleanup } from 'ink-testing-library'; import { cleanup } from 'ink-testing-library';
import { act, useContext } from 'react'; import { act, useContext } from 'react';
import { AppContainer } from './AppContainer.js'; import { AppContainer } from './AppContainer.js';
@@ -36,12 +37,18 @@ const mockCoreEvents = vi.hoisted(() => ({
emit: vi.fn(), emit: vi.fn(),
})); }));
// Mock IdeClient
const mockIdeClient = vi.hoisted(() => ({
getInstance: vi.fn().mockReturnValue(new Promise(() => {})),
}));
vi.mock('@google/gemini-cli-core', async (importOriginal) => { vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual = const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>(); await importOriginal<typeof import('@google/gemini-cli-core')>();
return { return {
...actual, ...actual,
coreEvents: mockCoreEvents, coreEvents: mockCoreEvents,
IdeClient: mockIdeClient,
}; };
}); });
import type { LoadedSettings } from '../config/settings.js'; import type { LoadedSettings } from '../config/settings.js';
@@ -312,7 +319,7 @@ describe('AppContainer State Management', () => {
// Add other properties if AppContainer uses them // Add other properties if AppContainer uses them
}); });
mockedUseLogger.mockReturnValue({ mockedUseLogger.mockReturnValue({
getPreviousUserMessages: vi.fn().mockResolvedValue([]), getPreviousUserMessages: vi.fn().mockReturnValue(new Promise(() => {})),
}); });
mockedUseLoadingIndicator.mockReturnValue({ mockedUseLoadingIndicator.mockReturnValue({
elapsedTime: '0.0s', elapsedTime: '0.0s',
@@ -367,9 +374,7 @@ describe('AppContainer State Management', () => {
describe('Basic Rendering', () => { describe('Basic Rendering', () => {
it('renders without crashing with minimal props', async () => { it('renders without crashing with minimal props', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
@@ -377,9 +382,7 @@ describe('AppContainer State Management', () => {
const startupWarnings = ['Warning 1', 'Warning 2']; const startupWarnings = ['Warning 1', 'Warning 2'];
const { unmount } = renderAppContainer({ startupWarnings }); const { unmount } = renderAppContainer({ startupWarnings });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
}); });
@@ -394,9 +397,7 @@ describe('AppContainer State Management', () => {
const { unmount } = renderAppContainer({ const { unmount } = renderAppContainer({
initResult: initResultWithError, initResult: initResultWithError,
}); });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
@@ -413,9 +414,7 @@ describe('AppContainer State Management', () => {
describe('Context Providers', () => { describe('Context Providers', () => {
it('provides AppContext with correct values', async () => { it('provides AppContext with correct values', async () => {
const { unmount } = renderAppContainer({ version: '2.0.0' }); const { unmount } = renderAppContainer({ version: '2.0.0' });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
// Should render and unmount cleanly // Should render and unmount cleanly
expect(() => unmount()).not.toThrow(); expect(() => unmount()).not.toThrow();
@@ -423,25 +422,19 @@ describe('AppContainer State Management', () => {
it('provides UIStateContext with state management', async () => { it('provides UIStateContext with state management', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
it('provides UIActionsContext with action handlers', async () => { it('provides UIActionsContext with action handlers', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
it('provides ConfigContext with config object', async () => { it('provides ConfigContext with config object', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
}); });
@@ -458,9 +451,7 @@ describe('AppContainer State Management', () => {
} as unknown as LoadedSettings; } as unknown as LoadedSettings;
const { unmount } = renderAppContainer({ settings: settingsAllHidden }); const { unmount } = renderAppContainer({ settings: settingsAllHidden });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
@@ -475,9 +466,7 @@ describe('AppContainer State Management', () => {
} as unknown as LoadedSettings; } as unknown as LoadedSettings;
const { unmount } = renderAppContainer({ settings: settingsWithMemory }); const { unmount } = renderAppContainer({ settings: settingsWithMemory });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
}); });
@@ -487,9 +476,7 @@ describe('AppContainer State Management', () => {
'handles version format: %s', 'handles version format: %s',
async (version) => { async (version) => {
const { unmount } = renderAppContainer({ version }); const { unmount } = renderAppContainer({ version });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}, },
); );
@@ -504,9 +491,6 @@ describe('AppContainer State Management', () => {
// Should still render without crashing - errors should be handled internally // Should still render without crashing - errors should be handled internally
const { unmount } = renderAppContainer({ config: errorConfig }); const { unmount } = renderAppContainer({ config: errorConfig });
await act(async () => {
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
@@ -516,9 +500,7 @@ describe('AppContainer State Management', () => {
} as LoadedSettings; } as LoadedSettings;
const { unmount } = renderAppContainer({ settings: undefinedSettings }); const { unmount } = renderAppContainer({ settings: undefinedSettings });
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
}); });
}); });
@@ -849,12 +831,10 @@ describe('AppContainer State Management', () => {
it('passes a null proQuotaRequest to UIStateContext by default', async () => { it('passes a null proQuotaRequest to UIStateContext by default', async () => {
// The default mock from beforeEach already sets proQuotaRequest to null // The default mock from beforeEach already sets proQuotaRequest to null
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => {
await new Promise((resolve) => setTimeout(resolve, 0)); // Assert that the context value is as expected
expect(capturedUIState.proQuotaRequest).toBeNull();
}); });
// Assert that the context value is as expected
expect(capturedUIState.proQuotaRequest).toBeNull();
unmount(); unmount();
}); });
@@ -872,12 +852,10 @@ describe('AppContainer State Management', () => {
// Act: Render the container // Act: Render the container
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => {
await new Promise((resolve) => setTimeout(resolve, 0)); // Assert: The mock request is correctly passed through the context
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
}); });
// Assert: The mock request is correctly passed through the context
expect(capturedUIState.proQuotaRequest).toEqual(mockRequest);
unmount(); unmount();
}); });
@@ -891,13 +869,11 @@ describe('AppContainer State Management', () => {
// Act: Render the container // Act: Render the container
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => {
await new Promise((resolve) => setTimeout(resolve, 0)); // Assert: The action in the context is the mock handler we provided
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
}); });
// Assert: The action in the context is the mock handler we provided
expect(capturedUIActions.handleProQuotaChoice).toBe(mockHandler);
// You can even verify that the plumbed function is callable // You can even verify that the plumbed function is callable
act(() => { act(() => {
capturedUIActions.handleProQuotaChoice('retry_later'); capturedUIActions.handleProQuotaChoice('retry_later');
@@ -1308,13 +1284,7 @@ describe('AppContainer State Management', () => {
}); });
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(resizePtySpy).toHaveBeenCalled());
await new Promise((resolve) => setTimeout(resolve, 0));
});
// Assert: The shell should be resized to a minimum height of 1, not a negative number.
// The old code would have tried to set a negative height.
expect(resizePtySpy).toHaveBeenCalled();
const lastCall = const lastCall =
resizePtySpy.mock.calls[resizePtySpy.mock.calls.length - 1]; resizePtySpy.mock.calls[resizePtySpy.mock.calls.length - 1];
// Check the height argument specifically // Check the height argument specifically
@@ -1649,9 +1619,7 @@ describe('AppContainer State Management', () => {
}); });
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
expect(capturedUIState.isModelDialogOpen).toBe(true); expect(capturedUIState.isModelDialogOpen).toBe(true);
unmount(); unmount();
@@ -1667,9 +1635,7 @@ describe('AppContainer State Management', () => {
}); });
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
// Verify that the actions are correctly passed through context // Verify that the actions are correctly passed through context
act(() => { act(() => {
@@ -1683,9 +1649,7 @@ describe('AppContainer State Management', () => {
describe('CoreEvents Integration', () => { describe('CoreEvents Integration', () => {
it('subscribes to UserFeedback and drains backlog on mount', async () => { it('subscribes to UserFeedback and drains backlog on mount', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
expect(mockCoreEvents.on).toHaveBeenCalledWith( expect(mockCoreEvents.on).toHaveBeenCalledWith(
CoreEvent.UserFeedback, CoreEvent.UserFeedback,
@@ -1697,9 +1661,7 @@ describe('AppContainer State Management', () => {
it('unsubscribes from UserFeedback on unmount', async () => { it('unsubscribes from UserFeedback on unmount', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
unmount(); unmount();
@@ -1711,9 +1673,7 @@ describe('AppContainer State Management', () => {
it('adds history item when UserFeedback event is received', async () => { it('adds history item when UserFeedback event is received', async () => {
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(capturedUIState).toBeTruthy());
await new Promise((resolve) => setTimeout(resolve, 0));
});
// Get the registered handler // Get the registered handler
const handler = mockCoreEvents.on.mock.calls.find( const handler = mockCoreEvents.on.mock.calls.find(
@@ -1745,10 +1705,8 @@ describe('AppContainer State Management', () => {
vi.spyOn(mockConfig, 'getModel').mockReturnValue('initial-model'); vi.spyOn(mockConfig, 'getModel').mockReturnValue('initial-model');
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => {
await vi.waitFor(() => { expect(capturedUIState?.currentModel).toBe('initial-model');
expect(capturedUIState?.currentModel).toBe('initial-model');
});
}); });
// Get the registered handler for ModelChanged // Get the registered handler for ModelChanged
@@ -1789,11 +1747,7 @@ describe('AppContainer State Management', () => {
// The main assertion is that the render does not throw. // The main assertion is that the render does not throw.
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => expect(resizePtySpy).toHaveBeenCalled());
await new Promise((resolve) => setTimeout(resolve, 0));
});
expect(resizePtySpy).toHaveBeenCalled();
unmount(); unmount();
}); });
}); });
@@ -1805,13 +1759,70 @@ describe('AppContainer State Management', () => {
apiKey: 'fake-key', apiKey: 'fake-key',
}); });
const { unmount } = renderAppContainer(); const { unmount } = renderAppContainer();
await act(async () => { await waitFor(() => {
await new Promise((resolve) => setTimeout(resolve, 0));
});
await vi.waitFor(() => {
expect(capturedUIState.bannerData.defaultText).toBeDefined(); expect(capturedUIState.bannerData.defaultText).toBeDefined();
unmount(); unmount();
}); });
}); });
}); });
describe('onCancelSubmit Behavior', () => {
let mockSetText: Mock;
// Helper to extract arguments from the useGeminiStream hook call
// This isolates the positional argument dependency to a single location
const extractUseGeminiStreamArgs = (args: unknown[]) => ({
onCancelSubmit: args[14] as (shouldRestorePrompt?: boolean) => void,
});
beforeEach(() => {
mockSetText = vi.fn();
mockedUseTextBuffer.mockReturnValue({
text: '',
setText: mockSetText,
});
});
it('clears the prompt when onCancelSubmit is called with shouldRestorePrompt=false', async () => {
const { unmount } = renderAppContainer();
await waitFor(() => expect(capturedUIState).toBeTruthy());
const { onCancelSubmit } = extractUseGeminiStreamArgs(
mockedUseGeminiStream.mock.lastCall!,
);
act(() => {
onCancelSubmit(false);
});
expect(mockSetText).toHaveBeenCalledWith('');
unmount();
});
it('restores the prompt when onCancelSubmit is called with shouldRestorePrompt=true (or undefined)', async () => {
mockedUseLogger.mockReturnValue({
getPreviousUserMessages: vi
.fn()
.mockResolvedValue(['previous message']),
});
const { unmount } = renderAppContainer();
await waitFor(() =>
expect(capturedUIState.userMessages).toContain('previous message'),
);
const { onCancelSubmit } = extractUseGeminiStreamArgs(
mockedUseGeminiStream.mock.lastCall!,
);
await act(async () => {
onCancelSubmit(true);
});
expect(mockSetText).toHaveBeenCalledWith('previous message');
unmount();
});
});
}); });
+35 -30
View File
@@ -648,15 +648,17 @@ Logging in with Google... Please restart Gemini CLI to continue.
} }
}, [config, historyManager]); }, [config, historyManager]);
const cancelHandlerRef = useRef<() => void>(() => {}); const cancelHandlerRef = useRef<(shouldRestorePrompt?: boolean) => void>(
() => {},
);
const getPreferredEditor = useCallback( const getPreferredEditor = useCallback(
() => settings.merged.general?.preferredEditor as EditorType, () => settings.merged.general?.preferredEditor as EditorType,
[settings.merged.general?.preferredEditor], [settings.merged.general?.preferredEditor],
); );
const onCancelSubmit = useCallback(() => { const onCancelSubmit = useCallback((shouldRestorePrompt?: boolean) => {
cancelHandlerRef.current(); cancelHandlerRef.current(shouldRestorePrompt);
}, []); }, []);
const { const {
@@ -710,36 +712,39 @@ Logging in with Google... Please restart Gemini CLI to continue.
submitQuery, submitQuery,
}); });
cancelHandlerRef.current = useCallback(() => { cancelHandlerRef.current = useCallback(
const pendingHistoryItems = [ (shouldRestorePrompt: boolean = true) => {
...pendingSlashCommandHistoryItems, const pendingHistoryItems = [
...pendingGeminiHistoryItems, ...pendingSlashCommandHistoryItems,
]; ...pendingGeminiHistoryItems,
if (isToolExecuting(pendingHistoryItems)) { ];
buffer.setText(''); // Just clear the prompt if (isToolExecuting(pendingHistoryItems)) {
return; buffer.setText(''); // Just clear the prompt
} return;
}
const lastUserMessage = userMessages.at(-1); const lastUserMessage = userMessages.at(-1);
let textToSet = lastUserMessage || ''; let textToSet = shouldRestorePrompt ? lastUserMessage || '' : '';
const queuedText = getQueuedMessagesText(); const queuedText = getQueuedMessagesText();
if (queuedText) { if (queuedText) {
textToSet = textToSet ? `${textToSet}\n\n${queuedText}` : queuedText; textToSet = textToSet ? `${textToSet}\n\n${queuedText}` : queuedText;
clearQueue(); clearQueue();
} }
if (textToSet) { if (textToSet || !shouldRestorePrompt) {
buffer.setText(textToSet); buffer.setText(textToSet);
} }
}, [ },
buffer, [
userMessages, buffer,
getQueuedMessagesText, userMessages,
clearQueue, getQueuedMessagesText,
pendingSlashCommandHistoryItems, clearQueue,
pendingGeminiHistoryItems, pendingSlashCommandHistoryItems,
]); pendingGeminiHistoryItems,
],
);
const handleFinalSubmit = useCallback( const handleFinalSubmit = useCallback(
(submittedValue: string) => { (submittedValue: string) => {
@@ -1050,7 +1050,7 @@ describe('useGeminiStream', () => {
simulateEscapeKeyPress(); simulateEscapeKeyPress();
expect(cancelSubmitSpy).toHaveBeenCalled(); expect(cancelSubmitSpy).toHaveBeenCalledWith(false);
}); });
it('should call setShellInputFocused(false) when escape is pressed', async () => { it('should call setShellInputFocused(false) when escape is pressed', async () => {
@@ -1968,7 +1968,7 @@ describe('useGeminiStream', () => {
// Check that onCancelSubmit was called // Check that onCancelSubmit was called
await waitFor(() => { await waitFor(() => {
expect(onCancelSubmitSpy).toHaveBeenCalled(); expect(onCancelSubmitSpy).toHaveBeenCalledWith(true);
}); });
}); });
+3 -3
View File
@@ -105,7 +105,7 @@ export const useGeminiStream = (
modelSwitchedFromQuotaError: boolean, modelSwitchedFromQuotaError: boolean,
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>, setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
onEditorClose: () => void, onEditorClose: () => void,
onCancelSubmit: () => void, onCancelSubmit: (shouldRestorePrompt?: boolean) => void,
setShellInputFocused: (value: boolean) => void, setShellInputFocused: (value: boolean) => void,
terminalWidth: number, terminalWidth: number,
terminalHeight: number, terminalHeight: number,
@@ -324,7 +324,7 @@ export const useGeminiStream = (
setIsResponding(false); setIsResponding(false);
} }
onCancelSubmit(); onCancelSubmit(false);
setShellInputFocused(false); setShellInputFocused(false);
}, [ }, [
streamingState, streamingState,
@@ -690,7 +690,7 @@ export const useGeminiStream = (
const handleContextWindowWillOverflowEvent = useCallback( const handleContextWindowWillOverflowEvent = useCallback(
(estimatedRequestTokenCount: number, remainingTokenCount: number) => { (estimatedRequestTokenCount: number, remainingTokenCount: number) => {
onCancelSubmit(); onCancelSubmit(true);
const limit = tokenLimit(config.getModel()); const limit = tokenLimit(config.getModel());