diff --git a/packages/cli/src/ui/commands/chatCommand.ts b/packages/cli/src/ui/commands/chatCommand.ts index d346d3653c..b41d403443 100644 --- a/packages/cli/src/ui/commands/chatCommand.ts +++ b/packages/cli/src/ui/commands/chatCommand.ts @@ -11,11 +11,13 @@ import { theme } from '../semantic-colors.js'; import type { CommandContext, SlashCommand, - MessageActionReturn, SlashCommandActionReturn, } from './types.js'; import { CommandKind } from './types.js'; -import { decodeTagName } from '@google/gemini-cli-core'; +import { + decodeTagName, + type MessageActionReturn, +} from '@google/gemini-cli-core'; import path from 'node:path'; import type { HistoryItemWithoutId, diff --git a/packages/cli/src/ui/commands/hooksCommand.ts b/packages/cli/src/ui/commands/hooksCommand.ts index bdf226e85b..03312c361c 100644 --- a/packages/cli/src/ui/commands/hooksCommand.ts +++ b/packages/cli/src/ui/commands/hooksCommand.ts @@ -4,14 +4,13 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { - SlashCommand, - CommandContext, - MessageActionReturn, -} from './types.js'; +import type { SlashCommand, CommandContext } from './types.js'; import { CommandKind } from './types.js'; import { MessageType, type HistoryItemHooksList } from '../types.js'; -import type { HookRegistryEntry } from '@google/gemini-cli-core'; +import type { + HookRegistryEntry, + MessageActionReturn, +} from '@google/gemini-cli-core'; import { getErrorMessage } from '@google/gemini-cli-core'; import { SettingScope } from '../../config/settings.js'; diff --git a/packages/cli/src/ui/commands/initCommand.test.ts b/packages/cli/src/ui/commands/initCommand.test.ts index c38fd4196f..54bb4d164e 100644 --- a/packages/cli/src/ui/commands/initCommand.test.ts +++ b/packages/cli/src/ui/commands/initCommand.test.ts @@ -9,7 +9,8 @@ import * as fs from 'node:fs'; import * as path from 'node:path'; import { initCommand } from './initCommand.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; -import type { SubmitPromptActionReturn, CommandContext } from './types.js'; +import type { CommandContext } from './types.js'; +import type { SubmitPromptActionReturn } from '@google/gemini-cli-core'; // Mock the 'fs' module vi.mock('fs', () => ({ diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index f3fdd06b8e..56d37b5e14 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -8,10 +8,12 @@ import type { SlashCommand, SlashCommandActionReturn, CommandContext, - MessageActionReturn, } from './types.js'; import { CommandKind } from './types.js'; -import type { DiscoveredMCPPrompt } from '@google/gemini-cli-core'; +import type { + DiscoveredMCPPrompt, + MessageActionReturn, +} from '@google/gemini-cli-core'; import { DiscoveredMCPTool, getMCPDiscoveryState, diff --git a/packages/cli/src/ui/commands/restoreCommand.test.ts b/packages/cli/src/ui/commands/restoreCommand.test.ts index f1c73e3c96..2a5def5c42 100644 --- a/packages/cli/src/ui/commands/restoreCommand.test.ts +++ b/packages/cli/src/ui/commands/restoreCommand.test.ts @@ -155,10 +155,10 @@ describe('restoreCommand', () => { it('should restore a tool call and project state', async () => { const toolCallData = { - history: [{ type: 'user', text: 'do a thing' }], + history: [{ type: 'user', text: 'do a thing', id: 123 }], clientHistory: [{ role: 'user', parts: [{ text: 'do a thing' }] }], commitHash: 'abcdef123', - toolCall: { name: 'run_shell_command', args: 'ls' }, + toolCall: { name: 'run_shell_command', args: { command: 'ls' } }, }; await fs.writeFile( path.join(checkpointsDir, 'my-checkpoint.json'), @@ -169,7 +169,7 @@ describe('restoreCommand', () => { expect(await command?.action?.(mockContext, 'my-checkpoint')).toEqual({ type: 'tool', toolName: 'run_shell_command', - toolArgs: 'ls', + toolArgs: { command: 'ls' }, }); expect(mockContext.ui.loadHistory).toHaveBeenCalledWith( toolCallData.history, @@ -189,7 +189,7 @@ describe('restoreCommand', () => { it('should restore even if only toolCall is present', async () => { const toolCallData = { - toolCall: { name: 'run_shell_command', args: 'ls' }, + toolCall: { name: 'run_shell_command', args: { command: 'ls' } }, }; await fs.writeFile( path.join(checkpointsDir, 'my-checkpoint.json'), @@ -201,7 +201,7 @@ describe('restoreCommand', () => { expect(await command?.action?.(mockContext, 'my-checkpoint')).toEqual({ type: 'tool', toolName: 'run_shell_command', - toolArgs: 'ls', + toolArgs: { command: 'ls' }, }); expect(mockContext.ui.loadHistory).not.toHaveBeenCalled(); @@ -222,7 +222,7 @@ describe('restoreCommand', () => { type: 'message', messageType: 'error', // A more specific error message would be ideal, but for now, we can assert the current behavior. - content: expect.stringContaining('Could not read restorable tool calls.'), + content: expect.stringContaining('Checkpoint file is invalid'), }); }); diff --git a/packages/cli/src/ui/commands/restoreCommand.ts b/packages/cli/src/ui/commands/restoreCommand.ts index cbee8ec0ea..cf77b10456 100644 --- a/packages/cli/src/ui/commands/restoreCommand.ts +++ b/packages/cli/src/ui/commands/restoreCommand.ts @@ -6,13 +6,45 @@ import * as fs from 'node:fs/promises'; import path from 'node:path'; +import { z } from 'zod'; +import { + type Config, + performRestore, + type ToolCallData, +} from '@google/gemini-cli-core'; import { type CommandContext, type SlashCommand, type SlashCommandActionReturn, CommandKind, } from './types.js'; -import type { Config } from '@google/gemini-cli-core'; +import type { HistoryItem } from '../types.js'; +import type { Content } from '@google/genai'; + +const HistoryItemSchema = z + .object({ + type: z.string(), + id: z.number(), + }) + .passthrough(); + +const ContentSchema = z + .object({ + role: z.string().optional(), + parts: z.array(z.record(z.unknown())), + }) + .passthrough(); + +const ToolCallDataSchema = z.object({ + history: z.array(HistoryItemSchema).optional(), + clientHistory: z.array(ContentSchema).optional(), + commitHash: z.string().optional(), + toolCall: z.object({ + name: z.string(), + args: z.record(z.unknown()), + }), + messageId: z.string().optional(), +}); async function restoreAction( context: CommandContext, @@ -74,33 +106,43 @@ async function restoreAction( const filePath = path.join(checkpointDir, selectedFile); const data = await fs.readFile(filePath, 'utf-8'); - const toolCallData = JSON.parse(data); + const parseResult = ToolCallDataSchema.safeParse(JSON.parse(data)); - if (toolCallData.history) { - if (!loadHistory) { - // This should not happen - return { - type: 'message', - messageType: 'error', - content: 'loadHistory function is not available.', - }; + if (!parseResult.success) { + return { + type: 'message', + messageType: 'error', + content: `Checkpoint file is invalid: ${parseResult.error.message}`, + }; + } + + // We safely cast here because: + // 1. ToolCallDataSchema strictly validates the existence of 'history' as an array and 'id'/'type' on each item. + // 2. We trust that files valid according to this schema (written by useGeminiStream) contain the full HistoryItem structure. + const toolCallData = parseResult.data as ToolCallData< + HistoryItem[], + Record + >; + + const actionStream = performRestore(toolCallData, gitService); + + for await (const action of actionStream) { + if (action.type === 'message') { + addItem( + { + type: action.messageType, + text: action.content, + }, + Date.now(), + ); + } else if (action.type === 'load_history' && loadHistory) { + loadHistory(action.history); + if (action.clientHistory) { + await config + ?.getGeminiClient() + ?.setHistory(action.clientHistory as Content[]); + } } - loadHistory(toolCallData.history); - } - - if (toolCallData.clientHistory) { - await config?.getGeminiClient()?.setHistory(toolCallData.clientHistory); - } - - if (toolCallData.commitHash) { - await gitService?.restoreProjectFromSnapshot(toolCallData.commitHash); - addItem( - { - type: 'info', - text: 'Restored project to the state before the tool call.', - }, - Date.now(), - ); } return { diff --git a/packages/cli/src/ui/commands/setupGithubCommand.test.ts b/packages/cli/src/ui/commands/setupGithubCommand.test.ts index 2eb6da7abc..0125ae70bd 100644 --- a/packages/cli/src/ui/commands/setupGithubCommand.test.ts +++ b/packages/cli/src/ui/commands/setupGithubCommand.test.ts @@ -15,8 +15,9 @@ import { updateGitignore, GITHUB_WORKFLOW_PATHS, } from './setupGithubCommand.js'; -import type { CommandContext, ToolActionReturn } from './types.js'; +import type { CommandContext } from './types.js'; import * as commandUtils from '../utils/commandUtils.js'; +import type { ToolActionReturn } from '@google/gemini-cli-core'; import { debugLogger } from '@google/gemini-cli-core'; vi.mock('child_process'); diff --git a/packages/cli/src/ui/commands/terminalSetupCommand.ts b/packages/cli/src/ui/commands/terminalSetupCommand.ts index c5772ae5a7..780513ab6c 100644 --- a/packages/cli/src/ui/commands/terminalSetupCommand.ts +++ b/packages/cli/src/ui/commands/terminalSetupCommand.ts @@ -4,9 +4,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { MessageActionReturn, SlashCommand } from './types.js'; +import type { SlashCommand } from './types.js'; import { CommandKind } from './types.js'; import { terminalSetup } from '../utils/terminalSetup.js'; +import { type MessageActionReturn } from '@google/gemini-cli-core'; /** * Command to configure terminal keybindings for multiline input support. diff --git a/packages/cli/src/ui/commands/types.ts b/packages/cli/src/ui/commands/types.ts index a0f102c3c3..99d98f4009 100644 --- a/packages/cli/src/ui/commands/types.ts +++ b/packages/cli/src/ui/commands/types.ts @@ -5,13 +5,17 @@ */ import type { ReactNode } from 'react'; -import type { Content, PartListUnion } from '@google/genai'; import type { HistoryItemWithoutId, HistoryItem, ConfirmationRequest, } from '../types.js'; -import type { Config, GitService, Logger } from '@google/gemini-cli-core'; +import type { + Config, + GitService, + Logger, + CommandActionReturn, +} from '@google/gemini-cli-core'; import type { LoadedSettings } from '../../config/settings.js'; import type { UseHistoryManagerReturn } from '../hooks/useHistoryManager.js'; import type { SessionStatsState } from '../contexts/SessionContext.js'; @@ -84,31 +88,12 @@ export interface CommandContext { overwriteConfirmed?: boolean; } -/** - * The return type for a command action that results in scheduling a tool call. - */ -export interface ToolActionReturn { - type: 'tool'; - toolName: string; - toolArgs: Record; -} - /** The return type for a command action that results in the app quitting. */ export interface QuitActionReturn { type: 'quit'; messages: HistoryItem[]; } -/** - * The return type for a command action that results in a simple message - * being displayed to the user. - */ -export interface MessageActionReturn { - type: 'message'; - messageType: 'info' | 'error'; - content: string; -} - /** * The return type for a command action that needs to open a dialog. */ @@ -128,25 +113,6 @@ export interface OpenDialogActionReturn { | 'permissions'; } -/** - * The return type for a command action that results in replacing - * the entire conversation history. - */ -export interface LoadHistoryActionReturn { - type: 'load_history'; - history: HistoryItemWithoutId[]; - clientHistory: Content[]; // The history for the generative client -} - -/** - * The return type for a command action that should immediately submit - * content as a prompt to the Gemini model. - */ -export interface SubmitPromptActionReturn { - type: 'submit_prompt'; - content: PartListUnion; -} - /** * The return type for a command action that needs to pause and request * confirmation for a set of shell commands before proceeding. @@ -177,12 +143,9 @@ export interface OpenCustomDialogActionReturn { } export type SlashCommandActionReturn = - | ToolActionReturn - | MessageActionReturn + | CommandActionReturn | QuitActionReturn | OpenDialogActionReturn - | LoadHistoryActionReturn - | SubmitPromptActionReturn | ConfirmShellCommandsActionReturn | ConfirmActionReturn | OpenCustomDialogActionReturn; diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 979f520dcd..78e955dd9a 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -16,6 +16,7 @@ import type { ThoughtSummary, ToolCallRequestInfo, GeminiErrorEventValue, + ToolCallData, } from '@google/gemini-cli-core'; import { GeminiEventType as ServerGeminiEventType, @@ -1313,22 +1314,23 @@ export const useGeminiStream = ( toolCallWithSnapshotFileName, ); + const checkpointData: ToolCallData< + HistoryItem[], + Record + > & { filePath: string } = { + history, + clientHistory, + toolCall: { + name: toolCall.request.name, + args: toolCall.request.args, + }, + commitHash, + filePath, + }; + await fs.writeFile( toolCallWithSnapshotFilePath, - JSON.stringify( - { - history, - clientHistory, - toolCall: { - name: toolCall.request.name, - args: toolCall.request.args, - }, - commitHash, - filePath, - }, - null, - 2, - ), + JSON.stringify(checkpointData, null, 2), ); } catch (error) { onDebugMessage( diff --git a/packages/core/src/commands/restore.test.ts b/packages/core/src/commands/restore.test.ts new file mode 100644 index 0000000000..634f4989d4 --- /dev/null +++ b/packages/core/src/commands/restore.test.ts @@ -0,0 +1,168 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { performRestore, type ToolCallData } from './restore.js'; +import type { GitService } from '../services/gitService.js'; + +describe('performRestore', () => { + let mockGitService: GitService; + + beforeEach(() => { + mockGitService = { + initialize: vi.fn(), + verifyGitAvailability: vi.fn(), + setupShadowGitRepository: vi.fn(), + getCurrentCommitHash: vi.fn(), + createFileSnapshot: vi.fn(), + restoreProjectFromSnapshot: vi.fn(), + storage: {}, + getHistoryDir: vi.fn().mockReturnValue('mock-history-dir'), + shadowGitRepository: {}, + } as unknown as GitService; + }); + + it('should yield load_history if history and clientHistory are present', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + history: [{ some: 'history' }], + clientHistory: [{ role: 'user', parts: [{ text: 'hello' }] }], + }; + + const generator = performRestore(toolCallData, undefined); + const result = await generator.next(); + + expect(result.value).toEqual({ + type: 'load_history', + history: toolCallData.history, + clientHistory: toolCallData.clientHistory, + }); + expect(result.done).toBe(false); + + const nextResult = await generator.next(); + expect(nextResult.done).toBe(true); + }); + + it('should call restoreProjectFromSnapshot and yield a message if commitHash and gitService are present', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + commitHash: 'test-commit-hash', + }; + const spy = vi + .spyOn(mockGitService, 'restoreProjectFromSnapshot') + .mockResolvedValue(undefined); + + const generator = performRestore(toolCallData, mockGitService); + const result = await generator.next(); + + expect(spy).toHaveBeenCalledWith('test-commit-hash'); + expect(result.value).toEqual({ + type: 'message', + messageType: 'info', + content: 'Restored project to the state before the tool call.', + }); + expect(result.done).toBe(false); + + const nextResult = await generator.next(); + expect(nextResult.done).toBe(true); + }); + + it('should yield an error message if restoreProjectFromSnapshot throws "unable to read tree" error', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + commitHash: 'invalid-commit-hash', + }; + const spy = vi + .spyOn(mockGitService, 'restoreProjectFromSnapshot') + .mockRejectedValue( + new Error('fatal: unable to read tree invalid-commit-hash'), + ); + + const generator = performRestore(toolCallData, mockGitService); + const result = await generator.next(); + + expect(spy).toHaveBeenCalledWith('invalid-commit-hash'); + expect(result.value).toEqual({ + type: 'message', + messageType: 'error', + content: + "The commit hash 'invalid-commit-hash' associated with this checkpoint could not be found in your Git repository. This can happen if the repository has been re-cloned, reset, or if old commits have been garbage collected. This checkpoint cannot be restored.", + }); + expect(result.done).toBe(false); + + const nextResult = await generator.next(); + expect(nextResult.done).toBe(true); + }); + + it('should re-throw other errors from restoreProjectFromSnapshot', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + commitHash: 'some-commit-hash', + }; + const testError = new Error('something went wrong'); + vi.spyOn(mockGitService, 'restoreProjectFromSnapshot').mockRejectedValue( + testError, + ); + + const generator = performRestore(toolCallData, mockGitService); + await expect(generator.next()).rejects.toThrow(testError); + }); + + it('should yield load_history then a message if both are present', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + history: [{ some: 'history' }], + clientHistory: [{ role: 'user', parts: [{ text: 'hello' }] }], + commitHash: 'test-commit-hash', + }; + const spy = vi + .spyOn(mockGitService, 'restoreProjectFromSnapshot') + .mockResolvedValue(undefined); + + const generator = performRestore(toolCallData, mockGitService); + + const historyResult = await generator.next(); + expect(historyResult.value).toEqual({ + type: 'load_history', + history: toolCallData.history, + clientHistory: toolCallData.clientHistory, + }); + expect(historyResult.done).toBe(false); + + const messageResult = await generator.next(); + expect(spy).toHaveBeenCalledWith('test-commit-hash'); + expect(messageResult.value).toEqual({ + type: 'message', + messageType: 'info', + content: 'Restored project to the state before the tool call.', + }); + expect(messageResult.done).toBe(false); + + const nextResult = await generator.next(); + expect(nextResult.done).toBe(true); + }); + + it('should yield error message if commitHash is present but gitService is undefined', async () => { + const toolCallData: ToolCallData = { + toolCall: { name: 'test', args: {} }, + commitHash: 'test-commit-hash', + }; + + const generator = performRestore(toolCallData, undefined); + const result = await generator.next(); + + expect(result.value).toEqual({ + type: 'message', + messageType: 'error', + content: + 'Git service is not available, cannot restore checkpoint. Please ensure you are in a git repository.', + }); + expect(result.done).toBe(false); + + const nextResult = await generator.next(); + expect(nextResult.done).toBe(true); + }); +}); diff --git a/packages/core/src/commands/restore.ts b/packages/core/src/commands/restore.ts new file mode 100644 index 0000000000..778836f96a --- /dev/null +++ b/packages/core/src/commands/restore.ts @@ -0,0 +1,68 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content } from '@google/genai'; +import type { GitService } from '../services/gitService.js'; +import type { CommandActionReturn } from './types.js'; + +export interface ToolCallData { + history?: HistoryType; + clientHistory?: Content[]; + commitHash?: string; + toolCall: { + name: string; + args: ArgsType; + }; + messageId?: string; +} + +export async function* performRestore< + HistoryType = unknown, + ArgsType = unknown, +>( + toolCallData: ToolCallData, + gitService: GitService | undefined, +): AsyncGenerator> { + if (toolCallData.history && toolCallData.clientHistory) { + yield { + type: 'load_history', + history: toolCallData.history, + clientHistory: toolCallData.clientHistory, + }; + } + + if (toolCallData.commitHash) { + if (!gitService) { + yield { + type: 'message', + messageType: 'error', + content: + 'Git service is not available, cannot restore checkpoint. Please ensure you are in a git repository.', + }; + return; + } + + try { + await gitService.restoreProjectFromSnapshot(toolCallData.commitHash); + yield { + type: 'message', + messageType: 'info', + content: 'Restored project to the state before the tool call.', + }; + } catch (e) { + const error = e as Error; + if (error.message.includes('unable to read tree')) { + yield { + type: 'message', + messageType: 'error', + content: `The commit hash '${toolCallData.commitHash}' associated with this checkpoint could not be found in your Git repository. This can happen if the repository has been re-cloned, reset, or if old commits have been garbage collected. This checkpoint cannot be restored.`, + }; + return; + } + throw e; + } + } +} diff --git a/packages/core/src/commands/types.ts b/packages/core/src/commands/types.ts new file mode 100644 index 0000000000..31491a27be --- /dev/null +++ b/packages/core/src/commands/types.ts @@ -0,0 +1,50 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { Content, PartListUnion } from '@google/genai'; +/** + * The return type for a command action that results in scheduling a tool call. + */ +export interface ToolActionReturn { + type: 'tool'; + toolName: string; + toolArgs: Record; +} + +/** + * The return type for a command action that results in a simple message + * being displayed to the user. + */ +export interface MessageActionReturn { + type: 'message'; + messageType: 'info' | 'error'; + content: string; +} + +/** + * The return type for a command action that results in replacing + * the entire conversation history. + */ +export interface LoadHistoryActionReturn { + type: 'load_history'; + history: HistoryType; + clientHistory: Content[]; // The history for the generative client +} + +/** + * The return type for a command action that should immediately submit + * content as a prompt to the Gemini model. + */ +export interface SubmitPromptActionReturn { + type: 'submit_prompt'; + content: PartListUnion; +} + +export type CommandActionReturn = + | ToolActionReturn + | MessageActionReturn + | LoadHistoryActionReturn + | SubmitPromptActionReturn; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index d69f4c1bd7..235c459b2c 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -20,6 +20,8 @@ export * from './confirmation-bus/message-bus.js'; // Export Commands logic export * from './commands/extensions.js'; +export * from './commands/restore.js'; +export * from './commands/types.js'; // Export Core Logic export * from './core/client.js';