mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-19 10:31:16 -07:00
feat: implement /rewind command (#15720)
This commit is contained in:
351
packages/cli/src/ui/commands/rewindCommand.test.tsx
Normal file
351
packages/cli/src/ui/commands/rewindCommand.test.tsx
Normal file
@@ -0,0 +1,351 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { rewindCommand } from './rewindCommand.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
import { waitFor } from '../../test-utils/async.js';
|
||||
import { RewindOutcome } from '../components/RewindConfirmation.js';
|
||||
import {
|
||||
type OpenCustomDialogActionReturn,
|
||||
type CommandContext,
|
||||
} from './types.js';
|
||||
import type { ReactElement } from 'react';
|
||||
import { coreEvents } from '@google/gemini-cli-core';
|
||||
|
||||
// Mock dependencies
|
||||
const mockRewindTo = vi.fn();
|
||||
const mockRecordMessage = vi.fn();
|
||||
const mockSetHistory = vi.fn();
|
||||
const mockSendMessageStream = vi.fn();
|
||||
const mockGetChatRecordingService = vi.fn();
|
||||
const mockGetConversation = vi.fn();
|
||||
const mockRemoveComponent = vi.fn();
|
||||
const mockLoadHistory = vi.fn();
|
||||
const mockAddItem = vi.fn();
|
||||
const mockSetPendingItem = vi.fn();
|
||||
const mockResetContext = vi.fn();
|
||||
const mockSetInput = vi.fn();
|
||||
const mockRevertFileChanges = vi.fn();
|
||||
const mockGetProjectRoot = vi.fn().mockReturnValue('/mock/root');
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const actual =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...actual,
|
||||
coreEvents: {
|
||||
...actual.coreEvents,
|
||||
emitFeedback: vi.fn(),
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../components/RewindViewer.js', () => ({
|
||||
RewindViewer: () => null,
|
||||
}));
|
||||
|
||||
vi.mock('../hooks/useSessionBrowser.js', () => ({
|
||||
convertSessionToHistoryFormats: vi.fn().mockReturnValue({
|
||||
uiHistory: [
|
||||
{ type: 'user', text: 'old user' },
|
||||
{ type: 'gemini', text: 'old gemini' },
|
||||
],
|
||||
clientHistory: [{ role: 'user', parts: [{ text: 'old user' }] }],
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/rewindFileOps.js', () => ({
|
||||
revertFileChanges: (...args: unknown[]) => mockRevertFileChanges(...args),
|
||||
}));
|
||||
|
||||
interface RewindViewerProps {
|
||||
onRewind: (
|
||||
messageId: string,
|
||||
newText: string,
|
||||
outcome: RewindOutcome,
|
||||
) => Promise<void>;
|
||||
conversation: unknown;
|
||||
onExit: () => void;
|
||||
}
|
||||
|
||||
describe('rewindCommand', () => {
|
||||
let mockContext: CommandContext;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockGetConversation.mockReturnValue({
|
||||
messages: [{ id: 'msg-1', type: 'user', content: 'hello' }],
|
||||
sessionId: 'test-session',
|
||||
});
|
||||
|
||||
mockRewindTo.mockReturnValue({
|
||||
messages: [], // Mocked rewound messages
|
||||
});
|
||||
|
||||
mockGetChatRecordingService.mockReturnValue({
|
||||
getConversation: mockGetConversation,
|
||||
rewindTo: mockRewindTo,
|
||||
recordMessage: mockRecordMessage,
|
||||
});
|
||||
|
||||
mockContext = createMockCommandContext({
|
||||
services: {
|
||||
config: {
|
||||
getGeminiClient: () => ({
|
||||
getChatRecordingService: mockGetChatRecordingService,
|
||||
setHistory: mockSetHistory,
|
||||
sendMessageStream: mockSendMessageStream,
|
||||
}),
|
||||
getSessionId: () => 'test-session-id',
|
||||
getContextManager: () => ({ refresh: mockResetContext }),
|
||||
getProjectRoot: mockGetProjectRoot,
|
||||
},
|
||||
},
|
||||
ui: {
|
||||
removeComponent: mockRemoveComponent,
|
||||
loadHistory: mockLoadHistory,
|
||||
addItem: mockAddItem,
|
||||
setPendingItem: mockSetPendingItem,
|
||||
},
|
||||
}) as unknown as CommandContext;
|
||||
});
|
||||
|
||||
it('should initialize successfully', async () => {
|
||||
const result = await rewindCommand.action!(mockContext, '');
|
||||
expect(result).toHaveProperty('type', 'custom_dialog');
|
||||
});
|
||||
|
||||
it('should handle RewindOnly correctly', async () => {
|
||||
// 1. Run the command to get the component
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
|
||||
// Access onRewind from props
|
||||
const onRewind = component.props.onRewind;
|
||||
expect(onRewind).toBeDefined();
|
||||
|
||||
await onRewind('msg-id-123', 'New Prompt', RewindOutcome.RewindOnly);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRevertFileChanges).not.toHaveBeenCalled();
|
||||
expect(mockRewindTo).toHaveBeenCalledWith('msg-id-123');
|
||||
expect(mockSetHistory).toHaveBeenCalled();
|
||||
expect(mockResetContext).toHaveBeenCalled();
|
||||
expect(mockLoadHistory).toHaveBeenCalledWith(
|
||||
[
|
||||
expect.objectContaining({ text: 'old user', id: 1 }),
|
||||
expect.objectContaining({ text: 'old gemini', id: 2 }),
|
||||
],
|
||||
'New Prompt',
|
||||
);
|
||||
expect(mockRemoveComponent).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Verify setInput was NOT called directly (it's handled via loadHistory now)
|
||||
expect(mockSetInput).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle RewindAndRevert correctly', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onRewind = component.props.onRewind;
|
||||
|
||||
await onRewind('msg-id-123', 'New Prompt', RewindOutcome.RewindAndRevert);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRevertFileChanges).toHaveBeenCalledWith(
|
||||
mockGetConversation(),
|
||||
'msg-id-123',
|
||||
);
|
||||
expect(mockRewindTo).toHaveBeenCalledWith('msg-id-123');
|
||||
expect(mockLoadHistory).toHaveBeenCalledWith(
|
||||
expect.any(Array),
|
||||
'New Prompt',
|
||||
);
|
||||
});
|
||||
expect(mockSetInput).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle RevertOnly correctly', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onRewind = component.props.onRewind;
|
||||
|
||||
await onRewind('msg-id-123', 'New Prompt', RewindOutcome.RevertOnly);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRevertFileChanges).toHaveBeenCalledWith(
|
||||
mockGetConversation(),
|
||||
'msg-id-123',
|
||||
);
|
||||
expect(mockRewindTo).not.toHaveBeenCalled();
|
||||
expect(mockRemoveComponent).toHaveBeenCalled();
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'info',
|
||||
'File changes reverted.',
|
||||
);
|
||||
});
|
||||
expect(mockSetInput).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle Cancel correctly', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onRewind = component.props.onRewind;
|
||||
|
||||
await onRewind('msg-id-123', 'New Prompt', RewindOutcome.Cancel);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRevertFileChanges).not.toHaveBeenCalled();
|
||||
expect(mockRewindTo).not.toHaveBeenCalled();
|
||||
expect(mockRemoveComponent).toHaveBeenCalled();
|
||||
});
|
||||
expect(mockSetInput).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle onExit correctly', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onExit = component.props.onExit;
|
||||
|
||||
onExit();
|
||||
|
||||
expect(mockRemoveComponent).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle rewind error correctly', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onRewind = component.props.onRewind;
|
||||
|
||||
mockRewindTo.mockImplementation(() => {
|
||||
throw new Error('Rewind Failed');
|
||||
});
|
||||
|
||||
await onRewind('msg-1', 'Prompt', RewindOutcome.RewindOnly);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'error',
|
||||
'Rewind Failed',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle null conversation from rewindTo', async () => {
|
||||
const result = (await rewindCommand.action!(
|
||||
mockContext,
|
||||
'',
|
||||
)) as OpenCustomDialogActionReturn;
|
||||
const component = result.component as ReactElement<RewindViewerProps>;
|
||||
const onRewind = component.props.onRewind;
|
||||
|
||||
mockRewindTo.mockReturnValue(null);
|
||||
|
||||
await onRewind('msg-1', 'Prompt', RewindOutcome.RewindOnly);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(coreEvents.emitFeedback).toHaveBeenCalledWith(
|
||||
'error',
|
||||
'Could not fetch conversation file',
|
||||
);
|
||||
expect(mockRemoveComponent).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
it('should fail if config is missing', () => {
|
||||
const context = { services: {} } as CommandContext;
|
||||
|
||||
const result = rewindCommand.action!(context, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Config not found',
|
||||
});
|
||||
});
|
||||
|
||||
it('should fail if client is not initialized', () => {
|
||||
const context = createMockCommandContext({
|
||||
services: {
|
||||
config: { getGeminiClient: () => undefined },
|
||||
},
|
||||
}) as unknown as CommandContext;
|
||||
|
||||
const result = rewindCommand.action!(context, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Client not initialized',
|
||||
});
|
||||
});
|
||||
|
||||
it('should fail if recording service is unavailable', () => {
|
||||
const context = createMockCommandContext({
|
||||
services: {
|
||||
config: {
|
||||
getGeminiClient: () => ({ getChatRecordingService: () => undefined }),
|
||||
},
|
||||
},
|
||||
}) as unknown as CommandContext;
|
||||
|
||||
const result = rewindCommand.action!(context, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Recording service unavailable',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return info if no conversation found', () => {
|
||||
mockGetConversation.mockReturnValue(null);
|
||||
|
||||
const result = rewindCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'No conversation found.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return info if no user interactions found', () => {
|
||||
mockGetConversation.mockReturnValue({
|
||||
messages: [{ id: 'msg-1', type: 'gemini', content: 'hello' }],
|
||||
sessionId: 'test-session',
|
||||
});
|
||||
|
||||
const result = rewindCommand.action!(mockContext, '');
|
||||
|
||||
expect(result).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'Nothing to rewind to.',
|
||||
});
|
||||
});
|
||||
});
|
||||
191
packages/cli/src/ui/commands/rewindCommand.tsx
Normal file
191
packages/cli/src/ui/commands/rewindCommand.tsx
Normal file
@@ -0,0 +1,191 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
CommandKind,
|
||||
type CommandContext,
|
||||
type SlashCommand,
|
||||
} from './types.js';
|
||||
import { RewindViewer } from '../components/RewindViewer.js';
|
||||
import { type HistoryItem } from '../types.js';
|
||||
import { convertSessionToHistoryFormats } from '../hooks/useSessionBrowser.js';
|
||||
import { revertFileChanges } from '../utils/rewindFileOps.js';
|
||||
import { RewindOutcome } from '../components/RewindConfirmation.js';
|
||||
import { checkExhaustive } from '../../utils/checks.js';
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import type {
|
||||
ChatRecordingService,
|
||||
GeminiClient,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { coreEvents, debugLogger } from '@google/gemini-cli-core';
|
||||
|
||||
/**
|
||||
* Helper function to handle the core logic of rewinding a conversation.
|
||||
* This function encapsulates the steps needed to rewind the conversation,
|
||||
* update the client and UI history, and clear the component.
|
||||
*
|
||||
* @param context The command context.
|
||||
* @param client Gemini client
|
||||
* @param recordingService The chat recording service.
|
||||
* @param messageId The ID of the message to rewind to.
|
||||
* @param newText The new text for the input field after rewinding.
|
||||
*/
|
||||
async function rewindConversation(
|
||||
context: CommandContext,
|
||||
client: GeminiClient,
|
||||
recordingService: ChatRecordingService,
|
||||
messageId: string,
|
||||
newText: string,
|
||||
) {
|
||||
try {
|
||||
const conversation = recordingService.rewindTo(messageId);
|
||||
if (!conversation) {
|
||||
const errorMsg = 'Could not fetch conversation file';
|
||||
debugLogger.error(errorMsg);
|
||||
context.ui.removeComponent();
|
||||
coreEvents.emitFeedback('error', errorMsg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert to UI and Client formats
|
||||
const { uiHistory, clientHistory } = convertSessionToHistoryFormats(
|
||||
conversation.messages,
|
||||
);
|
||||
|
||||
client.setHistory(clientHistory as Content[]);
|
||||
|
||||
// Reset context manager as we are rewinding history
|
||||
await context.services.config?.getContextManager()?.refresh();
|
||||
|
||||
// Update UI History
|
||||
// We generate IDs based on index for the rewind history
|
||||
const startId = 1;
|
||||
const historyWithIds = uiHistory.map(
|
||||
(item, idx) =>
|
||||
({
|
||||
...item,
|
||||
id: startId + idx,
|
||||
}) as HistoryItem,
|
||||
);
|
||||
|
||||
// 1. Remove component FIRST to avoid flicker and clear the stage
|
||||
context.ui.removeComponent();
|
||||
|
||||
// 2. Load the rewound history and set the input
|
||||
context.ui.loadHistory(historyWithIds, newText);
|
||||
} catch (error) {
|
||||
// If an error occurs, we still want to remove the component if possible
|
||||
context.ui.removeComponent();
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
error instanceof Error ? error.message : 'Unknown error during rewind',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export const rewindCommand: SlashCommand = {
|
||||
name: 'rewind',
|
||||
description: 'Jump back to a specific message and restart the conversation',
|
||||
kind: CommandKind.BUILT_IN,
|
||||
action: (context) => {
|
||||
const config = context.services.config;
|
||||
if (!config)
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Config not found',
|
||||
};
|
||||
|
||||
const client = config.getGeminiClient();
|
||||
if (!client)
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Client not initialized',
|
||||
};
|
||||
|
||||
const recordingService = client.getChatRecordingService();
|
||||
if (!recordingService)
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Recording service unavailable',
|
||||
};
|
||||
|
||||
const conversation = recordingService.getConversation();
|
||||
if (!conversation)
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'No conversation found.',
|
||||
};
|
||||
|
||||
const hasUserInteractions = conversation.messages.some(
|
||||
(msg) => msg.type === 'user',
|
||||
);
|
||||
if (!hasUserInteractions) {
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'Nothing to rewind to.',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'custom_dialog',
|
||||
component: (
|
||||
<RewindViewer
|
||||
conversation={conversation}
|
||||
onExit={() => {
|
||||
context.ui.removeComponent();
|
||||
}}
|
||||
onRewind={async (messageId, newText, outcome) => {
|
||||
switch (outcome) {
|
||||
case RewindOutcome.Cancel:
|
||||
context.ui.removeComponent();
|
||||
return;
|
||||
|
||||
case RewindOutcome.RevertOnly:
|
||||
if (conversation) {
|
||||
await revertFileChanges(conversation, messageId);
|
||||
}
|
||||
context.ui.removeComponent();
|
||||
coreEvents.emitFeedback('info', 'File changes reverted.');
|
||||
return;
|
||||
|
||||
case RewindOutcome.RewindAndRevert:
|
||||
if (conversation) {
|
||||
await revertFileChanges(conversation, messageId);
|
||||
}
|
||||
await rewindConversation(
|
||||
context,
|
||||
client,
|
||||
recordingService,
|
||||
messageId,
|
||||
newText,
|
||||
);
|
||||
return;
|
||||
|
||||
case RewindOutcome.RewindOnly:
|
||||
await rewindConversation(
|
||||
context,
|
||||
client,
|
||||
recordingService,
|
||||
messageId,
|
||||
newText,
|
||||
);
|
||||
return;
|
||||
|
||||
default:
|
||||
checkExhaustive(outcome);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
),
|
||||
};
|
||||
},
|
||||
};
|
||||
@@ -66,8 +66,9 @@ export interface CommandContext {
|
||||
* Loads a new set of history items, replacing the current history.
|
||||
*
|
||||
* @param history The array of history items to load.
|
||||
* @param postLoadInput Optional text to set in the input buffer after loading history.
|
||||
*/
|
||||
loadHistory: UseHistoryManagerReturn['loadHistory'];
|
||||
loadHistory: (history: HistoryItem[], postLoadInput?: string) => void;
|
||||
/** Toggles a special display mode. */
|
||||
toggleCorgiMode: () => void;
|
||||
toggleDebugProfiler: () => void;
|
||||
|
||||
Reference in New Issue
Block a user