mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-30 15:04:16 -07:00
feat(a2a): Introduce restore command for a2a server (#13015)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shreya Keshive <shreyakeshive@google.com>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { ExtensionsCommand } from './extensions.js';
|
||||
import { RestoreCommand } from './restore.js';
|
||||
import type { Command } from './types.js';
|
||||
|
||||
class CommandRegistry {
|
||||
@@ -12,6 +13,7 @@ class CommandRegistry {
|
||||
|
||||
constructor() {
|
||||
this.register(new ExtensionsCommand());
|
||||
this.register(new RestoreCommand());
|
||||
}
|
||||
|
||||
register(command: Command) {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ExtensionsCommand, ListExtensionsCommand } from './extensions.js';
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import type { CommandContext } from './types.js';
|
||||
|
||||
const mockListExtensions = vi.hoisted(() => vi.fn());
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
@@ -42,14 +42,14 @@ describe('ExtensionsCommand', () => {
|
||||
|
||||
it('should default to listing extensions', async () => {
|
||||
const command = new ExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
const mockExtensions = [{ name: 'ext1' }];
|
||||
mockListExtensions.mockReturnValue(mockExtensions);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
|
||||
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -61,19 +61,19 @@ describe('ListExtensionsCommand', () => {
|
||||
|
||||
it('should call listExtensions with the provided config', async () => {
|
||||
const command = new ListExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
const mockExtensions = [{ name: 'ext1' }];
|
||||
mockListExtensions.mockReturnValue(mockExtensions);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
|
||||
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
|
||||
it('should return a message when no extensions are installed', async () => {
|
||||
const command = new ListExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
mockListExtensions.mockReturnValue([]);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
@@ -82,6 +82,6 @@ describe('ListExtensionsCommand', () => {
|
||||
name: 'extensions list',
|
||||
data: 'No extensions installed.',
|
||||
});
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { listExtensions, type Config } from '@google/gemini-cli-core';
|
||||
import type { Command, CommandExecutionResponse } from './types.js';
|
||||
import { listExtensions } from '@google/gemini-cli-core';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class ExtensionsCommand implements Command {
|
||||
readonly name = 'extensions';
|
||||
@@ -14,10 +18,10 @@ export class ExtensionsCommand implements Command {
|
||||
readonly topLevel = true;
|
||||
|
||||
async execute(
|
||||
config: Config,
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
return new ListExtensionsCommand().execute(config, _);
|
||||
return new ListExtensionsCommand().execute(context, _);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,10 +30,10 @@ export class ListExtensionsCommand implements Command {
|
||||
readonly description = 'Lists all installed extensions.';
|
||||
|
||||
async execute(
|
||||
config: Config,
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensions = listExtensions(config);
|
||||
const extensions = listExtensions(context.config);
|
||||
const data = extensions.length ? extensions : 'No extensions installed.';
|
||||
|
||||
return { name: this.name, data };
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { RestoreCommand, ListCheckpointsCommand } from './restore.js';
|
||||
import type { CommandContext } from './types.js';
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
const mockPerformRestore = vi.hoisted(() => vi.fn());
|
||||
const mockLoggerInfo = vi.hoisted(() => vi.fn());
|
||||
const mockGetCheckpointInfoList = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
performRestore: mockPerformRestore,
|
||||
getCheckpointInfoList: mockGetCheckpointInfoList,
|
||||
};
|
||||
});
|
||||
|
||||
const mockFs = vi.hoisted(() => ({
|
||||
readFile: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('node:fs/promises', () => mockFs);
|
||||
|
||||
vi.mock('../utils/logger.js', () => ({
|
||||
logger: {
|
||||
info: mockLoggerInfo,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('RestoreCommand', () => {
|
||||
const mockConfig = {
|
||||
config: createMockConfig() as Config,
|
||||
git: {},
|
||||
} as CommandContext;
|
||||
|
||||
it('should return error if no checkpoint name is provided', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const result = await command.execute(mockConfig, []);
|
||||
expect(result.data).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Please provide a checkpoint name to restore.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should restore a checkpoint when a valid file is provided', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const toolCallData = {
|
||||
toolCall: {
|
||||
name: 'test-tool',
|
||||
args: {},
|
||||
},
|
||||
history: [],
|
||||
clientHistory: [],
|
||||
commitHash: '123',
|
||||
};
|
||||
mockFs.readFile.mockResolvedValue(JSON.stringify(toolCallData));
|
||||
const restoreContent = {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'Restored',
|
||||
};
|
||||
mockPerformRestore.mockReturnValue(
|
||||
(async function* () {
|
||||
yield restoreContent;
|
||||
})(),
|
||||
);
|
||||
const result = await command.execute(mockConfig, ['checkpoint1.json']);
|
||||
expect(result.data).toEqual([restoreContent]);
|
||||
});
|
||||
|
||||
it('should show "file not found" error for a non-existent checkpoint', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const error = new Error('File not found');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
mockFs.readFile.mockRejectedValue(error);
|
||||
const result = await command.execute(mockConfig, ['checkpoint2.json']);
|
||||
expect(result.data).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'File not found: checkpoint2.json',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle invalid JSON in checkpoint file', async () => {
|
||||
const command = new RestoreCommand();
|
||||
mockFs.readFile.mockResolvedValue('invalid json');
|
||||
const result = await command.execute(mockConfig, ['checkpoint1.json']);
|
||||
expect((result.data as { content: string }).content).toContain(
|
||||
'An unexpected error occurred during restore.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ListCheckpointsCommand', () => {
|
||||
const mockConfig = {
|
||||
config: createMockConfig() as Config,
|
||||
} as CommandContext;
|
||||
|
||||
it('should list all available checkpoints', async () => {
|
||||
const command = new ListCheckpointsCommand();
|
||||
const checkpointInfo = [{ file: 'checkpoint1.json', description: 'Test' }];
|
||||
mockFs.readdir.mockResolvedValue(['checkpoint1.json']);
|
||||
mockFs.readFile.mockResolvedValue(
|
||||
JSON.stringify({ toolCall: { name: 'Test', args: {} } }),
|
||||
);
|
||||
mockGetCheckpointInfoList.mockReturnValue(checkpointInfo);
|
||||
const result = await command.execute(mockConfig);
|
||||
expect((result.data as { content: string }).content).toEqual(
|
||||
JSON.stringify(checkpointInfo),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors when listing checkpoints', async () => {
|
||||
const command = new ListCheckpointsCommand();
|
||||
mockFs.readdir.mockRejectedValue(new Error('Read error'));
|
||||
const result = await command.execute(mockConfig);
|
||||
expect((result.data as { content: string }).content).toContain(
|
||||
'An unexpected error occurred while listing checkpoints.',
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
getCheckpointInfoList,
|
||||
getToolCallDataSchema,
|
||||
isNodeError,
|
||||
performRestore,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class RestoreCommand implements Command {
|
||||
readonly name = 'restore';
|
||||
readonly description =
|
||||
'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created';
|
||||
readonly topLevel = true;
|
||||
readonly requiresWorkspace = true;
|
||||
readonly subCommands = [new ListCheckpointsCommand()];
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const { config, git: gitService } = context;
|
||||
const argsStr = args.join(' ');
|
||||
|
||||
try {
|
||||
if (!argsStr) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Please provide a checkpoint name to restore.',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const selectedFile = argsStr.endsWith('.json')
|
||||
? argsStr
|
||||
: `${argsStr}.json`;
|
||||
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
const filePath = path.join(checkpointDir, selectedFile);
|
||||
|
||||
let data: string;
|
||||
try {
|
||||
data = await fs.readFile(filePath, 'utf-8');
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ENOENT') {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `File not found: ${selectedFile}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const toolCallData = JSON.parse(data);
|
||||
const ToolCallDataSchema = getToolCallDataSchema();
|
||||
const parseResult = ToolCallDataSchema.safeParse(toolCallData);
|
||||
|
||||
if (!parseResult.success) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Checkpoint file is invalid or corrupted.',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const restoreResultGenerator = performRestore(
|
||||
parseResult.data,
|
||||
gitService,
|
||||
);
|
||||
const restoreResult = [];
|
||||
for await (const result of restoreResultGenerator) {
|
||||
restoreResult.push(result);
|
||||
}
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: restoreResult,
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'An unexpected error occurred during restore.',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class ListCheckpointsCommand implements Command {
|
||||
readonly name = 'restore list';
|
||||
readonly description = 'Lists all available checkpoints.';
|
||||
readonly topLevel = false;
|
||||
|
||||
async execute(context: CommandContext): Promise<CommandExecutionResponse> {
|
||||
const { config } = context;
|
||||
|
||||
try {
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
const files = await fs.readdir(checkpointDir);
|
||||
const jsonFiles = files.filter((file) => file.endsWith('.json'));
|
||||
|
||||
const checkpointFiles = new Map<string, string>();
|
||||
for (const file of jsonFiles) {
|
||||
const filePath = path.join(checkpointDir, file);
|
||||
const data = await fs.readFile(filePath, 'utf-8');
|
||||
checkpointFiles.set(file, data);
|
||||
}
|
||||
|
||||
const checkpointInfoList = getCheckpointInfoList(checkpointFiles);
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: JSON.stringify(checkpointInfoList),
|
||||
},
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'An unexpected error occurred while listing checkpoints.',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import type { Config, GitService } from '@google/gemini-cli-core';
|
||||
|
||||
export interface CommandContext {
|
||||
config: Config;
|
||||
git?: GitService;
|
||||
}
|
||||
|
||||
export interface CommandArgument {
|
||||
readonly name: string;
|
||||
@@ -18,8 +23,12 @@ export interface Command {
|
||||
readonly arguments?: CommandArgument[];
|
||||
readonly subCommands?: Command[];
|
||||
readonly topLevel?: boolean;
|
||||
readonly requiresWorkspace?: boolean;
|
||||
|
||||
execute(config: Config, args: string[]): Promise<CommandExecutionResponse>;
|
||||
execute(
|
||||
config: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse>;
|
||||
}
|
||||
|
||||
export interface CommandExecutionResponse {
|
||||
|
||||
Reference in New Issue
Block a user