mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-15 16:41:11 -07:00
feat(a2a): Introduce restore command for a2a server (#13015)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shreya Keshive <shreyakeshive@google.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import {
|
||||
GeminiEventType,
|
||||
type Config,
|
||||
type ToolCallRequestInfo,
|
||||
type GitService,
|
||||
type CompletedToolCall,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
@@ -25,6 +26,17 @@ import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type { ToolCall } from '@google/gemini-cli-core';
|
||||
|
||||
const mockProcessRestorableToolCalls = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
processRestorableToolCalls: mockProcessRestorableToolCalls,
|
||||
};
|
||||
});
|
||||
|
||||
describe('Task', () => {
|
||||
it('scheduleToolCalls should not modify the input requests array', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
@@ -72,6 +84,141 @@ describe('Task', () => {
|
||||
expect(requests).toEqual(originalRequests);
|
||||
});
|
||||
|
||||
describe('scheduleToolCalls', () => {
|
||||
const mockConfig = createMockConfig();
|
||||
const mockEventBus: ExecutionEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should not create a checkpoint if no restorable tools are called', async () => {
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'run_shell_command',
|
||||
args: { command: 'ls' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create a checkpoint if a restorable tool is called', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
getCheckpointingEnabled: () => true,
|
||||
getGitService: () => Promise.resolve({} as GitService),
|
||||
});
|
||||
mockProcessRestorableToolCalls.mockResolvedValue({
|
||||
checkpointsToWrite: new Map([['test.json', 'test content']]),
|
||||
toolCallToCheckpointMap: new Map(),
|
||||
errors: [],
|
||||
});
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: {
|
||||
file_path: 'test.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should process all restorable tools for checkpointing in a single batch', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
getCheckpointingEnabled: () => true,
|
||||
getGitService: () => Promise.resolve({} as GitService),
|
||||
});
|
||||
mockProcessRestorableToolCalls.mockResolvedValue({
|
||||
checkpointsToWrite: new Map([
|
||||
['test1.json', 'test content 1'],
|
||||
['test2.json', 'test content 2'],
|
||||
]),
|
||||
toolCallToCheckpointMap: new Map([
|
||||
['1', 'test1'],
|
||||
['2', 'test2'],
|
||||
]),
|
||||
errors: [],
|
||||
});
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: {
|
||||
file_path: 'test.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'write_file',
|
||||
args: { file_path: 'test2.txt', content: 'new content' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-2',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'not_restorable',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-3',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).toHaveBeenCalledExactlyOnceWith(
|
||||
[
|
||||
expect.objectContaining({ callId: '1' }),
|
||||
expect.objectContaining({ callId: '2' }),
|
||||
],
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('acceptAgentMessage', () => {
|
||||
it('should set currentTraceId when event has traceId', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
|
||||
@@ -27,6 +27,8 @@ import {
|
||||
type Config,
|
||||
type UserTierId,
|
||||
type AnsiOutput,
|
||||
EDIT_TOOL_NAMES,
|
||||
processRestorableToolCalls,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||
import { type ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
@@ -40,7 +42,8 @@ import type {
|
||||
} from '@a2a-js/sdk';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { logger } from '../utils/logger.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type {
|
||||
CoderAgentMessage,
|
||||
@@ -511,7 +514,7 @@ export class Task {
|
||||
new_string: string,
|
||||
): Promise<string> {
|
||||
try {
|
||||
const currentContent = fs.readFileSync(file_path, 'utf8');
|
||||
const currentContent = await fs.readFile(file_path, 'utf8');
|
||||
return this._applyReplacement(
|
||||
currentContent,
|
||||
old_string,
|
||||
@@ -554,6 +557,44 @@ export class Task {
|
||||
return;
|
||||
}
|
||||
|
||||
// Set checkpoint file before any file modification tool executes
|
||||
const restorableToolCalls = requests.filter((request) =>
|
||||
EDIT_TOOL_NAMES.has(request.name),
|
||||
);
|
||||
|
||||
if (restorableToolCalls.length > 0) {
|
||||
const gitService = await this.config.getGitService();
|
||||
if (gitService) {
|
||||
const { checkpointsToWrite, toolCallToCheckpointMap, errors } =
|
||||
await processRestorableToolCalls(
|
||||
restorableToolCalls,
|
||||
gitService,
|
||||
this.geminiClient,
|
||||
);
|
||||
|
||||
if (errors.length > 0) {
|
||||
errors.forEach((error) => logger.error(error));
|
||||
}
|
||||
|
||||
if (checkpointsToWrite.size > 0) {
|
||||
const checkpointDir =
|
||||
this.config.storage.getProjectTempCheckpointsDir();
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
for (const [fileName, content] of checkpointsToWrite) {
|
||||
const filePath = path.join(checkpointDir, fileName);
|
||||
await fs.writeFile(filePath, content);
|
||||
}
|
||||
}
|
||||
|
||||
for (const request of requests) {
|
||||
const checkpoint = toolCallToCheckpointMap.get(request.callId);
|
||||
if (checkpoint) {
|
||||
request.checkpoint = checkpoint;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const updatedRequests = await Promise.all(
|
||||
requests.map(async (request) => {
|
||||
if (
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import { ExtensionsCommand } from './extensions.js';
|
||||
import { RestoreCommand } from './restore.js';
|
||||
import type { Command } from './types.js';
|
||||
|
||||
class CommandRegistry {
|
||||
@@ -12,6 +13,7 @@ class CommandRegistry {
|
||||
|
||||
constructor() {
|
||||
this.register(new ExtensionsCommand());
|
||||
this.register(new RestoreCommand());
|
||||
}
|
||||
|
||||
register(command: Command) {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { ExtensionsCommand, ListExtensionsCommand } from './extensions.js';
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import type { CommandContext } from './types.js';
|
||||
|
||||
const mockListExtensions = vi.hoisted(() => vi.fn());
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
@@ -42,14 +42,14 @@ describe('ExtensionsCommand', () => {
|
||||
|
||||
it('should default to listing extensions', async () => {
|
||||
const command = new ExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
const mockExtensions = [{ name: 'ext1' }];
|
||||
mockListExtensions.mockReturnValue(mockExtensions);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
|
||||
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -61,19 +61,19 @@ describe('ListExtensionsCommand', () => {
|
||||
|
||||
it('should call listExtensions with the provided config', async () => {
|
||||
const command = new ListExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
const mockExtensions = [{ name: 'ext1' }];
|
||||
mockListExtensions.mockReturnValue(mockExtensions);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
|
||||
expect(result).toEqual({ name: 'extensions list', data: mockExtensions });
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
|
||||
it('should return a message when no extensions are installed', async () => {
|
||||
const command = new ListExtensionsCommand();
|
||||
const mockConfig = {} as Config;
|
||||
const mockConfig = { config: {} } as CommandContext;
|
||||
mockListExtensions.mockReturnValue([]);
|
||||
|
||||
const result = await command.execute(mockConfig, []);
|
||||
@@ -82,6 +82,6 @@ describe('ListExtensionsCommand', () => {
|
||||
name: 'extensions list',
|
||||
data: 'No extensions installed.',
|
||||
});
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig);
|
||||
expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,8 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { listExtensions, type Config } from '@google/gemini-cli-core';
|
||||
import type { Command, CommandExecutionResponse } from './types.js';
|
||||
import { listExtensions } from '@google/gemini-cli-core';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class ExtensionsCommand implements Command {
|
||||
readonly name = 'extensions';
|
||||
@@ -14,10 +18,10 @@ export class ExtensionsCommand implements Command {
|
||||
readonly topLevel = true;
|
||||
|
||||
async execute(
|
||||
config: Config,
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
return new ListExtensionsCommand().execute(config, _);
|
||||
return new ListExtensionsCommand().execute(context, _);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,10 +30,10 @@ export class ListExtensionsCommand implements Command {
|
||||
readonly description = 'Lists all installed extensions.';
|
||||
|
||||
async execute(
|
||||
config: Config,
|
||||
context: CommandContext,
|
||||
_: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const extensions = listExtensions(config);
|
||||
const extensions = listExtensions(context.config);
|
||||
const data = extensions.length ? extensions : 'No extensions installed.';
|
||||
|
||||
return { name: this.name, data };
|
||||
|
||||
137
packages/a2a-server/src/commands/restore.test.ts
Normal file
137
packages/a2a-server/src/commands/restore.test.ts
Normal file
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { RestoreCommand, ListCheckpointsCommand } from './restore.js';
|
||||
import type { CommandContext } from './types.js';
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
const mockPerformRestore = vi.hoisted(() => vi.fn());
|
||||
const mockLoggerInfo = vi.hoisted(() => vi.fn());
|
||||
const mockGetCheckpointInfoList = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
performRestore: mockPerformRestore,
|
||||
getCheckpointInfoList: mockGetCheckpointInfoList,
|
||||
};
|
||||
});
|
||||
|
||||
const mockFs = vi.hoisted(() => ({
|
||||
readFile: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
mkdir: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('node:fs/promises', () => mockFs);
|
||||
|
||||
vi.mock('../utils/logger.js', () => ({
|
||||
logger: {
|
||||
info: mockLoggerInfo,
|
||||
},
|
||||
}));
|
||||
|
||||
describe('RestoreCommand', () => {
|
||||
const mockConfig = {
|
||||
config: createMockConfig() as Config,
|
||||
git: {},
|
||||
} as CommandContext;
|
||||
|
||||
it('should return error if no checkpoint name is provided', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const result = await command.execute(mockConfig, []);
|
||||
expect(result.data).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Please provide a checkpoint name to restore.',
|
||||
});
|
||||
});
|
||||
|
||||
it('should restore a checkpoint when a valid file is provided', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const toolCallData = {
|
||||
toolCall: {
|
||||
name: 'test-tool',
|
||||
args: {},
|
||||
},
|
||||
history: [],
|
||||
clientHistory: [],
|
||||
commitHash: '123',
|
||||
};
|
||||
mockFs.readFile.mockResolvedValue(JSON.stringify(toolCallData));
|
||||
const restoreContent = {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: 'Restored',
|
||||
};
|
||||
mockPerformRestore.mockReturnValue(
|
||||
(async function* () {
|
||||
yield restoreContent;
|
||||
})(),
|
||||
);
|
||||
const result = await command.execute(mockConfig, ['checkpoint1.json']);
|
||||
expect(result.data).toEqual([restoreContent]);
|
||||
});
|
||||
|
||||
it('should show "file not found" error for a non-existent checkpoint', async () => {
|
||||
const command = new RestoreCommand();
|
||||
const error = new Error('File not found');
|
||||
(error as NodeJS.ErrnoException).code = 'ENOENT';
|
||||
mockFs.readFile.mockRejectedValue(error);
|
||||
const result = await command.execute(mockConfig, ['checkpoint2.json']);
|
||||
expect(result.data).toEqual({
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'File not found: checkpoint2.json',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle invalid JSON in checkpoint file', async () => {
|
||||
const command = new RestoreCommand();
|
||||
mockFs.readFile.mockResolvedValue('invalid json');
|
||||
const result = await command.execute(mockConfig, ['checkpoint1.json']);
|
||||
expect((result.data as { content: string }).content).toContain(
|
||||
'An unexpected error occurred during restore.',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ListCheckpointsCommand', () => {
|
||||
const mockConfig = {
|
||||
config: createMockConfig() as Config,
|
||||
} as CommandContext;
|
||||
|
||||
it('should list all available checkpoints', async () => {
|
||||
const command = new ListCheckpointsCommand();
|
||||
const checkpointInfo = [{ file: 'checkpoint1.json', description: 'Test' }];
|
||||
mockFs.readdir.mockResolvedValue(['checkpoint1.json']);
|
||||
mockFs.readFile.mockResolvedValue(
|
||||
JSON.stringify({ toolCall: { name: 'Test', args: {} } }),
|
||||
);
|
||||
mockGetCheckpointInfoList.mockReturnValue(checkpointInfo);
|
||||
const result = await command.execute(mockConfig);
|
||||
expect((result.data as { content: string }).content).toEqual(
|
||||
JSON.stringify(checkpointInfo),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle errors when listing checkpoints', async () => {
|
||||
const command = new ListCheckpointsCommand();
|
||||
mockFs.readdir.mockRejectedValue(new Error('Read error'));
|
||||
const result = await command.execute(mockConfig);
|
||||
expect((result.data as { content: string }).content).toContain(
|
||||
'An unexpected error occurred while listing checkpoints.',
|
||||
);
|
||||
});
|
||||
});
|
||||
155
packages/a2a-server/src/commands/restore.ts
Normal file
155
packages/a2a-server/src/commands/restore.ts
Normal file
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
getCheckpointInfoList,
|
||||
getToolCallDataSchema,
|
||||
isNodeError,
|
||||
performRestore,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import type {
|
||||
Command,
|
||||
CommandContext,
|
||||
CommandExecutionResponse,
|
||||
} from './types.js';
|
||||
|
||||
export class RestoreCommand implements Command {
|
||||
readonly name = 'restore';
|
||||
readonly description =
|
||||
'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created';
|
||||
readonly topLevel = true;
|
||||
readonly requiresWorkspace = true;
|
||||
readonly subCommands = [new ListCheckpointsCommand()];
|
||||
|
||||
async execute(
|
||||
context: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse> {
|
||||
const { config, git: gitService } = context;
|
||||
const argsStr = args.join(' ');
|
||||
|
||||
try {
|
||||
if (!argsStr) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Please provide a checkpoint name to restore.',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const selectedFile = argsStr.endsWith('.json')
|
||||
? argsStr
|
||||
: `${argsStr}.json`;
|
||||
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
const filePath = path.join(checkpointDir, selectedFile);
|
||||
|
||||
let data: string;
|
||||
try {
|
||||
data = await fs.readFile(filePath, 'utf-8');
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ENOENT') {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: `File not found: ${selectedFile}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
const toolCallData = JSON.parse(data);
|
||||
const ToolCallDataSchema = getToolCallDataSchema();
|
||||
const parseResult = ToolCallDataSchema.safeParse(toolCallData);
|
||||
|
||||
if (!parseResult.success) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'Checkpoint file is invalid or corrupted.',
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const restoreResultGenerator = performRestore(
|
||||
parseResult.data,
|
||||
gitService,
|
||||
);
|
||||
const restoreResult = [];
|
||||
for await (const result of restoreResultGenerator) {
|
||||
restoreResult.push(result);
|
||||
}
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: restoreResult,
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'An unexpected error occurred during restore.',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class ListCheckpointsCommand implements Command {
|
||||
readonly name = 'restore list';
|
||||
readonly description = 'Lists all available checkpoints.';
|
||||
readonly topLevel = false;
|
||||
|
||||
async execute(context: CommandContext): Promise<CommandExecutionResponse> {
|
||||
const { config } = context;
|
||||
|
||||
try {
|
||||
const checkpointDir = config.storage.getProjectTempCheckpointsDir();
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
const files = await fs.readdir(checkpointDir);
|
||||
const jsonFiles = files.filter((file) => file.endsWith('.json'));
|
||||
|
||||
const checkpointFiles = new Map<string, string>();
|
||||
for (const file of jsonFiles) {
|
||||
const filePath = path.join(checkpointDir, file);
|
||||
const data = await fs.readFile(filePath, 'utf-8');
|
||||
checkpointFiles.set(file, data);
|
||||
}
|
||||
|
||||
const checkpointInfoList = getCheckpointInfoList(checkpointFiles);
|
||||
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
content: JSON.stringify(checkpointInfoList),
|
||||
},
|
||||
};
|
||||
} catch (_error) {
|
||||
return {
|
||||
name: this.name,
|
||||
data: {
|
||||
type: 'message',
|
||||
messageType: 'error',
|
||||
content: 'An unexpected error occurred while listing checkpoints.',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,12 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import type { Config, GitService } from '@google/gemini-cli-core';
|
||||
|
||||
export interface CommandContext {
|
||||
config: Config;
|
||||
git?: GitService;
|
||||
}
|
||||
|
||||
export interface CommandArgument {
|
||||
readonly name: string;
|
||||
@@ -18,8 +23,12 @@ export interface Command {
|
||||
readonly arguments?: CommandArgument[];
|
||||
readonly subCommands?: Command[];
|
||||
readonly topLevel?: boolean;
|
||||
readonly requiresWorkspace?: boolean;
|
||||
|
||||
execute(config: Config, args: string[]): Promise<CommandExecutionResponse>;
|
||||
execute(
|
||||
config: CommandContext,
|
||||
args: string[],
|
||||
): Promise<CommandExecutionResponse>;
|
||||
}
|
||||
|
||||
export interface CommandExecutionResponse {
|
||||
|
||||
@@ -71,6 +71,9 @@ export async function loadConfig(
|
||||
ideMode: false,
|
||||
folderTrust: settings.folderTrust === true,
|
||||
extensionLoader,
|
||||
checkpointing: process.env['CHECKPOINTING']
|
||||
? process.env['CHECKPOINTING'] === 'true'
|
||||
: settings.checkpointing?.enabled,
|
||||
previewFeatures: settings.general?.previewFeatures,
|
||||
};
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ import {
|
||||
createMockConfig,
|
||||
} from '../utils/testing_utils.js';
|
||||
import { MockTool } from '@google/gemini-cli-core';
|
||||
import type { Command } from '../commands/types.js';
|
||||
import type { Command, CommandContext } from '../commands/types.js';
|
||||
|
||||
const mockToolConfirmationFn = async () =>
|
||||
({}) as unknown as ToolCallConfirmationDetails;
|
||||
@@ -97,6 +97,7 @@ vi.mock('@google/gemini-cli-core', async () => {
|
||||
getUserTier: vi.fn().mockReturnValue('free'),
|
||||
initialize: vi.fn(),
|
||||
})),
|
||||
performRestore: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
@@ -939,6 +940,17 @@ describe('E2E Tests', () => {
|
||||
});
|
||||
|
||||
it('should return extensions for valid command', async () => {
|
||||
const mockExtensionsCommand = {
|
||||
name: 'extensions list',
|
||||
description: 'a mock command',
|
||||
execute: vi.fn(async (context: CommandContext) => {
|
||||
// Simulate the actual command's behavior
|
||||
const extensions = context.config.getExtensions();
|
||||
return { name: 'extensions list', data: extensions };
|
||||
}),
|
||||
};
|
||||
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockExtensionsCommand);
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/executeCommand')
|
||||
@@ -954,6 +966,8 @@ describe('E2E Tests', () => {
|
||||
});
|
||||
|
||||
it('should return 404 for invalid command', async () => {
|
||||
vi.spyOn(commandRegistry, 'get').mockReturnValue(undefined);
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/executeCommand')
|
||||
@@ -986,5 +1000,66 @@ describe('E2E Tests', () => {
|
||||
expect(res.body.error).toBe('"args" field must be an array.');
|
||||
expect(getExtensionsSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should execute a command that does not require a workspace when CODER_AGENT_WORKSPACE_PATH is not set', async () => {
|
||||
const mockCommand = {
|
||||
name: 'test-command',
|
||||
description: 'a mock command',
|
||||
execute: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'test-command', data: 'success' }),
|
||||
};
|
||||
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockCommand);
|
||||
|
||||
delete process.env['CODER_AGENT_WORKSPACE_PATH'];
|
||||
const response = await request(app)
|
||||
.post('/executeCommand')
|
||||
.send({ command: 'test-command', args: [] });
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.data).toBe('success');
|
||||
});
|
||||
|
||||
it('should return 400 for a command that requires a workspace when CODER_AGENT_WORKSPACE_PATH is not set', async () => {
|
||||
const mockWorkspaceCommand = {
|
||||
name: 'workspace-command',
|
||||
description: 'A command that requires a workspace',
|
||||
requiresWorkspace: true,
|
||||
execute: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'workspace-command', data: 'success' }),
|
||||
};
|
||||
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockWorkspaceCommand);
|
||||
|
||||
delete process.env['CODER_AGENT_WORKSPACE_PATH'];
|
||||
const response = await request(app)
|
||||
.post('/executeCommand')
|
||||
.send({ command: 'workspace-command', args: [] });
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
expect(response.body.error).toBe(
|
||||
'Command "workspace-command" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.',
|
||||
);
|
||||
});
|
||||
|
||||
it('should execute a command that requires a workspace when CODER_AGENT_WORKSPACE_PATH is set', async () => {
|
||||
const mockWorkspaceCommand = {
|
||||
name: 'workspace-command',
|
||||
description: 'A command that requires a workspace',
|
||||
requiresWorkspace: true,
|
||||
execute: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ name: 'workspace-command', data: 'success' }),
|
||||
};
|
||||
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockWorkspaceCommand);
|
||||
|
||||
process.env['CODER_AGENT_WORKSPACE_PATH'] = '/tmp/test-workspace';
|
||||
const response = await request(app)
|
||||
.post('/executeCommand')
|
||||
.send({ command: 'workspace-command', args: [] });
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.data).toBe('success');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -22,6 +22,7 @@ import { loadExtensions } from '../config/extension.js';
|
||||
import { commandRegistry } from '../commands/command-registry.js';
|
||||
import { SimpleExtensionLoader } from '@google/gemini-cli-core';
|
||||
import type { Command, CommandArgument } from '../commands/types.js';
|
||||
import { GitService } from '@google/gemini-cli-core';
|
||||
|
||||
type CommandResponse = {
|
||||
name: string;
|
||||
@@ -85,6 +86,14 @@ export async function createApp() {
|
||||
'a2a-server',
|
||||
);
|
||||
|
||||
let git: GitService | undefined;
|
||||
if (config.getCheckpointingEnabled()) {
|
||||
git = new GitService(config.getTargetDir(), config.storage);
|
||||
await git.initialize();
|
||||
}
|
||||
|
||||
const context = { config, git };
|
||||
|
||||
// loadEnvironment() is called within getConfig now
|
||||
const bucketName = process.env['GCS_BUCKET_NAME'];
|
||||
let taskStoreForExecutor: TaskStore;
|
||||
@@ -144,6 +153,7 @@ export async function createApp() {
|
||||
});
|
||||
|
||||
expressApp.post('/executeCommand', async (req, res) => {
|
||||
logger.info('[CoreAgent] Received /executeCommand request: ', req.body);
|
||||
try {
|
||||
const { command, args } = req.body;
|
||||
|
||||
@@ -159,13 +169,22 @@ export async function createApp() {
|
||||
|
||||
const commandToExecute = commandRegistry.get(command);
|
||||
|
||||
if (commandToExecute?.requiresWorkspace) {
|
||||
if (!process.env['CODER_AGENT_WORKSPACE_PATH']) {
|
||||
return res.status(400).json({
|
||||
error: `Command "${command}" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (!commandToExecute) {
|
||||
return res
|
||||
.status(404)
|
||||
.json({ error: `Command not found: ${command}` });
|
||||
}
|
||||
|
||||
const result = await commandToExecute.execute(config, args ?? []);
|
||||
const result = await commandToExecute.execute(context, args ?? []);
|
||||
logger.info('[CoreAgent] Sending /executeCommand response: ', result);
|
||||
return res.status(200).json(result);
|
||||
} catch (e) {
|
||||
logger.error('Error executing /executeCommand:', e);
|
||||
|
||||
@@ -37,8 +37,10 @@ export function createMockConfig(
|
||||
isPathWithinWorkspace: () => true,
|
||||
}),
|
||||
getTargetDir: () => '/test',
|
||||
getCheckpointingEnabled: vi.fn().mockReturnValue(false),
|
||||
storage: {
|
||||
getProjectTempDir: () => '/tmp',
|
||||
getProjectTempCheckpointsDir: () => '/tmp/checkpoints',
|
||||
} as Storage,
|
||||
getTruncateToolOutputThreshold: () =>
|
||||
DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD,
|
||||
@@ -62,6 +64,7 @@ export function createMockConfig(
|
||||
getMcpClientManager: vi.fn().mockReturnValue({
|
||||
getMcpServers: vi.fn().mockReturnValue({}),
|
||||
}),
|
||||
getGitService: vi.fn(),
|
||||
...overrides,
|
||||
} as unknown as Config;
|
||||
mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus());
|
||||
|
||||
@@ -9,6 +9,9 @@ import path from 'node:path';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
type Config,
|
||||
formatCheckpointDisplayList,
|
||||
getToolCallDataSchema,
|
||||
getTruncatedCheckpointNames,
|
||||
performRestore,
|
||||
type ToolCallData,
|
||||
} from '@google/gemini-cli-core';
|
||||
@@ -28,23 +31,7 @@ const HistoryItemSchema = z
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
const ContentSchema = z
|
||||
.object({
|
||||
role: z.string().optional(),
|
||||
parts: z.array(z.record(z.unknown())),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
const ToolCallDataSchema = z.object({
|
||||
history: z.array(HistoryItemSchema).optional(),
|
||||
clientHistory: z.array(ContentSchema).optional(),
|
||||
commitHash: z.string().optional(),
|
||||
toolCall: z.object({
|
||||
name: z.string(),
|
||||
args: z.record(z.unknown()),
|
||||
}),
|
||||
messageId: z.string().optional(),
|
||||
});
|
||||
const ToolCallDataSchema = getToolCallDataSchema(HistoryItemSchema);
|
||||
|
||||
async function restoreAction(
|
||||
context: CommandContext,
|
||||
@@ -78,15 +65,7 @@ async function restoreAction(
|
||||
content: 'No restorable tool calls found.',
|
||||
};
|
||||
}
|
||||
const truncatedFiles = jsonFiles.map((file) => {
|
||||
const components = file.split('.');
|
||||
if (components.length <= 1) {
|
||||
return file;
|
||||
}
|
||||
components.pop();
|
||||
return components.join('.');
|
||||
});
|
||||
const fileList = truncatedFiles.join('\n');
|
||||
const fileList = formatCheckpointDisplayList(jsonFiles);
|
||||
return {
|
||||
type: 'message',
|
||||
messageType: 'info',
|
||||
@@ -171,9 +150,8 @@ async function completion(
|
||||
}
|
||||
try {
|
||||
const files = await fs.readdir(checkpointDir);
|
||||
return files
|
||||
.filter((file) => file.endsWith('.json'))
|
||||
.map((file) => file.replace('.json', ''));
|
||||
const jsonFiles = files.filter((file) => file.endsWith('.json'));
|
||||
return getTruncatedCheckpointNames(jsonFiles);
|
||||
} catch (_err) {
|
||||
return [];
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import type {
|
||||
ThoughtSummary,
|
||||
ToolCallRequestInfo,
|
||||
GeminiErrorEventValue,
|
||||
ToolCallData,
|
||||
} from '@google/gemini-cli-core';
|
||||
import {
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
@@ -34,10 +33,11 @@ import {
|
||||
parseAndFormatApiError,
|
||||
ToolConfirmationOutcome,
|
||||
promptIdContext,
|
||||
WRITE_FILE_TOOL_NAME,
|
||||
tokenLimit,
|
||||
debugLogger,
|
||||
runInDevTraceSpan,
|
||||
EDIT_TOOL_NAMES,
|
||||
processRestorableToolCalls,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { type Part, type PartListUnion, FinishReason } from '@google/genai';
|
||||
import type {
|
||||
@@ -76,8 +76,6 @@ enum StreamProcessingStatus {
|
||||
Error,
|
||||
}
|
||||
|
||||
const EDIT_TOOL_NAMES = new Set(['replace', WRITE_FILE_TOOL_NAME]);
|
||||
|
||||
function showCitations(settings: LoadedSettings): boolean {
|
||||
const enabled = settings?.merged?.ui?.showCitations;
|
||||
if (enabled !== undefined) {
|
||||
@@ -1248,98 +1246,37 @@ export const useGeminiStream = (
|
||||
);
|
||||
|
||||
if (restorableToolCalls.length > 0) {
|
||||
const checkpointDir = storage.getProjectTempCheckpointsDir();
|
||||
|
||||
if (!checkpointDir) {
|
||||
if (!gitService) {
|
||||
onDebugMessage(
|
||||
'Checkpointing is enabled but Git service is not available. Failed to create snapshot. Ensure Git is installed and working properly.',
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
} catch (error) {
|
||||
if (!isNodeError(error) || error.code !== 'EEXIST') {
|
||||
onDebugMessage(
|
||||
`Failed to create checkpoint directory: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
const { checkpointsToWrite, errors } = await processRestorableToolCalls<
|
||||
HistoryItem[]
|
||||
>(
|
||||
restorableToolCalls.map((call) => call.request),
|
||||
gitService,
|
||||
geminiClient,
|
||||
history,
|
||||
);
|
||||
|
||||
if (errors.length > 0) {
|
||||
errors.forEach(onDebugMessage);
|
||||
}
|
||||
|
||||
for (const toolCall of restorableToolCalls) {
|
||||
const filePath = toolCall.request.args['file_path'] as string;
|
||||
if (!filePath) {
|
||||
onDebugMessage(
|
||||
`Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (checkpointsToWrite.size > 0) {
|
||||
const checkpointDir = storage.getProjectTempCheckpointsDir();
|
||||
try {
|
||||
if (!gitService) {
|
||||
onDebugMessage(
|
||||
`Checkpointing is enabled but Git service is not available. Failed to create snapshot for ${filePath}. Ensure Git is installed and working properly.`,
|
||||
);
|
||||
continue;
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
for (const [fileName, content] of checkpointsToWrite) {
|
||||
const filePath = path.join(checkpointDir, fileName);
|
||||
await fs.writeFile(filePath, content);
|
||||
}
|
||||
|
||||
let commitHash: string | undefined;
|
||||
try {
|
||||
commitHash = await gitService.createFileSnapshot(
|
||||
`Snapshot for ${toolCall.request.name}`,
|
||||
);
|
||||
} catch (error) {
|
||||
onDebugMessage(
|
||||
`Failed to create new snapshot: ${getErrorMessage(error)}. Attempting to use current commit.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!commitHash) {
|
||||
commitHash = await gitService.getCurrentCommitHash();
|
||||
}
|
||||
|
||||
if (!commitHash) {
|
||||
onDebugMessage(
|
||||
`Failed to create snapshot for ${filePath}. Checkpointing may not be working properly. Ensure Git is installed and the project directory is accessible.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/:/g, '-')
|
||||
.replace(/\./g, '_');
|
||||
const toolName = toolCall.request.name;
|
||||
const fileName = path.basename(filePath);
|
||||
const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`;
|
||||
const clientHistory = await geminiClient?.getHistory();
|
||||
const toolCallWithSnapshotFilePath = path.join(
|
||||
checkpointDir,
|
||||
toolCallWithSnapshotFileName,
|
||||
);
|
||||
|
||||
const checkpointData: ToolCallData<
|
||||
HistoryItem[],
|
||||
Record<string, unknown>
|
||||
> & { filePath: string } = {
|
||||
history,
|
||||
clientHistory,
|
||||
toolCall: {
|
||||
name: toolCall.request.name,
|
||||
args: toolCall.request.args,
|
||||
},
|
||||
commitHash,
|
||||
filePath,
|
||||
};
|
||||
|
||||
await fs.writeFile(
|
||||
toolCallWithSnapshotFilePath,
|
||||
JSON.stringify(checkpointData, null, 2),
|
||||
);
|
||||
} catch (error) {
|
||||
onDebugMessage(
|
||||
`Failed to create checkpoint for ${filePath}: ${getErrorMessage(
|
||||
error,
|
||||
)}. This may indicate a problem with Git or file system permissions.`,
|
||||
`Failed to write checkpoint file: ${getErrorMessage(error)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import { performRestore, type ToolCallData } from './restore.js';
|
||||
import { performRestore } from './restore.js';
|
||||
import { type ToolCallData } from '../utils/checkpointUtils.js';
|
||||
import type { GitService } from '../services/gitService.js';
|
||||
|
||||
describe('performRestore', () => {
|
||||
|
||||
@@ -4,20 +4,9 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Content } from '@google/genai';
|
||||
import type { GitService } from '../services/gitService.js';
|
||||
import type { CommandActionReturn } from './types.js';
|
||||
|
||||
export interface ToolCallData<HistoryType = unknown, ArgsType = unknown> {
|
||||
history?: HistoryType;
|
||||
clientHistory?: Content[];
|
||||
commitHash?: string;
|
||||
toolCall: {
|
||||
name: string;
|
||||
args: ArgsType;
|
||||
};
|
||||
messageId?: string;
|
||||
}
|
||||
import { type ToolCallData } from '../utils/checkpointUtils.js';
|
||||
|
||||
export async function* performRestore<
|
||||
HistoryType = unknown,
|
||||
|
||||
@@ -107,6 +107,7 @@ export interface ToolCallRequestInfo {
|
||||
args: Record<string, unknown>;
|
||||
isClientInitiated: boolean;
|
||||
prompt_id: string;
|
||||
checkpoint?: string;
|
||||
}
|
||||
|
||||
export interface ToolCallResponseInfo {
|
||||
|
||||
@@ -78,6 +78,7 @@ export * from './utils/debugLogger.js';
|
||||
export * from './utils/events.js';
|
||||
export * from './utils/extensionLoader.js';
|
||||
export * from './utils/package.js';
|
||||
export * from './utils/checkpointUtils.js';
|
||||
|
||||
// Export services
|
||||
export * from './services/fileDiscoveryService.js';
|
||||
|
||||
@@ -32,6 +32,7 @@ const hoistedMockInit = vi.hoisted(() => vi.fn());
|
||||
const hoistedMockRaw = vi.hoisted(() => vi.fn());
|
||||
const hoistedMockAdd = vi.hoisted(() => vi.fn());
|
||||
const hoistedMockCommit = vi.hoisted(() => vi.fn());
|
||||
const hoistedMockStatus = vi.hoisted(() => vi.fn());
|
||||
vi.mock('simple-git', () => ({
|
||||
simpleGit: hoistedMockSimpleGit.mockImplementation(() => ({
|
||||
checkIsRepo: hoistedMockCheckIsRepo,
|
||||
@@ -39,6 +40,7 @@ vi.mock('simple-git', () => ({
|
||||
raw: hoistedMockRaw,
|
||||
add: hoistedMockAdd,
|
||||
commit: hoistedMockCommit,
|
||||
status: hoistedMockStatus,
|
||||
env: hoistedMockEnv,
|
||||
})),
|
||||
CheckRepoActions: { IS_REPO_ROOT: 'is-repo-root' },
|
||||
@@ -89,6 +91,7 @@ describe('GitService', () => {
|
||||
raw: hoistedMockRaw,
|
||||
add: hoistedMockAdd,
|
||||
commit: hoistedMockCommit,
|
||||
status: hoistedMockStatus,
|
||||
}));
|
||||
hoistedMockSimpleGit.mockImplementation(() => ({
|
||||
checkIsRepo: hoistedMockCheckIsRepo,
|
||||
@@ -96,6 +99,7 @@ describe('GitService', () => {
|
||||
raw: hoistedMockRaw,
|
||||
add: hoistedMockAdd,
|
||||
commit: hoistedMockCommit,
|
||||
status: hoistedMockStatus,
|
||||
env: hoistedMockEnv,
|
||||
}));
|
||||
hoistedMockCheckIsRepo.mockResolvedValue(false);
|
||||
@@ -248,6 +252,7 @@ describe('GitService', () => {
|
||||
|
||||
describe('createFileSnapshot', () => {
|
||||
it('should commit with --no-verify flag', async () => {
|
||||
hoistedMockStatus.mockResolvedValue({ isClean: () => false });
|
||||
const service = new GitService(projectRoot, storage);
|
||||
await service.initialize();
|
||||
await service.createFileSnapshot('test commit');
|
||||
@@ -255,5 +260,30 @@ describe('GitService', () => {
|
||||
'--no-verify': null,
|
||||
});
|
||||
});
|
||||
|
||||
it('should create a new commit if there are staged changes', async () => {
|
||||
hoistedMockStatus.mockResolvedValue({ isClean: () => false });
|
||||
hoistedMockCommit.mockResolvedValue({ commit: 'new-commit-hash' });
|
||||
const service = new GitService(projectRoot, storage);
|
||||
const commitHash = await service.createFileSnapshot('test message');
|
||||
expect(hoistedMockAdd).toHaveBeenCalledWith('.');
|
||||
expect(hoistedMockStatus).toHaveBeenCalled();
|
||||
expect(hoistedMockCommit).toHaveBeenCalledWith('test message', {
|
||||
'--no-verify': null,
|
||||
});
|
||||
expect(commitHash).toBe('new-commit-hash');
|
||||
});
|
||||
|
||||
it('should return the current HEAD commit hash if there are no staged changes', async () => {
|
||||
hoistedMockStatus.mockResolvedValue({ isClean: () => true });
|
||||
hoistedMockRaw.mockResolvedValue('current-head-hash');
|
||||
const service = new GitService(projectRoot, storage);
|
||||
const commitHash = await service.createFileSnapshot('test message');
|
||||
expect(hoistedMockAdd).toHaveBeenCalledWith('.');
|
||||
expect(hoistedMockStatus).toHaveBeenCalled();
|
||||
expect(hoistedMockCommit).not.toHaveBeenCalled();
|
||||
expect(hoistedMockRaw).toHaveBeenCalledWith('rev-parse', 'HEAD');
|
||||
expect(commitHash).toBe('current-head-hash');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -112,6 +112,11 @@ export class GitService {
|
||||
try {
|
||||
const repo = this.shadowGitRepository;
|
||||
await repo.add('.');
|
||||
const status = await repo.status();
|
||||
if (status.isClean()) {
|
||||
// If no changes are staged, return the current HEAD commit hash
|
||||
return await this.getCurrentCommitHash();
|
||||
}
|
||||
const commitResult = await repo.commit(message, {
|
||||
'--no-verify': null,
|
||||
});
|
||||
|
||||
@@ -20,3 +20,4 @@ export const READ_MANY_FILES_TOOL_NAME = 'read_many_files';
|
||||
export const READ_FILE_TOOL_NAME = 'read_file';
|
||||
export const LS_TOOL_NAME = 'list_directory';
|
||||
export const MEMORY_TOOL_NAME = 'save_memory';
|
||||
export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]);
|
||||
|
||||
305
packages/core/src/utils/checkpointUtils.test.ts
Normal file
305
packages/core/src/utils/checkpointUtils.test.ts
Normal file
@@ -0,0 +1,305 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
getToolCallDataSchema,
|
||||
generateCheckpointFileName,
|
||||
formatCheckpointDisplayList,
|
||||
getTruncatedCheckpointNames,
|
||||
processRestorableToolCalls,
|
||||
getCheckpointInfoList,
|
||||
} from './checkpointUtils.js';
|
||||
import type { GitService } from '../services/gitService.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
||||
|
||||
describe('checkpoint utils', () => {
|
||||
describe('getToolCallDataSchema', () => {
|
||||
it('should return a schema that validates a basic tool call data object', () => {
|
||||
const schema = getToolCallDataSchema();
|
||||
const validData = {
|
||||
toolCall: { name: 'test-tool', args: { foo: 'bar' } },
|
||||
};
|
||||
const result = schema.safeParse(validData);
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should validate with an optional history schema', () => {
|
||||
const historyItemSchema = z.object({ id: z.number(), event: z.string() });
|
||||
const schema = getToolCallDataSchema(historyItemSchema);
|
||||
const validData = {
|
||||
history: [{ id: 1, event: 'start' }],
|
||||
toolCall: { name: 'test-tool', args: {} },
|
||||
};
|
||||
const result = schema.safeParse(validData);
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
|
||||
it('should fail validation if history items do not match the schema', () => {
|
||||
const historyItemSchema = z.object({ id: z.number(), event: z.string() });
|
||||
const schema = getToolCallDataSchema(historyItemSchema);
|
||||
const invalidData = {
|
||||
history: [{ id: '1', event: 'start' }], // id should be a number
|
||||
toolCall: { name: 'test-tool', args: {} },
|
||||
};
|
||||
const result = schema.safeParse(invalidData);
|
||||
expect(result.success).toBe(false);
|
||||
});
|
||||
|
||||
it('should validate clientHistory with the correct schema', () => {
|
||||
const schema = getToolCallDataSchema();
|
||||
const validData = {
|
||||
clientHistory: [{ role: 'user', parts: [{ text: 'Hello' }] }],
|
||||
toolCall: { name: 'test-tool', args: {} },
|
||||
};
|
||||
const result = schema.safeParse(validData);
|
||||
expect(result.success).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('generateCheckpointFileName', () => {
|
||||
it('should generate a filename with timestamp, basename, and tool name', () => {
|
||||
vi.useFakeTimers();
|
||||
vi.setSystemTime(new Date('2025-01-01T12:00:00.000Z'));
|
||||
const toolCall = {
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { file_path: '/path/to/my-file.txt' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
} as ToolCallRequestInfo;
|
||||
|
||||
const expected = '2025-01-01T12-00-00_000Z-my-file.txt-replace';
|
||||
const actual = generateCheckpointFileName(toolCall);
|
||||
expect(actual).toBe(expected);
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it('should return null if file_path is not in the tool arguments', () => {
|
||||
const toolCall = {
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { some_other_arg: 'value' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'p1',
|
||||
} as ToolCallRequestInfo;
|
||||
|
||||
const actual = generateCheckpointFileName(toolCall);
|
||||
expect(actual).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe('formatCheckpointDisplayList and getTruncatedCheckpointNames', () => {
|
||||
const filenames = [
|
||||
'2025-01-01T12-00-00_000Z-my-file.txt-replace.json',
|
||||
'2025-01-01T13-00-00_000Z-another.js-write_file.json',
|
||||
'no-extension-file',
|
||||
];
|
||||
|
||||
it('getTruncatedCheckpointNames should remove the .json extension', () => {
|
||||
const expected = [
|
||||
'2025-01-01T12-00-00_000Z-my-file.txt-replace',
|
||||
'2025-01-01T13-00-00_000Z-another.js-write_file',
|
||||
'no-extension-file',
|
||||
];
|
||||
const actual = getTruncatedCheckpointNames(filenames);
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
|
||||
it('formatCheckpointDisplayList should return a newline-separated string of truncated names', () => {
|
||||
const expected = [
|
||||
'2025-01-01T12-00-00_000Z-my-file.txt-replace',
|
||||
'2025-01-01T13-00-00_000Z-another.js-write_file',
|
||||
'no-extension-file',
|
||||
].join('\n');
|
||||
const actual = formatCheckpointDisplayList(filenames);
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processRestorableToolCalls', () => {
|
||||
const mockGitService = {
|
||||
createFileSnapshot: vi.fn(),
|
||||
getCurrentCommitHash: vi.fn(),
|
||||
} as unknown as GitService;
|
||||
|
||||
const mockGeminiClient = {
|
||||
getHistory: vi.fn(),
|
||||
} as unknown as GeminiClient;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should create checkpoints for restorable tool calls', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { file_path: 'a.txt' },
|
||||
prompt_id: 'p1',
|
||||
isClientInitiated: false,
|
||||
},
|
||||
] as ToolCallRequestInfo[];
|
||||
|
||||
(mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123');
|
||||
(mockGeminiClient.getHistory as Mock).mockResolvedValue([
|
||||
{ role: 'user', parts: [] },
|
||||
]);
|
||||
|
||||
const { checkpointsToWrite, toolCallToCheckpointMap, errors } =
|
||||
await processRestorableToolCalls(
|
||||
toolCalls,
|
||||
mockGitService,
|
||||
mockGeminiClient,
|
||||
'history-data',
|
||||
);
|
||||
|
||||
expect(errors).toHaveLength(0);
|
||||
expect(checkpointsToWrite.size).toBe(1);
|
||||
expect(toolCallToCheckpointMap.get('1')).toBeDefined();
|
||||
|
||||
const fileName = checkpointsToWrite.values().next().value;
|
||||
expect(fileName).toBeDefined();
|
||||
const fileContent = JSON.parse(fileName!);
|
||||
|
||||
expect(fileContent.commitHash).toBe('hash123');
|
||||
expect(fileContent.history).toBe('history-data');
|
||||
expect(fileContent.clientHistory).toEqual([{ role: 'user', parts: [] }]);
|
||||
expect(fileContent.toolCall.name).toBe('replace');
|
||||
expect(fileContent.messageId).toBe('p1');
|
||||
});
|
||||
|
||||
it('should handle git snapshot failure by using current commit hash', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { file_path: 'a.txt' },
|
||||
prompt_id: 'p1',
|
||||
isClientInitiated: false,
|
||||
},
|
||||
] as ToolCallRequestInfo[];
|
||||
|
||||
(mockGitService.createFileSnapshot as Mock).mockRejectedValue(
|
||||
new Error('Snapshot failed'),
|
||||
);
|
||||
(mockGitService.getCurrentCommitHash as Mock).mockResolvedValue(
|
||||
'fallback-hash',
|
||||
);
|
||||
|
||||
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
|
||||
toolCalls,
|
||||
mockGitService,
|
||||
mockGeminiClient,
|
||||
);
|
||||
|
||||
expect(errors).toHaveLength(1);
|
||||
expect(errors[0]).toContain('Failed to create new snapshot');
|
||||
expect(checkpointsToWrite.size).toBe(1);
|
||||
const value = checkpointsToWrite.values().next().value;
|
||||
expect(value).toBeDefined();
|
||||
const fileContent = JSON.parse(value!);
|
||||
expect(fileContent.commitHash).toBe('fallback-hash');
|
||||
});
|
||||
|
||||
it('should skip tool calls with no file_path', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { not_a_path: 'a.txt' },
|
||||
prompt_id: 'p1',
|
||||
isClientInitiated: false,
|
||||
},
|
||||
] as ToolCallRequestInfo[];
|
||||
(mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123');
|
||||
|
||||
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
|
||||
toolCalls,
|
||||
mockGitService,
|
||||
mockGeminiClient,
|
||||
);
|
||||
|
||||
expect(errors).toHaveLength(1);
|
||||
expect(errors[0]).toContain(
|
||||
'Skipping restorable tool call due to missing file_path',
|
||||
);
|
||||
expect(checkpointsToWrite.size).toBe(0);
|
||||
});
|
||||
|
||||
it('should log an error if git snapshot fails and then skip the tool call', async () => {
|
||||
const toolCalls = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: { file_path: 'a.txt' },
|
||||
prompt_id: 'p1',
|
||||
isClientInitiated: false,
|
||||
},
|
||||
] as ToolCallRequestInfo[];
|
||||
(mockGitService.createFileSnapshot as Mock).mockRejectedValue(
|
||||
new Error('Snapshot failed'),
|
||||
);
|
||||
(mockGitService.getCurrentCommitHash as Mock).mockResolvedValue(
|
||||
undefined,
|
||||
);
|
||||
|
||||
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
|
||||
toolCalls,
|
||||
mockGitService,
|
||||
mockGeminiClient,
|
||||
);
|
||||
|
||||
expect(errors).toHaveLength(2);
|
||||
expect(errors[0]).toContain('Failed to create new snapshot');
|
||||
expect(errors[1]).toContain('Failed to create snapshot for replace');
|
||||
expect(checkpointsToWrite.size).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getCheckpointInfoList', () => {
|
||||
it('should parse valid checkpoint files and return a list of info', () => {
|
||||
const checkpointFiles = new Map([
|
||||
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
|
||||
['checkpoint2.json', JSON.stringify({ messageId: 'msg2' })],
|
||||
]);
|
||||
|
||||
const expected = [
|
||||
{ messageId: 'msg1', checkpoint: 'checkpoint1' },
|
||||
{ messageId: 'msg2', checkpoint: 'checkpoint2' },
|
||||
];
|
||||
|
||||
const actual = getCheckpointInfoList(checkpointFiles);
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should ignore files with invalid JSON', () => {
|
||||
const checkpointFiles = new Map([
|
||||
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
|
||||
['invalid.json', 'not-json'],
|
||||
]);
|
||||
|
||||
const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }];
|
||||
const actual = getCheckpointInfoList(checkpointFiles);
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
|
||||
it('should ignore files that are missing a messageId', () => {
|
||||
const checkpointFiles = new Map([
|
||||
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
|
||||
['no-msg-id.json', JSON.stringify({ other_prop: 'value' })],
|
||||
]);
|
||||
|
||||
const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }];
|
||||
const actual = getCheckpointInfoList(checkpointFiles);
|
||||
expect(actual).toEqual(expected);
|
||||
});
|
||||
});
|
||||
});
|
||||
182
packages/core/src/utils/checkpointUtils.ts
Normal file
182
packages/core/src/utils/checkpointUtils.ts
Normal file
@@ -0,0 +1,182 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as path from 'node:path';
|
||||
import type { GitService } from '../services/gitService.js';
|
||||
import type { GeminiClient } from '../core/client.js';
|
||||
import { getErrorMessage } from './errors.js';
|
||||
import { z } from 'zod';
|
||||
import type { Content } from '@google/genai';
|
||||
import type { ToolCallRequestInfo } from '../core/turn.js';
|
||||
|
||||
export interface ToolCallData<HistoryType = unknown, ArgsType = unknown> {
|
||||
history?: HistoryType;
|
||||
clientHistory?: Content[];
|
||||
commitHash?: string;
|
||||
toolCall: {
|
||||
name: string;
|
||||
args: ArgsType;
|
||||
};
|
||||
messageId?: string;
|
||||
}
|
||||
|
||||
const ContentSchema = z
|
||||
.object({
|
||||
role: z.string().optional(),
|
||||
parts: z.array(z.record(z.unknown())),
|
||||
})
|
||||
.passthrough();
|
||||
|
||||
export function getToolCallDataSchema(historyItemSchema?: z.ZodTypeAny) {
|
||||
const schema = historyItemSchema ?? z.any();
|
||||
|
||||
return z.object({
|
||||
history: z.array(schema).optional(),
|
||||
clientHistory: z.array(ContentSchema).optional(),
|
||||
commitHash: z.string().optional(),
|
||||
toolCall: z.object({
|
||||
name: z.string(),
|
||||
args: z.record(z.unknown()),
|
||||
}),
|
||||
messageId: z.string().optional(),
|
||||
});
|
||||
}
|
||||
|
||||
export function generateCheckpointFileName(
|
||||
toolCall: ToolCallRequestInfo,
|
||||
): string | null {
|
||||
const toolArgs = toolCall.args as Record<string, unknown>;
|
||||
const toolFilePath = toolArgs['file_path'] as string;
|
||||
|
||||
if (!toolFilePath) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/:/g, '-')
|
||||
.replace(/\./g, '_');
|
||||
const toolName = toolCall.name;
|
||||
const fileName = path.basename(toolFilePath);
|
||||
|
||||
return `${timestamp}-${fileName}-${toolName}`;
|
||||
}
|
||||
|
||||
export function formatCheckpointDisplayList(filenames: string[]): string {
|
||||
return getTruncatedCheckpointNames(filenames).join('\n');
|
||||
}
|
||||
|
||||
export function getTruncatedCheckpointNames(filenames: string[]): string[] {
|
||||
return filenames.map((file) => {
|
||||
const components = file.split('.');
|
||||
if (components.length <= 1) {
|
||||
return file;
|
||||
}
|
||||
components.pop();
|
||||
return components.join('.');
|
||||
});
|
||||
}
|
||||
|
||||
export async function processRestorableToolCalls<HistoryType>(
|
||||
toolCalls: ToolCallRequestInfo[],
|
||||
gitService: GitService,
|
||||
geminiClient: GeminiClient,
|
||||
history?: HistoryType,
|
||||
): Promise<{
|
||||
checkpointsToWrite: Map<string, string>;
|
||||
toolCallToCheckpointMap: Map<string, string>;
|
||||
errors: string[];
|
||||
}> {
|
||||
const checkpointsToWrite = new Map<string, string>();
|
||||
const toolCallToCheckpointMap = new Map<string, string>();
|
||||
const errors: string[] = [];
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
try {
|
||||
let commitHash: string | undefined;
|
||||
try {
|
||||
commitHash = await gitService.createFileSnapshot(
|
||||
`Snapshot for ${toolCall.name}`,
|
||||
);
|
||||
} catch (error) {
|
||||
errors.push(
|
||||
`Failed to create new snapshot for ${
|
||||
toolCall.name
|
||||
}: ${getErrorMessage(error)}. Attempting to use current commit.`,
|
||||
);
|
||||
commitHash = await gitService.getCurrentCommitHash();
|
||||
}
|
||||
|
||||
if (!commitHash) {
|
||||
errors.push(
|
||||
`Failed to create snapshot for ${toolCall.name}. Checkpointing may not be working properly. Ensure Git is installed and the project directory is accessible.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const checkpointFileName = generateCheckpointFileName(toolCall);
|
||||
if (!checkpointFileName) {
|
||||
errors.push(
|
||||
`Skipping restorable tool call due to missing file_path: ${toolCall.name}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const clientHistory = await geminiClient.getHistory();
|
||||
const checkpointData: ToolCallData<HistoryType> = {
|
||||
history,
|
||||
clientHistory,
|
||||
toolCall: {
|
||||
name: toolCall.name,
|
||||
args: toolCall.args,
|
||||
},
|
||||
commitHash,
|
||||
messageId: toolCall.prompt_id,
|
||||
};
|
||||
|
||||
const fileName = `${checkpointFileName}.json`;
|
||||
checkpointsToWrite.set(fileName, JSON.stringify(checkpointData, null, 2));
|
||||
toolCallToCheckpointMap.set(
|
||||
toolCall.callId,
|
||||
fileName.replace('.json', ''),
|
||||
);
|
||||
} catch (error) {
|
||||
errors.push(
|
||||
`Failed to create checkpoint for ${toolCall.name}: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return { checkpointsToWrite, toolCallToCheckpointMap, errors };
|
||||
}
|
||||
|
||||
export interface CheckpointInfo {
|
||||
messageId: string;
|
||||
checkpoint: string;
|
||||
}
|
||||
|
||||
export function getCheckpointInfoList(
|
||||
checkpointFiles: Map<string, string>,
|
||||
): CheckpointInfo[] {
|
||||
const checkpointInfoList: CheckpointInfo[] = [];
|
||||
|
||||
for (const [file, content] of checkpointFiles) {
|
||||
try {
|
||||
const toolCallData = JSON.parse(content) as ToolCallData;
|
||||
if (toolCallData.messageId) {
|
||||
checkpointInfoList.push({
|
||||
messageId: toolCallData.messageId,
|
||||
checkpoint: file.replace('.json', ''),
|
||||
});
|
||||
}
|
||||
} catch (_e) {
|
||||
// Ignore invalid JSON files
|
||||
}
|
||||
}
|
||||
return checkpointInfoList;
|
||||
}
|
||||
Reference in New Issue
Block a user