feat(a2a): Introduce restore command for a2a server (#13015)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Shreya Keshive <shreyakeshive@google.com>
This commit is contained in:
Coco Sheng
2025-12-09 10:08:23 -05:00
committed by GitHub
parent afd4829f10
commit 1f813f6a06
23 changed files with 1173 additions and 148 deletions
+2 -1
View File
@@ -5,7 +5,8 @@
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { performRestore, type ToolCallData } from './restore.js';
import { performRestore } from './restore.js';
import { type ToolCallData } from '../utils/checkpointUtils.js';
import type { GitService } from '../services/gitService.js';
describe('performRestore', () => {
+1 -12
View File
@@ -4,20 +4,9 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { Content } from '@google/genai';
import type { GitService } from '../services/gitService.js';
import type { CommandActionReturn } from './types.js';
export interface ToolCallData<HistoryType = unknown, ArgsType = unknown> {
history?: HistoryType;
clientHistory?: Content[];
commitHash?: string;
toolCall: {
name: string;
args: ArgsType;
};
messageId?: string;
}
import { type ToolCallData } from '../utils/checkpointUtils.js';
export async function* performRestore<
HistoryType = unknown,
+1
View File
@@ -107,6 +107,7 @@ export interface ToolCallRequestInfo {
args: Record<string, unknown>;
isClientInitiated: boolean;
prompt_id: string;
checkpoint?: string;
}
export interface ToolCallResponseInfo {
+1
View File
@@ -78,6 +78,7 @@ export * from './utils/debugLogger.js';
export * from './utils/events.js';
export * from './utils/extensionLoader.js';
export * from './utils/package.js';
export * from './utils/checkpointUtils.js';
// Export services
export * from './services/fileDiscoveryService.js';
@@ -32,6 +32,7 @@ const hoistedMockInit = vi.hoisted(() => vi.fn());
const hoistedMockRaw = vi.hoisted(() => vi.fn());
const hoistedMockAdd = vi.hoisted(() => vi.fn());
const hoistedMockCommit = vi.hoisted(() => vi.fn());
const hoistedMockStatus = vi.hoisted(() => vi.fn());
vi.mock('simple-git', () => ({
simpleGit: hoistedMockSimpleGit.mockImplementation(() => ({
checkIsRepo: hoistedMockCheckIsRepo,
@@ -39,6 +40,7 @@ vi.mock('simple-git', () => ({
raw: hoistedMockRaw,
add: hoistedMockAdd,
commit: hoistedMockCommit,
status: hoistedMockStatus,
env: hoistedMockEnv,
})),
CheckRepoActions: { IS_REPO_ROOT: 'is-repo-root' },
@@ -89,6 +91,7 @@ describe('GitService', () => {
raw: hoistedMockRaw,
add: hoistedMockAdd,
commit: hoistedMockCommit,
status: hoistedMockStatus,
}));
hoistedMockSimpleGit.mockImplementation(() => ({
checkIsRepo: hoistedMockCheckIsRepo,
@@ -96,6 +99,7 @@ describe('GitService', () => {
raw: hoistedMockRaw,
add: hoistedMockAdd,
commit: hoistedMockCommit,
status: hoistedMockStatus,
env: hoistedMockEnv,
}));
hoistedMockCheckIsRepo.mockResolvedValue(false);
@@ -248,6 +252,7 @@ describe('GitService', () => {
describe('createFileSnapshot', () => {
it('should commit with --no-verify flag', async () => {
hoistedMockStatus.mockResolvedValue({ isClean: () => false });
const service = new GitService(projectRoot, storage);
await service.initialize();
await service.createFileSnapshot('test commit');
@@ -255,5 +260,30 @@ describe('GitService', () => {
'--no-verify': null,
});
});
it('should create a new commit if there are staged changes', async () => {
hoistedMockStatus.mockResolvedValue({ isClean: () => false });
hoistedMockCommit.mockResolvedValue({ commit: 'new-commit-hash' });
const service = new GitService(projectRoot, storage);
const commitHash = await service.createFileSnapshot('test message');
expect(hoistedMockAdd).toHaveBeenCalledWith('.');
expect(hoistedMockStatus).toHaveBeenCalled();
expect(hoistedMockCommit).toHaveBeenCalledWith('test message', {
'--no-verify': null,
});
expect(commitHash).toBe('new-commit-hash');
});
it('should return the current HEAD commit hash if there are no staged changes', async () => {
hoistedMockStatus.mockResolvedValue({ isClean: () => true });
hoistedMockRaw.mockResolvedValue('current-head-hash');
const service = new GitService(projectRoot, storage);
const commitHash = await service.createFileSnapshot('test message');
expect(hoistedMockAdd).toHaveBeenCalledWith('.');
expect(hoistedMockStatus).toHaveBeenCalled();
expect(hoistedMockCommit).not.toHaveBeenCalled();
expect(hoistedMockRaw).toHaveBeenCalledWith('rev-parse', 'HEAD');
expect(commitHash).toBe('current-head-hash');
});
});
});
+5
View File
@@ -112,6 +112,11 @@ export class GitService {
try {
const repo = this.shadowGitRepository;
await repo.add('.');
const status = await repo.status();
if (status.isClean()) {
// If no changes are staged, return the current HEAD commit hash
return await this.getCurrentCommitHash();
}
const commitResult = await repo.commit(message, {
'--no-verify': null,
});
+1
View File
@@ -20,3 +20,4 @@ export const READ_MANY_FILES_TOOL_NAME = 'read_many_files';
export const READ_FILE_TOOL_NAME = 'read_file';
export const LS_TOOL_NAME = 'list_directory';
export const MEMORY_TOOL_NAME = 'save_memory';
export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]);
@@ -0,0 +1,305 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest';
import { z } from 'zod';
import {
getToolCallDataSchema,
generateCheckpointFileName,
formatCheckpointDisplayList,
getTruncatedCheckpointNames,
processRestorableToolCalls,
getCheckpointInfoList,
} from './checkpointUtils.js';
import type { GitService } from '../services/gitService.js';
import type { GeminiClient } from '../core/client.js';
import type { ToolCallRequestInfo } from '../core/turn.js';
describe('checkpoint utils', () => {
describe('getToolCallDataSchema', () => {
it('should return a schema that validates a basic tool call data object', () => {
const schema = getToolCallDataSchema();
const validData = {
toolCall: { name: 'test-tool', args: { foo: 'bar' } },
};
const result = schema.safeParse(validData);
expect(result.success).toBe(true);
});
it('should validate with an optional history schema', () => {
const historyItemSchema = z.object({ id: z.number(), event: z.string() });
const schema = getToolCallDataSchema(historyItemSchema);
const validData = {
history: [{ id: 1, event: 'start' }],
toolCall: { name: 'test-tool', args: {} },
};
const result = schema.safeParse(validData);
expect(result.success).toBe(true);
});
it('should fail validation if history items do not match the schema', () => {
const historyItemSchema = z.object({ id: z.number(), event: z.string() });
const schema = getToolCallDataSchema(historyItemSchema);
const invalidData = {
history: [{ id: '1', event: 'start' }], // id should be a number
toolCall: { name: 'test-tool', args: {} },
};
const result = schema.safeParse(invalidData);
expect(result.success).toBe(false);
});
it('should validate clientHistory with the correct schema', () => {
const schema = getToolCallDataSchema();
const validData = {
clientHistory: [{ role: 'user', parts: [{ text: 'Hello' }] }],
toolCall: { name: 'test-tool', args: {} },
};
const result = schema.safeParse(validData);
expect(result.success).toBe(true);
});
});
describe('generateCheckpointFileName', () => {
it('should generate a filename with timestamp, basename, and tool name', () => {
vi.useFakeTimers();
vi.setSystemTime(new Date('2025-01-01T12:00:00.000Z'));
const toolCall = {
callId: '1',
name: 'replace',
args: { file_path: '/path/to/my-file.txt' },
isClientInitiated: false,
prompt_id: 'p1',
} as ToolCallRequestInfo;
const expected = '2025-01-01T12-00-00_000Z-my-file.txt-replace';
const actual = generateCheckpointFileName(toolCall);
expect(actual).toBe(expected);
vi.useRealTimers();
});
it('should return null if file_path is not in the tool arguments', () => {
const toolCall = {
callId: '1',
name: 'replace',
args: { some_other_arg: 'value' },
isClientInitiated: false,
prompt_id: 'p1',
} as ToolCallRequestInfo;
const actual = generateCheckpointFileName(toolCall);
expect(actual).toBeNull();
});
});
describe('formatCheckpointDisplayList and getTruncatedCheckpointNames', () => {
const filenames = [
'2025-01-01T12-00-00_000Z-my-file.txt-replace.json',
'2025-01-01T13-00-00_000Z-another.js-write_file.json',
'no-extension-file',
];
it('getTruncatedCheckpointNames should remove the .json extension', () => {
const expected = [
'2025-01-01T12-00-00_000Z-my-file.txt-replace',
'2025-01-01T13-00-00_000Z-another.js-write_file',
'no-extension-file',
];
const actual = getTruncatedCheckpointNames(filenames);
expect(actual).toEqual(expected);
});
it('formatCheckpointDisplayList should return a newline-separated string of truncated names', () => {
const expected = [
'2025-01-01T12-00-00_000Z-my-file.txt-replace',
'2025-01-01T13-00-00_000Z-another.js-write_file',
'no-extension-file',
].join('\n');
const actual = formatCheckpointDisplayList(filenames);
expect(actual).toEqual(expected);
});
});
describe('processRestorableToolCalls', () => {
const mockGitService = {
createFileSnapshot: vi.fn(),
getCurrentCommitHash: vi.fn(),
} as unknown as GitService;
const mockGeminiClient = {
getHistory: vi.fn(),
} as unknown as GeminiClient;
beforeEach(() => {
vi.clearAllMocks();
});
it('should create checkpoints for restorable tool calls', async () => {
const toolCalls = [
{
callId: '1',
name: 'replace',
args: { file_path: 'a.txt' },
prompt_id: 'p1',
isClientInitiated: false,
},
] as ToolCallRequestInfo[];
(mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123');
(mockGeminiClient.getHistory as Mock).mockResolvedValue([
{ role: 'user', parts: [] },
]);
const { checkpointsToWrite, toolCallToCheckpointMap, errors } =
await processRestorableToolCalls(
toolCalls,
mockGitService,
mockGeminiClient,
'history-data',
);
expect(errors).toHaveLength(0);
expect(checkpointsToWrite.size).toBe(1);
expect(toolCallToCheckpointMap.get('1')).toBeDefined();
const fileName = checkpointsToWrite.values().next().value;
expect(fileName).toBeDefined();
const fileContent = JSON.parse(fileName!);
expect(fileContent.commitHash).toBe('hash123');
expect(fileContent.history).toBe('history-data');
expect(fileContent.clientHistory).toEqual([{ role: 'user', parts: [] }]);
expect(fileContent.toolCall.name).toBe('replace');
expect(fileContent.messageId).toBe('p1');
});
it('should handle git snapshot failure by using current commit hash', async () => {
const toolCalls = [
{
callId: '1',
name: 'replace',
args: { file_path: 'a.txt' },
prompt_id: 'p1',
isClientInitiated: false,
},
] as ToolCallRequestInfo[];
(mockGitService.createFileSnapshot as Mock).mockRejectedValue(
new Error('Snapshot failed'),
);
(mockGitService.getCurrentCommitHash as Mock).mockResolvedValue(
'fallback-hash',
);
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
toolCalls,
mockGitService,
mockGeminiClient,
);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Failed to create new snapshot');
expect(checkpointsToWrite.size).toBe(1);
const value = checkpointsToWrite.values().next().value;
expect(value).toBeDefined();
const fileContent = JSON.parse(value!);
expect(fileContent.commitHash).toBe('fallback-hash');
});
it('should skip tool calls with no file_path', async () => {
const toolCalls = [
{
callId: '1',
name: 'replace',
args: { not_a_path: 'a.txt' },
prompt_id: 'p1',
isClientInitiated: false,
},
] as ToolCallRequestInfo[];
(mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123');
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
toolCalls,
mockGitService,
mockGeminiClient,
);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain(
'Skipping restorable tool call due to missing file_path',
);
expect(checkpointsToWrite.size).toBe(0);
});
it('should log an error if git snapshot fails and then skip the tool call', async () => {
const toolCalls = [
{
callId: '1',
name: 'replace',
args: { file_path: 'a.txt' },
prompt_id: 'p1',
isClientInitiated: false,
},
] as ToolCallRequestInfo[];
(mockGitService.createFileSnapshot as Mock).mockRejectedValue(
new Error('Snapshot failed'),
);
(mockGitService.getCurrentCommitHash as Mock).mockResolvedValue(
undefined,
);
const { checkpointsToWrite, errors } = await processRestorableToolCalls(
toolCalls,
mockGitService,
mockGeminiClient,
);
expect(errors).toHaveLength(2);
expect(errors[0]).toContain('Failed to create new snapshot');
expect(errors[1]).toContain('Failed to create snapshot for replace');
expect(checkpointsToWrite.size).toBe(0);
});
});
describe('getCheckpointInfoList', () => {
it('should parse valid checkpoint files and return a list of info', () => {
const checkpointFiles = new Map([
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
['checkpoint2.json', JSON.stringify({ messageId: 'msg2' })],
]);
const expected = [
{ messageId: 'msg1', checkpoint: 'checkpoint1' },
{ messageId: 'msg2', checkpoint: 'checkpoint2' },
];
const actual = getCheckpointInfoList(checkpointFiles);
expect(actual).toEqual(expected);
});
it('should ignore files with invalid JSON', () => {
const checkpointFiles = new Map([
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
['invalid.json', 'not-json'],
]);
const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }];
const actual = getCheckpointInfoList(checkpointFiles);
expect(actual).toEqual(expected);
});
it('should ignore files that are missing a messageId', () => {
const checkpointFiles = new Map([
['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })],
['no-msg-id.json', JSON.stringify({ other_prop: 'value' })],
]);
const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }];
const actual = getCheckpointInfoList(checkpointFiles);
expect(actual).toEqual(expected);
});
});
});
+182
View File
@@ -0,0 +1,182 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as path from 'node:path';
import type { GitService } from '../services/gitService.js';
import type { GeminiClient } from '../core/client.js';
import { getErrorMessage } from './errors.js';
import { z } from 'zod';
import type { Content } from '@google/genai';
import type { ToolCallRequestInfo } from '../core/turn.js';
export interface ToolCallData<HistoryType = unknown, ArgsType = unknown> {
history?: HistoryType;
clientHistory?: Content[];
commitHash?: string;
toolCall: {
name: string;
args: ArgsType;
};
messageId?: string;
}
const ContentSchema = z
.object({
role: z.string().optional(),
parts: z.array(z.record(z.unknown())),
})
.passthrough();
export function getToolCallDataSchema(historyItemSchema?: z.ZodTypeAny) {
const schema = historyItemSchema ?? z.any();
return z.object({
history: z.array(schema).optional(),
clientHistory: z.array(ContentSchema).optional(),
commitHash: z.string().optional(),
toolCall: z.object({
name: z.string(),
args: z.record(z.unknown()),
}),
messageId: z.string().optional(),
});
}
export function generateCheckpointFileName(
toolCall: ToolCallRequestInfo,
): string | null {
const toolArgs = toolCall.args as Record<string, unknown>;
const toolFilePath = toolArgs['file_path'] as string;
if (!toolFilePath) {
return null;
}
const timestamp = new Date()
.toISOString()
.replace(/:/g, '-')
.replace(/\./g, '_');
const toolName = toolCall.name;
const fileName = path.basename(toolFilePath);
return `${timestamp}-${fileName}-${toolName}`;
}
export function formatCheckpointDisplayList(filenames: string[]): string {
return getTruncatedCheckpointNames(filenames).join('\n');
}
export function getTruncatedCheckpointNames(filenames: string[]): string[] {
return filenames.map((file) => {
const components = file.split('.');
if (components.length <= 1) {
return file;
}
components.pop();
return components.join('.');
});
}
export async function processRestorableToolCalls<HistoryType>(
toolCalls: ToolCallRequestInfo[],
gitService: GitService,
geminiClient: GeminiClient,
history?: HistoryType,
): Promise<{
checkpointsToWrite: Map<string, string>;
toolCallToCheckpointMap: Map<string, string>;
errors: string[];
}> {
const checkpointsToWrite = new Map<string, string>();
const toolCallToCheckpointMap = new Map<string, string>();
const errors: string[] = [];
for (const toolCall of toolCalls) {
try {
let commitHash: string | undefined;
try {
commitHash = await gitService.createFileSnapshot(
`Snapshot for ${toolCall.name}`,
);
} catch (error) {
errors.push(
`Failed to create new snapshot for ${
toolCall.name
}: ${getErrorMessage(error)}. Attempting to use current commit.`,
);
commitHash = await gitService.getCurrentCommitHash();
}
if (!commitHash) {
errors.push(
`Failed to create snapshot for ${toolCall.name}. Checkpointing may not be working properly. Ensure Git is installed and the project directory is accessible.`,
);
continue;
}
const checkpointFileName = generateCheckpointFileName(toolCall);
if (!checkpointFileName) {
errors.push(
`Skipping restorable tool call due to missing file_path: ${toolCall.name}`,
);
continue;
}
const clientHistory = await geminiClient.getHistory();
const checkpointData: ToolCallData<HistoryType> = {
history,
clientHistory,
toolCall: {
name: toolCall.name,
args: toolCall.args,
},
commitHash,
messageId: toolCall.prompt_id,
};
const fileName = `${checkpointFileName}.json`;
checkpointsToWrite.set(fileName, JSON.stringify(checkpointData, null, 2));
toolCallToCheckpointMap.set(
toolCall.callId,
fileName.replace('.json', ''),
);
} catch (error) {
errors.push(
`Failed to create checkpoint for ${toolCall.name}: ${getErrorMessage(
error,
)}`,
);
}
}
return { checkpointsToWrite, toolCallToCheckpointMap, errors };
}
export interface CheckpointInfo {
messageId: string;
checkpoint: string;
}
export function getCheckpointInfoList(
checkpointFiles: Map<string, string>,
): CheckpointInfo[] {
const checkpointInfoList: CheckpointInfo[] = [];
for (const [file, content] of checkpointFiles) {
try {
const toolCallData = JSON.parse(content) as ToolCallData;
if (toolCallData.messageId) {
checkpointInfoList.push({
messageId: toolCallData.messageId,
checkpoint: file.replace('.json', ''),
});
}
} catch (_e) {
// Ignore invalid JSON files
}
}
return checkpointInfoList;
}