feat(a2a): Introduce restore command for a2a server (#13015)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Shreya Keshive <shreyakeshive@google.com>
This commit is contained in:
Coco Sheng
2025-12-09 10:08:23 -05:00
committed by GitHub
parent afd4829f10
commit 1f813f6a06
23 changed files with 1173 additions and 148 deletions
@@ -5,6 +5,7 @@
*/
import { ExtensionsCommand } from './extensions.js';
import { RestoreCommand } from './restore.js';
import type { Command } from './types.js';
class CommandRegistry {
@@ -12,6 +13,7 @@ class CommandRegistry {
constructor() {
this.register(new ExtensionsCommand());
this.register(new RestoreCommand());
}
register(command: Command) {
@@ -6,7 +6,7 @@
import { describe, it, expect, vi } from 'vitest';
import { ExtensionsCommand, ListExtensionsCommand } from './extensions.js';
import type { Config } from '@google/gemini-cli-core';
import type { CommandContext } from './types.js';
const mockListExtensions = vi.hoisted(() => vi.fn());
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
@@ -42,14 +42,14 @@ describe('ExtensionsCommand', () => {
it('should default to listing extensions', async () => {
const command = new ExtensionsCommand();
const mockConfig = {} as Config;
const mockConfig = { config: {} } as CommandContext;
const mockExtensions = [{ name: 'ext1' }];
mockListExtensions.mockReturnValue(mockExtensions);
const result = await command.execute(mockConfig, []);
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
});
});
@@ -61,19 +61,19 @@ describe('ListExtensionsCommand', () => {
it('should call listExtensions with the provided config', async () => {
const command = new ListExtensionsCommand();
const mockConfig = {} as Config;
const mockConfig = { config: {} } as CommandContext;
const mockExtensions = [{ name: 'ext1' }];
mockListExtensions.mockReturnValue(mockExtensions);
const result = await command.execute(mockConfig, []);
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
});
it('should return a message when no extensions are installed', async () => {
const command = new ListExtensionsCommand();
const mockConfig = {} as Config;
const mockConfig = { config: {} } as CommandContext;
mockListExtensions.mockReturnValue([]);
const result = await command.execute(mockConfig, []);
@@ -82,6 +82,6 @@ describe('ListExtensionsCommand', () => {
name: 'extensions list',
data: 'No extensions installed.',
});
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
});
});
+10 -6
View File
@@ -4,8 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { listExtensions, type Config } from '@google/gemini-cli-core';
import type { Command, CommandExecutionResponse } from './types.js';
import { listExtensions } from '@google/gemini-cli-core';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
export class ExtensionsCommand implements Command {
readonly name = 'extensions';
@@ -14,10 +18,10 @@ export class ExtensionsCommand implements Command {
readonly topLevel = true;
async execute(
config: Config,
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
return new ListExtensionsCommand().execute(config, _);
return new ListExtensionsCommand().execute(context, _);
}
}
@@ -26,10 +30,10 @@ export class ListExtensionsCommand implements Command {
readonly description = 'Lists all installed extensions.';
async execute(
config: Config,
context: CommandContext,
_: string[],
): Promise<CommandExecutionResponse> {
const extensions = listExtensions(config);
const extensions = listExtensions(context.config);
const data = extensions.length ? extensions : 'No extensions installed.';
return { name: this.name, data };
@@ -0,0 +1,137 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { RestoreCommand, ListCheckpointsCommand } from './restore.js';
import type { CommandContext } from './types.js';
import type { Config } from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
beforeEach(() => {
vi.clearAllMocks();
});
const mockPerformRestore = vi.hoisted(() => vi.fn());
const mockLoggerInfo = vi.hoisted(() => vi.fn());
const mockGetCheckpointInfoList = vi.hoisted(() => vi.fn());
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const original =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...original,
performRestore: mockPerformRestore,
getCheckpointInfoList: mockGetCheckpointInfoList,
};
});
const mockFs = vi.hoisted(() => ({
readFile: vi.fn(),
readdir: vi.fn(),
mkdir: vi.fn(),
}));
vi.mock('node:fs/promises', () => mockFs);
vi.mock('../utils/logger.js', () => ({
logger: {
info: mockLoggerInfo,
},
}));
describe('RestoreCommand', () => {
const mockConfig = {
config: createMockConfig() as Config,
git: {},
} as CommandContext;
it('should return error if no checkpoint name is provided', async () => {
const command = new RestoreCommand();
const result = await command.execute(mockConfig, []);
expect(result.data).toEqual({
type: 'message',
messageType: 'error',
content: 'Please provide a checkpoint name to restore.',
});
});
it('should restore a checkpoint when a valid file is provided', async () => {
const command = new RestoreCommand();
const toolCallData = {
toolCall: {
name: 'test-tool',
args: {},
},
history: [],
clientHistory: [],
commitHash: '123',
};
mockFs.readFile.mockResolvedValue(JSON.stringify(toolCallData));
const restoreContent = {
type: 'message',
messageType: 'info',
content: 'Restored',
};
mockPerformRestore.mockReturnValue(
(async function* () {
yield restoreContent;
})(),
);
const result = await command.execute(mockConfig, ['checkpoint1.json']);
expect(result.data).toEqual([restoreContent]);
});
it('should show "file not found" error for a non-existent checkpoint', async () => {
const command = new RestoreCommand();
const error = new Error('File not found');
(error as NodeJS.ErrnoException).code = 'ENOENT';
mockFs.readFile.mockRejectedValue(error);
const result = await command.execute(mockConfig, ['checkpoint2.json']);
expect(result.data).toEqual({
type: 'message',
messageType: 'error',
content: 'File not found: checkpoint2.json',
});
});
it('should handle invalid JSON in checkpoint file', async () => {
const command = new RestoreCommand();
mockFs.readFile.mockResolvedValue('invalid json');
const result = await command.execute(mockConfig, ['checkpoint1.json']);
expect((result.data as { content: string }).content).toContain(
'An unexpected error occurred during restore.',
);
});
});
describe('ListCheckpointsCommand', () => {
const mockConfig = {
config: createMockConfig() as Config,
} as CommandContext;
it('should list all available checkpoints', async () => {
const command = new ListCheckpointsCommand();
const checkpointInfo = [{ file: 'checkpoint1.json', description: 'Test' }];
mockFs.readdir.mockResolvedValue(['checkpoint1.json']);
mockFs.readFile.mockResolvedValue(
JSON.stringify({ toolCall: { name: 'Test', args: {} } }),
);
mockGetCheckpointInfoList.mockReturnValue(checkpointInfo);
const result = await command.execute(mockConfig);
expect((result.data as { content: string }).content).toEqual(
JSON.stringify(checkpointInfo),
);
});
it('should handle errors when listing checkpoints', async () => {
const command = new ListCheckpointsCommand();
mockFs.readdir.mockRejectedValue(new Error('Read error'));
const result = await command.execute(mockConfig);
expect((result.data as { content: string }).content).toContain(
'An unexpected error occurred while listing checkpoints.',
);
});
});
+155
View File
@@ -0,0 +1,155 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
getCheckpointInfoList,
getToolCallDataSchema,
isNodeError,
performRestore,
} from '@google/gemini-cli-core';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
export class RestoreCommand implements Command {
readonly name = 'restore';
readonly description =
'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created';
readonly topLevel = true;
readonly requiresWorkspace = true;
readonly subCommands = [new ListCheckpointsCommand()];
async execute(
context: CommandContext,
args: string[],
): Promise<CommandExecutionResponse> {
const { config, git: gitService } = context;
const argsStr = args.join(' ');
try {
if (!argsStr) {
return {
name: this.name,
data: {
type: 'message',
messageType: 'error',
content: 'Please provide a checkpoint name to restore.',
},
};
}
const selectedFile = argsStr.endsWith('.json')
? argsStr
: `${argsStr}.json`;
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
const filePath = path.join(checkpointDir, selectedFile);
let data: string;
try {
data = await fs.readFile(filePath, 'utf-8');
} catch (error) {
if (isNodeError(error) && error.code === 'ENOENT') {
return {
name: this.name,
data: {
type: 'message',
messageType: 'error',
content: `File not found: ${selectedFile}`,
},
};
}
throw error;
}
const toolCallData = JSON.parse(data);
const ToolCallDataSchema = getToolCallDataSchema();
const parseResult = ToolCallDataSchema.safeParse(toolCallData);
if (!parseResult.success) {
return {
name: this.name,
data: {
type: 'message',
messageType: 'error',
content: 'Checkpoint file is invalid or corrupted.',
},
};
}
const restoreResultGenerator = performRestore(
parseResult.data,
gitService,
);
const restoreResult = [];
for await (const result of restoreResultGenerator) {
restoreResult.push(result);
}
return {
name: this.name,
data: restoreResult,
};
} catch (_error) {
return {
name: this.name,
data: {
type: 'message',
messageType: 'error',
content: 'An unexpected error occurred during restore.',
},
};
}
}
}
export class ListCheckpointsCommand implements Command {
readonly name = 'restore list';
readonly description = 'Lists all available checkpoints.';
readonly topLevel = false;
async execute(context: CommandContext): Promise<CommandExecutionResponse> {
const { config } = context;
try {
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
await fs.mkdir(checkpointDir, { recursive: true });
const files = await fs.readdir(checkpointDir);
const jsonFiles = files.filter((file) => file.endsWith('.json'));
const checkpointFiles = new Map<string, string>();
for (const file of jsonFiles) {
const filePath = path.join(checkpointDir, file);
const data = await fs.readFile(filePath, 'utf-8');
checkpointFiles.set(file, data);
}
const checkpointInfoList = getCheckpointInfoList(checkpointFiles);
return {
name: this.name,
data: {
type: 'message',
messageType: 'info',
content: JSON.stringify(checkpointInfoList),
},
};
} catch (_error) {
return {
name: this.name,
data: {
type: 'message',
messageType: 'error',
content: 'An unexpected error occurred while listing checkpoints.',
},
};
}
}
}
+11 -2
View File
@@ -4,7 +4,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Config } from '@google/gemini-cli-core';
import type { Config, GitService } from '@google/gemini-cli-core';
export interface CommandContext {
config: Config;
git?: GitService;
}
export interface CommandArgument {
readonly name: string;
@@ -18,8 +23,12 @@ export interface Command {
readonly arguments?: CommandArgument[];
readonly subCommands?: Command[];
readonly topLevel?: boolean;
readonly requiresWorkspace?: boolean;
execute(config: Config, args: string[]): Promise<CommandExecutionResponse>;
execute(
config: CommandContext,
args: string[],
): Promise<CommandExecutionResponse>;
}
export interface CommandExecutionResponse {