diff --git a/packages/a2a-server/src/agent/task.test.ts b/packages/a2a-server/src/agent/task.test.ts index 47448b8c49..b4a342707f 100644 --- a/packages/a2a-server/src/agent/task.test.ts +++ b/packages/a2a-server/src/agent/task.test.ts @@ -18,6 +18,7 @@ import { GeminiEventType, type Config, type ToolCallRequestInfo, + type GitService, type CompletedToolCall, } from '@google/gemini-cli-core'; import { createMockConfig } from '../utils/testing_utils.js'; @@ -25,6 +26,17 @@ import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server'; import { CoderAgentEvent } from '../types.js'; import type { ToolCall } from '@google/gemini-cli-core'; +const mockProcessRestorableToolCalls = vi.hoisted(() => vi.fn()); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + processRestorableToolCalls: mockProcessRestorableToolCalls, + }; +}); + describe('Task', () => { it('scheduleToolCalls should not modify the input requests array', async () => { const mockConfig = createMockConfig(); @@ -72,6 +84,141 @@ describe('Task', () => { expect(requests).toEqual(originalRequests); }); + describe('scheduleToolCalls', () => { + const mockConfig = createMockConfig(); + const mockEventBus: ExecutionEventBus = { + publish: vi.fn(), + on: vi.fn(), + off: vi.fn(), + once: vi.fn(), + removeAllListeners: vi.fn(), + finished: vi.fn(), + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should not create a checkpoint if no restorable tools are called', async () => { + // @ts-expect-error - Calling private constructor for test purposes. + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + const requests: ToolCallRequestInfo[] = [ + { + callId: '1', + name: 'run_shell_command', + args: { command: 'ls' }, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + ]; + const abortController = new AbortController(); + await task.scheduleToolCalls(requests, abortController.signal); + expect(mockProcessRestorableToolCalls).not.toHaveBeenCalled(); + }); + + it('should create a checkpoint if a restorable tool is called', async () => { + const mockConfig = createMockConfig({ + getCheckpointingEnabled: () => true, + getGitService: () => Promise.resolve({} as GitService), + }); + mockProcessRestorableToolCalls.mockResolvedValue({ + checkpointsToWrite: new Map([['test.json', 'test content']]), + toolCallToCheckpointMap: new Map(), + errors: [], + }); + // @ts-expect-error - Calling private constructor for test purposes. + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + const requests: ToolCallRequestInfo[] = [ + { + callId: '1', + name: 'replace', + args: { + file_path: 'test.txt', + old_string: 'old', + new_string: 'new', + }, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + ]; + const abortController = new AbortController(); + await task.scheduleToolCalls(requests, abortController.signal); + expect(mockProcessRestorableToolCalls).toHaveBeenCalledOnce(); + }); + + it('should process all restorable tools for checkpointing in a single batch', async () => { + const mockConfig = createMockConfig({ + getCheckpointingEnabled: () => true, + getGitService: () => Promise.resolve({} as GitService), + }); + mockProcessRestorableToolCalls.mockResolvedValue({ + checkpointsToWrite: new Map([ + ['test1.json', 'test content 1'], + ['test2.json', 'test content 2'], + ]), + toolCallToCheckpointMap: new Map([ + ['1', 'test1'], + ['2', 'test2'], + ]), + errors: [], + }); + // @ts-expect-error - Calling private constructor for test purposes. + const task = new Task( + 'task-id', + 'context-id', + mockConfig as Config, + mockEventBus, + ); + const requests: ToolCallRequestInfo[] = [ + { + callId: '1', + name: 'replace', + args: { + file_path: 'test.txt', + old_string: 'old', + new_string: 'new', + }, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }, + { + callId: '2', + name: 'write_file', + args: { file_path: 'test2.txt', content: 'new content' }, + isClientInitiated: false, + prompt_id: 'prompt-id-2', + }, + { + callId: '3', + name: 'not_restorable', + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-3', + }, + ]; + const abortController = new AbortController(); + await task.scheduleToolCalls(requests, abortController.signal); + expect(mockProcessRestorableToolCalls).toHaveBeenCalledExactlyOnceWith( + [ + expect.objectContaining({ callId: '1' }), + expect.objectContaining({ callId: '2' }), + ], + expect.anything(), + expect.anything(), + ); + }); + }); + describe('acceptAgentMessage', () => { it('should set currentTraceId when event has traceId', async () => { const mockConfig = createMockConfig(); diff --git a/packages/a2a-server/src/agent/task.ts b/packages/a2a-server/src/agent/task.ts index cca753b242..e37965af6a 100644 --- a/packages/a2a-server/src/agent/task.ts +++ b/packages/a2a-server/src/agent/task.ts @@ -27,6 +27,8 @@ import { type Config, type UserTierId, type AnsiOutput, + EDIT_TOOL_NAMES, + processRestorableToolCalls, } from '@google/gemini-cli-core'; import type { RequestContext } from '@a2a-js/sdk/server'; import { type ExecutionEventBus } from '@a2a-js/sdk/server'; @@ -40,7 +42,8 @@ import type { } from '@a2a-js/sdk'; import { v4 as uuidv4 } from 'uuid'; import { logger } from '../utils/logger.js'; -import * as fs from 'node:fs'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; import { CoderAgentEvent } from '../types.js'; import type { CoderAgentMessage, @@ -511,7 +514,7 @@ export class Task { new_string: string, ): Promise { try { - const currentContent = fs.readFileSync(file_path, 'utf8'); + const currentContent = await fs.readFile(file_path, 'utf8'); return this._applyReplacement( currentContent, old_string, @@ -554,6 +557,44 @@ export class Task { return; } + // Set checkpoint file before any file modification tool executes + const restorableToolCalls = requests.filter((request) => + EDIT_TOOL_NAMES.has(request.name), + ); + + if (restorableToolCalls.length > 0) { + const gitService = await this.config.getGitService(); + if (gitService) { + const { checkpointsToWrite, toolCallToCheckpointMap, errors } = + await processRestorableToolCalls( + restorableToolCalls, + gitService, + this.geminiClient, + ); + + if (errors.length > 0) { + errors.forEach((error) => logger.error(error)); + } + + if (checkpointsToWrite.size > 0) { + const checkpointDir = + this.config.storage.getProjectTempCheckpointsDir(); + await fs.mkdir(checkpointDir, { recursive: true }); + for (const [fileName, content] of checkpointsToWrite) { + const filePath = path.join(checkpointDir, fileName); + await fs.writeFile(filePath, content); + } + } + + for (const request of requests) { + const checkpoint = toolCallToCheckpointMap.get(request.callId); + if (checkpoint) { + request.checkpoint = checkpoint; + } + } + } + } + const updatedRequests = await Promise.all( requests.map(async (request) => { if ( diff --git a/packages/a2a-server/src/commands/command-registry.ts b/packages/a2a-server/src/commands/command-registry.ts index 658193a603..964a2f5f7b 100644 --- a/packages/a2a-server/src/commands/command-registry.ts +++ b/packages/a2a-server/src/commands/command-registry.ts @@ -5,6 +5,7 @@ */ import { ExtensionsCommand } from './extensions.js'; +import { RestoreCommand } from './restore.js'; import type { Command } from './types.js'; class CommandRegistry { @@ -12,6 +13,7 @@ class CommandRegistry { constructor() { this.register(new ExtensionsCommand()); + this.register(new RestoreCommand()); } register(command: Command) { diff --git a/packages/a2a-server/src/commands/extensions.test.ts b/packages/a2a-server/src/commands/extensions.test.ts index ffccb017c9..611c51b161 100644 --- a/packages/a2a-server/src/commands/extensions.test.ts +++ b/packages/a2a-server/src/commands/extensions.test.ts @@ -6,7 +6,7 @@ import { describe, it, expect, vi } from 'vitest'; import { ExtensionsCommand, ListExtensionsCommand } from './extensions.js'; -import type { Config } from '@google/gemini-cli-core'; +import type { CommandContext } from './types.js'; const mockListExtensions = vi.hoisted(() => vi.fn()); vi.mock('@google/gemini-cli-core', async (importOriginal) => { @@ -42,14 +42,14 @@ describe('ExtensionsCommand', () => { it('should default to listing extensions', async () => { const command = new ExtensionsCommand(); - const mockConfig = {} as Config; + const mockConfig = { config: {} } as CommandContext; const mockExtensions = [{ name: 'ext1' }]; mockListExtensions.mockReturnValue(mockExtensions); const result = await command.execute(mockConfig, []); expect(result).toEqual({ name: 'extensions list', data: mockExtensions }); - expect(mockListExtensions).toHaveBeenCalledWith(mockConfig); + expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config); }); }); @@ -61,19 +61,19 @@ describe('ListExtensionsCommand', () => { it('should call listExtensions with the provided config', async () => { const command = new ListExtensionsCommand(); - const mockConfig = {} as Config; + const mockConfig = { config: {} } as CommandContext; const mockExtensions = [{ name: 'ext1' }]; mockListExtensions.mockReturnValue(mockExtensions); const result = await command.execute(mockConfig, []); expect(result).toEqual({ name: 'extensions list', data: mockExtensions }); - expect(mockListExtensions).toHaveBeenCalledWith(mockConfig); + expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config); }); it('should return a message when no extensions are installed', async () => { const command = new ListExtensionsCommand(); - const mockConfig = {} as Config; + const mockConfig = { config: {} } as CommandContext; mockListExtensions.mockReturnValue([]); const result = await command.execute(mockConfig, []); @@ -82,6 +82,6 @@ describe('ListExtensionsCommand', () => { name: 'extensions list', data: 'No extensions installed.', }); - expect(mockListExtensions).toHaveBeenCalledWith(mockConfig); + expect(mockListExtensions).toHaveBeenCalledWith(mockConfig.config); }); }); diff --git a/packages/a2a-server/src/commands/extensions.ts b/packages/a2a-server/src/commands/extensions.ts index 91893cd55b..27149eb3c7 100644 --- a/packages/a2a-server/src/commands/extensions.ts +++ b/packages/a2a-server/src/commands/extensions.ts @@ -4,8 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { listExtensions, type Config } from '@google/gemini-cli-core'; -import type { Command, CommandExecutionResponse } from './types.js'; +import { listExtensions } from '@google/gemini-cli-core'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; export class ExtensionsCommand implements Command { readonly name = 'extensions'; @@ -14,10 +18,10 @@ export class ExtensionsCommand implements Command { readonly topLevel = true; async execute( - config: Config, + context: CommandContext, _: string[], ): Promise { - return new ListExtensionsCommand().execute(config, _); + return new ListExtensionsCommand().execute(context, _); } } @@ -26,10 +30,10 @@ export class ListExtensionsCommand implements Command { readonly description = 'Lists all installed extensions.'; async execute( - config: Config, + context: CommandContext, _: string[], ): Promise { - const extensions = listExtensions(config); + const extensions = listExtensions(context.config); const data = extensions.length ? extensions : 'No extensions installed.'; return { name: this.name, data }; diff --git a/packages/a2a-server/src/commands/restore.test.ts b/packages/a2a-server/src/commands/restore.test.ts new file mode 100644 index 0000000000..a655f36a74 --- /dev/null +++ b/packages/a2a-server/src/commands/restore.test.ts @@ -0,0 +1,137 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { RestoreCommand, ListCheckpointsCommand } from './restore.js'; +import type { CommandContext } from './types.js'; +import type { Config } from '@google/gemini-cli-core'; +import { createMockConfig } from '../utils/testing_utils.js'; + +beforeEach(() => { + vi.clearAllMocks(); +}); + +const mockPerformRestore = vi.hoisted(() => vi.fn()); +const mockLoggerInfo = vi.hoisted(() => vi.fn()); +const mockGetCheckpointInfoList = vi.hoisted(() => vi.fn()); + +vi.mock('@google/gemini-cli-core', async (importOriginal) => { + const original = + await importOriginal(); + return { + ...original, + performRestore: mockPerformRestore, + getCheckpointInfoList: mockGetCheckpointInfoList, + }; +}); + +const mockFs = vi.hoisted(() => ({ + readFile: vi.fn(), + readdir: vi.fn(), + mkdir: vi.fn(), +})); + +vi.mock('node:fs/promises', () => mockFs); + +vi.mock('../utils/logger.js', () => ({ + logger: { + info: mockLoggerInfo, + }, +})); + +describe('RestoreCommand', () => { + const mockConfig = { + config: createMockConfig() as Config, + git: {}, + } as CommandContext; + + it('should return error if no checkpoint name is provided', async () => { + const command = new RestoreCommand(); + const result = await command.execute(mockConfig, []); + expect(result.data).toEqual({ + type: 'message', + messageType: 'error', + content: 'Please provide a checkpoint name to restore.', + }); + }); + + it('should restore a checkpoint when a valid file is provided', async () => { + const command = new RestoreCommand(); + const toolCallData = { + toolCall: { + name: 'test-tool', + args: {}, + }, + history: [], + clientHistory: [], + commitHash: '123', + }; + mockFs.readFile.mockResolvedValue(JSON.stringify(toolCallData)); + const restoreContent = { + type: 'message', + messageType: 'info', + content: 'Restored', + }; + mockPerformRestore.mockReturnValue( + (async function* () { + yield restoreContent; + })(), + ); + const result = await command.execute(mockConfig, ['checkpoint1.json']); + expect(result.data).toEqual([restoreContent]); + }); + + it('should show "file not found" error for a non-existent checkpoint', async () => { + const command = new RestoreCommand(); + const error = new Error('File not found'); + (error as NodeJS.ErrnoException).code = 'ENOENT'; + mockFs.readFile.mockRejectedValue(error); + const result = await command.execute(mockConfig, ['checkpoint2.json']); + expect(result.data).toEqual({ + type: 'message', + messageType: 'error', + content: 'File not found: checkpoint2.json', + }); + }); + + it('should handle invalid JSON in checkpoint file', async () => { + const command = new RestoreCommand(); + mockFs.readFile.mockResolvedValue('invalid json'); + const result = await command.execute(mockConfig, ['checkpoint1.json']); + expect((result.data as { content: string }).content).toContain( + 'An unexpected error occurred during restore.', + ); + }); +}); + +describe('ListCheckpointsCommand', () => { + const mockConfig = { + config: createMockConfig() as Config, + } as CommandContext; + + it('should list all available checkpoints', async () => { + const command = new ListCheckpointsCommand(); + const checkpointInfo = [{ file: 'checkpoint1.json', description: 'Test' }]; + mockFs.readdir.mockResolvedValue(['checkpoint1.json']); + mockFs.readFile.mockResolvedValue( + JSON.stringify({ toolCall: { name: 'Test', args: {} } }), + ); + mockGetCheckpointInfoList.mockReturnValue(checkpointInfo); + const result = await command.execute(mockConfig); + expect((result.data as { content: string }).content).toEqual( + JSON.stringify(checkpointInfo), + ); + }); + + it('should handle errors when listing checkpoints', async () => { + const command = new ListCheckpointsCommand(); + mockFs.readdir.mockRejectedValue(new Error('Read error')); + const result = await command.execute(mockConfig); + expect((result.data as { content: string }).content).toContain( + 'An unexpected error occurred while listing checkpoints.', + ); + }); +}); diff --git a/packages/a2a-server/src/commands/restore.ts b/packages/a2a-server/src/commands/restore.ts new file mode 100644 index 0000000000..5b4839a2e4 --- /dev/null +++ b/packages/a2a-server/src/commands/restore.ts @@ -0,0 +1,155 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + getCheckpointInfoList, + getToolCallDataSchema, + isNodeError, + performRestore, +} from '@google/gemini-cli-core'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import type { + Command, + CommandContext, + CommandExecutionResponse, +} from './types.js'; + +export class RestoreCommand implements Command { + readonly name = 'restore'; + readonly description = + 'Restore to a previous checkpoint, or list available checkpoints to restore. This will reset the conversation and file history to the state it was in when the checkpoint was created'; + readonly topLevel = true; + readonly requiresWorkspace = true; + readonly subCommands = [new ListCheckpointsCommand()]; + + async execute( + context: CommandContext, + args: string[], + ): Promise { + const { config, git: gitService } = context; + const argsStr = args.join(' '); + + try { + if (!argsStr) { + return { + name: this.name, + data: { + type: 'message', + messageType: 'error', + content: 'Please provide a checkpoint name to restore.', + }, + }; + } + + const selectedFile = argsStr.endsWith('.json') + ? argsStr + : `${argsStr}.json`; + + const checkpointDir = config.storage.getProjectTempCheckpointsDir(); + const filePath = path.join(checkpointDir, selectedFile); + + let data: string; + try { + data = await fs.readFile(filePath, 'utf-8'); + } catch (error) { + if (isNodeError(error) && error.code === 'ENOENT') { + return { + name: this.name, + data: { + type: 'message', + messageType: 'error', + content: `File not found: ${selectedFile}`, + }, + }; + } + throw error; + } + + const toolCallData = JSON.parse(data); + const ToolCallDataSchema = getToolCallDataSchema(); + const parseResult = ToolCallDataSchema.safeParse(toolCallData); + + if (!parseResult.success) { + return { + name: this.name, + data: { + type: 'message', + messageType: 'error', + content: 'Checkpoint file is invalid or corrupted.', + }, + }; + } + + const restoreResultGenerator = performRestore( + parseResult.data, + gitService, + ); + const restoreResult = []; + for await (const result of restoreResultGenerator) { + restoreResult.push(result); + } + + return { + name: this.name, + data: restoreResult, + }; + } catch (_error) { + return { + name: this.name, + data: { + type: 'message', + messageType: 'error', + content: 'An unexpected error occurred during restore.', + }, + }; + } + } +} + +export class ListCheckpointsCommand implements Command { + readonly name = 'restore list'; + readonly description = 'Lists all available checkpoints.'; + readonly topLevel = false; + + async execute(context: CommandContext): Promise { + const { config } = context; + + try { + const checkpointDir = config.storage.getProjectTempCheckpointsDir(); + await fs.mkdir(checkpointDir, { recursive: true }); + const files = await fs.readdir(checkpointDir); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + + const checkpointFiles = new Map(); + for (const file of jsonFiles) { + const filePath = path.join(checkpointDir, file); + const data = await fs.readFile(filePath, 'utf-8'); + checkpointFiles.set(file, data); + } + + const checkpointInfoList = getCheckpointInfoList(checkpointFiles); + + return { + name: this.name, + data: { + type: 'message', + messageType: 'info', + content: JSON.stringify(checkpointInfoList), + }, + }; + } catch (_error) { + return { + name: this.name, + data: { + type: 'message', + messageType: 'error', + content: 'An unexpected error occurred while listing checkpoints.', + }, + }; + } + } +} diff --git a/packages/a2a-server/src/commands/types.ts b/packages/a2a-server/src/commands/types.ts index ef6a876c5c..aca5693e13 100644 --- a/packages/a2a-server/src/commands/types.ts +++ b/packages/a2a-server/src/commands/types.ts @@ -4,7 +4,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Config } from '@google/gemini-cli-core'; +import type { Config, GitService } from '@google/gemini-cli-core'; + +export interface CommandContext { + config: Config; + git?: GitService; +} export interface CommandArgument { readonly name: string; @@ -18,8 +23,12 @@ export interface Command { readonly arguments?: CommandArgument[]; readonly subCommands?: Command[]; readonly topLevel?: boolean; + readonly requiresWorkspace?: boolean; - execute(config: Config, args: string[]): Promise; + execute( + config: CommandContext, + args: string[], + ): Promise; } export interface CommandExecutionResponse { diff --git a/packages/a2a-server/src/config/config.ts b/packages/a2a-server/src/config/config.ts index 999f6b95b2..e7a3609ca5 100644 --- a/packages/a2a-server/src/config/config.ts +++ b/packages/a2a-server/src/config/config.ts @@ -71,6 +71,9 @@ export async function loadConfig( ideMode: false, folderTrust: settings.folderTrust === true, extensionLoader, + checkpointing: process.env['CHECKPOINTING'] + ? process.env['CHECKPOINTING'] === 'true' + : settings.checkpointing?.enabled, previewFeatures: settings.general?.previewFeatures, }; diff --git a/packages/a2a-server/src/http/app.test.ts b/packages/a2a-server/src/http/app.test.ts index 57269feeba..641b3749e6 100644 --- a/packages/a2a-server/src/http/app.test.ts +++ b/packages/a2a-server/src/http/app.test.ts @@ -36,7 +36,7 @@ import { createMockConfig, } from '../utils/testing_utils.js'; import { MockTool } from '@google/gemini-cli-core'; -import type { Command } from '../commands/types.js'; +import type { Command, CommandContext } from '../commands/types.js'; const mockToolConfirmationFn = async () => ({}) as unknown as ToolCallConfirmationDetails; @@ -97,6 +97,7 @@ vi.mock('@google/gemini-cli-core', async () => { getUserTier: vi.fn().mockReturnValue('free'), initialize: vi.fn(), })), + performRestore: vi.fn(), }; }); @@ -939,6 +940,17 @@ describe('E2E Tests', () => { }); it('should return extensions for valid command', async () => { + const mockExtensionsCommand = { + name: 'extensions list', + description: 'a mock command', + execute: vi.fn(async (context: CommandContext) => { + // Simulate the actual command's behavior + const extensions = context.config.getExtensions(); + return { name: 'extensions list', data: extensions }; + }), + }; + vi.spyOn(commandRegistry, 'get').mockReturnValue(mockExtensionsCommand); + const agent = request.agent(app); const res = await agent .post('/executeCommand') @@ -954,6 +966,8 @@ describe('E2E Tests', () => { }); it('should return 404 for invalid command', async () => { + vi.spyOn(commandRegistry, 'get').mockReturnValue(undefined); + const agent = request.agent(app); const res = await agent .post('/executeCommand') @@ -986,5 +1000,66 @@ describe('E2E Tests', () => { expect(res.body.error).toBe('"args" field must be an array.'); expect(getExtensionsSpy).not.toHaveBeenCalled(); }); + + it('should execute a command that does not require a workspace when CODER_AGENT_WORKSPACE_PATH is not set', async () => { + const mockCommand = { + name: 'test-command', + description: 'a mock command', + execute: vi + .fn() + .mockResolvedValue({ name: 'test-command', data: 'success' }), + }; + vi.spyOn(commandRegistry, 'get').mockReturnValue(mockCommand); + + delete process.env['CODER_AGENT_WORKSPACE_PATH']; + const response = await request(app) + .post('/executeCommand') + .send({ command: 'test-command', args: [] }); + + expect(response.status).toBe(200); + expect(response.body.data).toBe('success'); + }); + + it('should return 400 for a command that requires a workspace when CODER_AGENT_WORKSPACE_PATH is not set', async () => { + const mockWorkspaceCommand = { + name: 'workspace-command', + description: 'A command that requires a workspace', + requiresWorkspace: true, + execute: vi + .fn() + .mockResolvedValue({ name: 'workspace-command', data: 'success' }), + }; + vi.spyOn(commandRegistry, 'get').mockReturnValue(mockWorkspaceCommand); + + delete process.env['CODER_AGENT_WORKSPACE_PATH']; + const response = await request(app) + .post('/executeCommand') + .send({ command: 'workspace-command', args: [] }); + + expect(response.status).toBe(400); + expect(response.body.error).toBe( + 'Command "workspace-command" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.', + ); + }); + + it('should execute a command that requires a workspace when CODER_AGENT_WORKSPACE_PATH is set', async () => { + const mockWorkspaceCommand = { + name: 'workspace-command', + description: 'A command that requires a workspace', + requiresWorkspace: true, + execute: vi + .fn() + .mockResolvedValue({ name: 'workspace-command', data: 'success' }), + }; + vi.spyOn(commandRegistry, 'get').mockReturnValue(mockWorkspaceCommand); + + process.env['CODER_AGENT_WORKSPACE_PATH'] = '/tmp/test-workspace'; + const response = await request(app) + .post('/executeCommand') + .send({ command: 'workspace-command', args: [] }); + + expect(response.status).toBe(200); + expect(response.body.data).toBe('success'); + }); }); }); diff --git a/packages/a2a-server/src/http/app.ts b/packages/a2a-server/src/http/app.ts index 81d57e4abd..91f1e70dd4 100644 --- a/packages/a2a-server/src/http/app.ts +++ b/packages/a2a-server/src/http/app.ts @@ -22,6 +22,7 @@ import { loadExtensions } from '../config/extension.js'; import { commandRegistry } from '../commands/command-registry.js'; import { SimpleExtensionLoader } from '@google/gemini-cli-core'; import type { Command, CommandArgument } from '../commands/types.js'; +import { GitService } from '@google/gemini-cli-core'; type CommandResponse = { name: string; @@ -85,6 +86,14 @@ export async function createApp() { 'a2a-server', ); + let git: GitService | undefined; + if (config.getCheckpointingEnabled()) { + git = new GitService(config.getTargetDir(), config.storage); + await git.initialize(); + } + + const context = { config, git }; + // loadEnvironment() is called within getConfig now const bucketName = process.env['GCS_BUCKET_NAME']; let taskStoreForExecutor: TaskStore; @@ -144,6 +153,7 @@ export async function createApp() { }); expressApp.post('/executeCommand', async (req, res) => { + logger.info('[CoreAgent] Received /executeCommand request: ', req.body); try { const { command, args } = req.body; @@ -159,13 +169,22 @@ export async function createApp() { const commandToExecute = commandRegistry.get(command); + if (commandToExecute?.requiresWorkspace) { + if (!process.env['CODER_AGENT_WORKSPACE_PATH']) { + return res.status(400).json({ + error: `Command "${command}" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.`, + }); + } + } + if (!commandToExecute) { return res .status(404) .json({ error: `Command not found: ${command}` }); } - const result = await commandToExecute.execute(config, args ?? []); + const result = await commandToExecute.execute(context, args ?? []); + logger.info('[CoreAgent] Sending /executeCommand response: ', result); return res.status(200).json(result); } catch (e) { logger.error('Error executing /executeCommand:', e); diff --git a/packages/a2a-server/src/utils/testing_utils.ts b/packages/a2a-server/src/utils/testing_utils.ts index fcd184924c..0abd9f3b31 100644 --- a/packages/a2a-server/src/utils/testing_utils.ts +++ b/packages/a2a-server/src/utils/testing_utils.ts @@ -37,8 +37,10 @@ export function createMockConfig( isPathWithinWorkspace: () => true, }), getTargetDir: () => '/test', + getCheckpointingEnabled: vi.fn().mockReturnValue(false), storage: { getProjectTempDir: () => '/tmp', + getProjectTempCheckpointsDir: () => '/tmp/checkpoints', } as Storage, getTruncateToolOutputThreshold: () => DEFAULT_TRUNCATE_TOOL_OUTPUT_THRESHOLD, @@ -62,6 +64,7 @@ export function createMockConfig( getMcpClientManager: vi.fn().mockReturnValue({ getMcpServers: vi.fn().mockReturnValue({}), }), + getGitService: vi.fn(), ...overrides, } as unknown as Config; mockConfig.getMessageBus = vi.fn().mockReturnValue(createMockMessageBus()); diff --git a/packages/cli/src/ui/commands/restoreCommand.ts b/packages/cli/src/ui/commands/restoreCommand.ts index ee219637b0..5825572642 100644 --- a/packages/cli/src/ui/commands/restoreCommand.ts +++ b/packages/cli/src/ui/commands/restoreCommand.ts @@ -9,6 +9,9 @@ import path from 'node:path'; import { z } from 'zod'; import { type Config, + formatCheckpointDisplayList, + getToolCallDataSchema, + getTruncatedCheckpointNames, performRestore, type ToolCallData, } from '@google/gemini-cli-core'; @@ -28,23 +31,7 @@ const HistoryItemSchema = z }) .passthrough(); -const ContentSchema = z - .object({ - role: z.string().optional(), - parts: z.array(z.record(z.unknown())), - }) - .passthrough(); - -const ToolCallDataSchema = z.object({ - history: z.array(HistoryItemSchema).optional(), - clientHistory: z.array(ContentSchema).optional(), - commitHash: z.string().optional(), - toolCall: z.object({ - name: z.string(), - args: z.record(z.unknown()), - }), - messageId: z.string().optional(), -}); +const ToolCallDataSchema = getToolCallDataSchema(HistoryItemSchema); async function restoreAction( context: CommandContext, @@ -78,15 +65,7 @@ async function restoreAction( content: 'No restorable tool calls found.', }; } - const truncatedFiles = jsonFiles.map((file) => { - const components = file.split('.'); - if (components.length <= 1) { - return file; - } - components.pop(); - return components.join('.'); - }); - const fileList = truncatedFiles.join('\n'); + const fileList = formatCheckpointDisplayList(jsonFiles); return { type: 'message', messageType: 'info', @@ -171,9 +150,8 @@ async function completion( } try { const files = await fs.readdir(checkpointDir); - return files - .filter((file) => file.endsWith('.json')) - .map((file) => file.replace('.json', '')); + const jsonFiles = files.filter((file) => file.endsWith('.json')); + return getTruncatedCheckpointNames(jsonFiles); } catch (_err) { return []; } diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index a361fdb966..47b97936f6 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -16,7 +16,6 @@ import type { ThoughtSummary, ToolCallRequestInfo, GeminiErrorEventValue, - ToolCallData, } from '@google/gemini-cli-core'; import { GeminiEventType as ServerGeminiEventType, @@ -34,10 +33,11 @@ import { parseAndFormatApiError, ToolConfirmationOutcome, promptIdContext, - WRITE_FILE_TOOL_NAME, tokenLimit, debugLogger, runInDevTraceSpan, + EDIT_TOOL_NAMES, + processRestorableToolCalls, } from '@google/gemini-cli-core'; import { type Part, type PartListUnion, FinishReason } from '@google/genai'; import type { @@ -76,8 +76,6 @@ enum StreamProcessingStatus { Error, } -const EDIT_TOOL_NAMES = new Set(['replace', WRITE_FILE_TOOL_NAME]); - function showCitations(settings: LoadedSettings): boolean { const enabled = settings?.merged?.ui?.showCitations; if (enabled !== undefined) { @@ -1248,98 +1246,37 @@ export const useGeminiStream = ( ); if (restorableToolCalls.length > 0) { - const checkpointDir = storage.getProjectTempCheckpointsDir(); - - if (!checkpointDir) { + if (!gitService) { + onDebugMessage( + 'Checkpointing is enabled but Git service is not available. Failed to create snapshot. Ensure Git is installed and working properly.', + ); return; } - try { - await fs.mkdir(checkpointDir, { recursive: true }); - } catch (error) { - if (!isNodeError(error) || error.code !== 'EEXIST') { - onDebugMessage( - `Failed to create checkpoint directory: ${getErrorMessage(error)}`, - ); - return; - } + const { checkpointsToWrite, errors } = await processRestorableToolCalls< + HistoryItem[] + >( + restorableToolCalls.map((call) => call.request), + gitService, + geminiClient, + history, + ); + + if (errors.length > 0) { + errors.forEach(onDebugMessage); } - for (const toolCall of restorableToolCalls) { - const filePath = toolCall.request.args['file_path'] as string; - if (!filePath) { - onDebugMessage( - `Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`, - ); - continue; - } - + if (checkpointsToWrite.size > 0) { + const checkpointDir = storage.getProjectTempCheckpointsDir(); try { - if (!gitService) { - onDebugMessage( - `Checkpointing is enabled but Git service is not available. Failed to create snapshot for ${filePath}. Ensure Git is installed and working properly.`, - ); - continue; + await fs.mkdir(checkpointDir, { recursive: true }); + for (const [fileName, content] of checkpointsToWrite) { + const filePath = path.join(checkpointDir, fileName); + await fs.writeFile(filePath, content); } - - let commitHash: string | undefined; - try { - commitHash = await gitService.createFileSnapshot( - `Snapshot for ${toolCall.request.name}`, - ); - } catch (error) { - onDebugMessage( - `Failed to create new snapshot: ${getErrorMessage(error)}. Attempting to use current commit.`, - ); - } - - if (!commitHash) { - commitHash = await gitService.getCurrentCommitHash(); - } - - if (!commitHash) { - onDebugMessage( - `Failed to create snapshot for ${filePath}. Checkpointing may not be working properly. Ensure Git is installed and the project directory is accessible.`, - ); - continue; - } - - const timestamp = new Date() - .toISOString() - .replace(/:/g, '-') - .replace(/\./g, '_'); - const toolName = toolCall.request.name; - const fileName = path.basename(filePath); - const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`; - const clientHistory = await geminiClient?.getHistory(); - const toolCallWithSnapshotFilePath = path.join( - checkpointDir, - toolCallWithSnapshotFileName, - ); - - const checkpointData: ToolCallData< - HistoryItem[], - Record - > & { filePath: string } = { - history, - clientHistory, - toolCall: { - name: toolCall.request.name, - args: toolCall.request.args, - }, - commitHash, - filePath, - }; - - await fs.writeFile( - toolCallWithSnapshotFilePath, - JSON.stringify(checkpointData, null, 2), - ); } catch (error) { onDebugMessage( - `Failed to create checkpoint for ${filePath}: ${getErrorMessage( - error, - )}. This may indicate a problem with Git or file system permissions.`, + `Failed to write checkpoint file: ${getErrorMessage(error)}`, ); } } diff --git a/packages/core/src/commands/restore.test.ts b/packages/core/src/commands/restore.test.ts index 634f4989d4..4dcba5dd87 100644 --- a/packages/core/src/commands/restore.test.ts +++ b/packages/core/src/commands/restore.test.ts @@ -5,7 +5,8 @@ */ import { describe, it, expect, vi, beforeEach } from 'vitest'; -import { performRestore, type ToolCallData } from './restore.js'; +import { performRestore } from './restore.js'; +import { type ToolCallData } from '../utils/checkpointUtils.js'; import type { GitService } from '../services/gitService.js'; describe('performRestore', () => { diff --git a/packages/core/src/commands/restore.ts b/packages/core/src/commands/restore.ts index 778836f96a..06c2013845 100644 --- a/packages/core/src/commands/restore.ts +++ b/packages/core/src/commands/restore.ts @@ -4,20 +4,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Content } from '@google/genai'; import type { GitService } from '../services/gitService.js'; import type { CommandActionReturn } from './types.js'; - -export interface ToolCallData { - history?: HistoryType; - clientHistory?: Content[]; - commitHash?: string; - toolCall: { - name: string; - args: ArgsType; - }; - messageId?: string; -} +import { type ToolCallData } from '../utils/checkpointUtils.js'; export async function* performRestore< HistoryType = unknown, diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 2f05b5bcba..e29eeef893 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -107,6 +107,7 @@ export interface ToolCallRequestInfo { args: Record; isClientInitiated: boolean; prompt_id: string; + checkpoint?: string; } export interface ToolCallResponseInfo { diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index e667185ca5..bb9521860c 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -78,6 +78,7 @@ export * from './utils/debugLogger.js'; export * from './utils/events.js'; export * from './utils/extensionLoader.js'; export * from './utils/package.js'; +export * from './utils/checkpointUtils.js'; // Export services export * from './services/fileDiscoveryService.js'; diff --git a/packages/core/src/services/gitService.test.ts b/packages/core/src/services/gitService.test.ts index 3d24f7e580..7923b7e647 100644 --- a/packages/core/src/services/gitService.test.ts +++ b/packages/core/src/services/gitService.test.ts @@ -32,6 +32,7 @@ const hoistedMockInit = vi.hoisted(() => vi.fn()); const hoistedMockRaw = vi.hoisted(() => vi.fn()); const hoistedMockAdd = vi.hoisted(() => vi.fn()); const hoistedMockCommit = vi.hoisted(() => vi.fn()); +const hoistedMockStatus = vi.hoisted(() => vi.fn()); vi.mock('simple-git', () => ({ simpleGit: hoistedMockSimpleGit.mockImplementation(() => ({ checkIsRepo: hoistedMockCheckIsRepo, @@ -39,6 +40,7 @@ vi.mock('simple-git', () => ({ raw: hoistedMockRaw, add: hoistedMockAdd, commit: hoistedMockCommit, + status: hoistedMockStatus, env: hoistedMockEnv, })), CheckRepoActions: { IS_REPO_ROOT: 'is-repo-root' }, @@ -89,6 +91,7 @@ describe('GitService', () => { raw: hoistedMockRaw, add: hoistedMockAdd, commit: hoistedMockCommit, + status: hoistedMockStatus, })); hoistedMockSimpleGit.mockImplementation(() => ({ checkIsRepo: hoistedMockCheckIsRepo, @@ -96,6 +99,7 @@ describe('GitService', () => { raw: hoistedMockRaw, add: hoistedMockAdd, commit: hoistedMockCommit, + status: hoistedMockStatus, env: hoistedMockEnv, })); hoistedMockCheckIsRepo.mockResolvedValue(false); @@ -248,6 +252,7 @@ describe('GitService', () => { describe('createFileSnapshot', () => { it('should commit with --no-verify flag', async () => { + hoistedMockStatus.mockResolvedValue({ isClean: () => false }); const service = new GitService(projectRoot, storage); await service.initialize(); await service.createFileSnapshot('test commit'); @@ -255,5 +260,30 @@ describe('GitService', () => { '--no-verify': null, }); }); + + it('should create a new commit if there are staged changes', async () => { + hoistedMockStatus.mockResolvedValue({ isClean: () => false }); + hoistedMockCommit.mockResolvedValue({ commit: 'new-commit-hash' }); + const service = new GitService(projectRoot, storage); + const commitHash = await service.createFileSnapshot('test message'); + expect(hoistedMockAdd).toHaveBeenCalledWith('.'); + expect(hoistedMockStatus).toHaveBeenCalled(); + expect(hoistedMockCommit).toHaveBeenCalledWith('test message', { + '--no-verify': null, + }); + expect(commitHash).toBe('new-commit-hash'); + }); + + it('should return the current HEAD commit hash if there are no staged changes', async () => { + hoistedMockStatus.mockResolvedValue({ isClean: () => true }); + hoistedMockRaw.mockResolvedValue('current-head-hash'); + const service = new GitService(projectRoot, storage); + const commitHash = await service.createFileSnapshot('test message'); + expect(hoistedMockAdd).toHaveBeenCalledWith('.'); + expect(hoistedMockStatus).toHaveBeenCalled(); + expect(hoistedMockCommit).not.toHaveBeenCalled(); + expect(hoistedMockRaw).toHaveBeenCalledWith('rev-parse', 'HEAD'); + expect(commitHash).toBe('current-head-hash'); + }); }); }); diff --git a/packages/core/src/services/gitService.ts b/packages/core/src/services/gitService.ts index b38f3e289f..cbda1ae1c3 100644 --- a/packages/core/src/services/gitService.ts +++ b/packages/core/src/services/gitService.ts @@ -112,6 +112,11 @@ export class GitService { try { const repo = this.shadowGitRepository; await repo.add('.'); + const status = await repo.status(); + if (status.isClean()) { + // If no changes are staged, return the current HEAD commit hash + return await this.getCurrentCommitHash(); + } const commitResult = await repo.commit(message, { '--no-verify': null, }); diff --git a/packages/core/src/tools/tool-names.ts b/packages/core/src/tools/tool-names.ts index 77ce28fe8f..ec8ebb6d80 100644 --- a/packages/core/src/tools/tool-names.ts +++ b/packages/core/src/tools/tool-names.ts @@ -20,3 +20,4 @@ export const READ_MANY_FILES_TOOL_NAME = 'read_many_files'; export const READ_FILE_TOOL_NAME = 'read_file'; export const LS_TOOL_NAME = 'list_directory'; export const MEMORY_TOOL_NAME = 'save_memory'; +export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]); diff --git a/packages/core/src/utils/checkpointUtils.test.ts b/packages/core/src/utils/checkpointUtils.test.ts new file mode 100644 index 0000000000..2a0d198f31 --- /dev/null +++ b/packages/core/src/utils/checkpointUtils.test.ts @@ -0,0 +1,305 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, type Mock } from 'vitest'; +import { z } from 'zod'; +import { + getToolCallDataSchema, + generateCheckpointFileName, + formatCheckpointDisplayList, + getTruncatedCheckpointNames, + processRestorableToolCalls, + getCheckpointInfoList, +} from './checkpointUtils.js'; +import type { GitService } from '../services/gitService.js'; +import type { GeminiClient } from '../core/client.js'; +import type { ToolCallRequestInfo } from '../core/turn.js'; + +describe('checkpoint utils', () => { + describe('getToolCallDataSchema', () => { + it('should return a schema that validates a basic tool call data object', () => { + const schema = getToolCallDataSchema(); + const validData = { + toolCall: { name: 'test-tool', args: { foo: 'bar' } }, + }; + const result = schema.safeParse(validData); + expect(result.success).toBe(true); + }); + + it('should validate with an optional history schema', () => { + const historyItemSchema = z.object({ id: z.number(), event: z.string() }); + const schema = getToolCallDataSchema(historyItemSchema); + const validData = { + history: [{ id: 1, event: 'start' }], + toolCall: { name: 'test-tool', args: {} }, + }; + const result = schema.safeParse(validData); + expect(result.success).toBe(true); + }); + + it('should fail validation if history items do not match the schema', () => { + const historyItemSchema = z.object({ id: z.number(), event: z.string() }); + const schema = getToolCallDataSchema(historyItemSchema); + const invalidData = { + history: [{ id: '1', event: 'start' }], // id should be a number + toolCall: { name: 'test-tool', args: {} }, + }; + const result = schema.safeParse(invalidData); + expect(result.success).toBe(false); + }); + + it('should validate clientHistory with the correct schema', () => { + const schema = getToolCallDataSchema(); + const validData = { + clientHistory: [{ role: 'user', parts: [{ text: 'Hello' }] }], + toolCall: { name: 'test-tool', args: {} }, + }; + const result = schema.safeParse(validData); + expect(result.success).toBe(true); + }); + }); + + describe('generateCheckpointFileName', () => { + it('should generate a filename with timestamp, basename, and tool name', () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date('2025-01-01T12:00:00.000Z')); + const toolCall = { + callId: '1', + name: 'replace', + args: { file_path: '/path/to/my-file.txt' }, + isClientInitiated: false, + prompt_id: 'p1', + } as ToolCallRequestInfo; + + const expected = '2025-01-01T12-00-00_000Z-my-file.txt-replace'; + const actual = generateCheckpointFileName(toolCall); + expect(actual).toBe(expected); + + vi.useRealTimers(); + }); + + it('should return null if file_path is not in the tool arguments', () => { + const toolCall = { + callId: '1', + name: 'replace', + args: { some_other_arg: 'value' }, + isClientInitiated: false, + prompt_id: 'p1', + } as ToolCallRequestInfo; + + const actual = generateCheckpointFileName(toolCall); + expect(actual).toBeNull(); + }); + }); + + describe('formatCheckpointDisplayList and getTruncatedCheckpointNames', () => { + const filenames = [ + '2025-01-01T12-00-00_000Z-my-file.txt-replace.json', + '2025-01-01T13-00-00_000Z-another.js-write_file.json', + 'no-extension-file', + ]; + + it('getTruncatedCheckpointNames should remove the .json extension', () => { + const expected = [ + '2025-01-01T12-00-00_000Z-my-file.txt-replace', + '2025-01-01T13-00-00_000Z-another.js-write_file', + 'no-extension-file', + ]; + const actual = getTruncatedCheckpointNames(filenames); + expect(actual).toEqual(expected); + }); + + it('formatCheckpointDisplayList should return a newline-separated string of truncated names', () => { + const expected = [ + '2025-01-01T12-00-00_000Z-my-file.txt-replace', + '2025-01-01T13-00-00_000Z-another.js-write_file', + 'no-extension-file', + ].join('\n'); + const actual = formatCheckpointDisplayList(filenames); + expect(actual).toEqual(expected); + }); + }); + + describe('processRestorableToolCalls', () => { + const mockGitService = { + createFileSnapshot: vi.fn(), + getCurrentCommitHash: vi.fn(), + } as unknown as GitService; + + const mockGeminiClient = { + getHistory: vi.fn(), + } as unknown as GeminiClient; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('should create checkpoints for restorable tool calls', async () => { + const toolCalls = [ + { + callId: '1', + name: 'replace', + args: { file_path: 'a.txt' }, + prompt_id: 'p1', + isClientInitiated: false, + }, + ] as ToolCallRequestInfo[]; + + (mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123'); + (mockGeminiClient.getHistory as Mock).mockResolvedValue([ + { role: 'user', parts: [] }, + ]); + + const { checkpointsToWrite, toolCallToCheckpointMap, errors } = + await processRestorableToolCalls( + toolCalls, + mockGitService, + mockGeminiClient, + 'history-data', + ); + + expect(errors).toHaveLength(0); + expect(checkpointsToWrite.size).toBe(1); + expect(toolCallToCheckpointMap.get('1')).toBeDefined(); + + const fileName = checkpointsToWrite.values().next().value; + expect(fileName).toBeDefined(); + const fileContent = JSON.parse(fileName!); + + expect(fileContent.commitHash).toBe('hash123'); + expect(fileContent.history).toBe('history-data'); + expect(fileContent.clientHistory).toEqual([{ role: 'user', parts: [] }]); + expect(fileContent.toolCall.name).toBe('replace'); + expect(fileContent.messageId).toBe('p1'); + }); + + it('should handle git snapshot failure by using current commit hash', async () => { + const toolCalls = [ + { + callId: '1', + name: 'replace', + args: { file_path: 'a.txt' }, + prompt_id: 'p1', + isClientInitiated: false, + }, + ] as ToolCallRequestInfo[]; + + (mockGitService.createFileSnapshot as Mock).mockRejectedValue( + new Error('Snapshot failed'), + ); + (mockGitService.getCurrentCommitHash as Mock).mockResolvedValue( + 'fallback-hash', + ); + + const { checkpointsToWrite, errors } = await processRestorableToolCalls( + toolCalls, + mockGitService, + mockGeminiClient, + ); + + expect(errors).toHaveLength(1); + expect(errors[0]).toContain('Failed to create new snapshot'); + expect(checkpointsToWrite.size).toBe(1); + const value = checkpointsToWrite.values().next().value; + expect(value).toBeDefined(); + const fileContent = JSON.parse(value!); + expect(fileContent.commitHash).toBe('fallback-hash'); + }); + + it('should skip tool calls with no file_path', async () => { + const toolCalls = [ + { + callId: '1', + name: 'replace', + args: { not_a_path: 'a.txt' }, + prompt_id: 'p1', + isClientInitiated: false, + }, + ] as ToolCallRequestInfo[]; + (mockGitService.createFileSnapshot as Mock).mockResolvedValue('hash123'); + + const { checkpointsToWrite, errors } = await processRestorableToolCalls( + toolCalls, + mockGitService, + mockGeminiClient, + ); + + expect(errors).toHaveLength(1); + expect(errors[0]).toContain( + 'Skipping restorable tool call due to missing file_path', + ); + expect(checkpointsToWrite.size).toBe(0); + }); + + it('should log an error if git snapshot fails and then skip the tool call', async () => { + const toolCalls = [ + { + callId: '1', + name: 'replace', + args: { file_path: 'a.txt' }, + prompt_id: 'p1', + isClientInitiated: false, + }, + ] as ToolCallRequestInfo[]; + (mockGitService.createFileSnapshot as Mock).mockRejectedValue( + new Error('Snapshot failed'), + ); + (mockGitService.getCurrentCommitHash as Mock).mockResolvedValue( + undefined, + ); + + const { checkpointsToWrite, errors } = await processRestorableToolCalls( + toolCalls, + mockGitService, + mockGeminiClient, + ); + + expect(errors).toHaveLength(2); + expect(errors[0]).toContain('Failed to create new snapshot'); + expect(errors[1]).toContain('Failed to create snapshot for replace'); + expect(checkpointsToWrite.size).toBe(0); + }); + }); + + describe('getCheckpointInfoList', () => { + it('should parse valid checkpoint files and return a list of info', () => { + const checkpointFiles = new Map([ + ['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })], + ['checkpoint2.json', JSON.stringify({ messageId: 'msg2' })], + ]); + + const expected = [ + { messageId: 'msg1', checkpoint: 'checkpoint1' }, + { messageId: 'msg2', checkpoint: 'checkpoint2' }, + ]; + + const actual = getCheckpointInfoList(checkpointFiles); + expect(actual).toEqual(expected); + }); + + it('should ignore files with invalid JSON', () => { + const checkpointFiles = new Map([ + ['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })], + ['invalid.json', 'not-json'], + ]); + + const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }]; + const actual = getCheckpointInfoList(checkpointFiles); + expect(actual).toEqual(expected); + }); + + it('should ignore files that are missing a messageId', () => { + const checkpointFiles = new Map([ + ['checkpoint1.json', JSON.stringify({ messageId: 'msg1' })], + ['no-msg-id.json', JSON.stringify({ other_prop: 'value' })], + ]); + + const expected = [{ messageId: 'msg1', checkpoint: 'checkpoint1' }]; + const actual = getCheckpointInfoList(checkpointFiles); + expect(actual).toEqual(expected); + }); + }); +}); diff --git a/packages/core/src/utils/checkpointUtils.ts b/packages/core/src/utils/checkpointUtils.ts new file mode 100644 index 0000000000..5508531185 --- /dev/null +++ b/packages/core/src/utils/checkpointUtils.ts @@ -0,0 +1,182 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import * as path from 'node:path'; +import type { GitService } from '../services/gitService.js'; +import type { GeminiClient } from '../core/client.js'; +import { getErrorMessage } from './errors.js'; +import { z } from 'zod'; +import type { Content } from '@google/genai'; +import type { ToolCallRequestInfo } from '../core/turn.js'; + +export interface ToolCallData { + history?: HistoryType; + clientHistory?: Content[]; + commitHash?: string; + toolCall: { + name: string; + args: ArgsType; + }; + messageId?: string; +} + +const ContentSchema = z + .object({ + role: z.string().optional(), + parts: z.array(z.record(z.unknown())), + }) + .passthrough(); + +export function getToolCallDataSchema(historyItemSchema?: z.ZodTypeAny) { + const schema = historyItemSchema ?? z.any(); + + return z.object({ + history: z.array(schema).optional(), + clientHistory: z.array(ContentSchema).optional(), + commitHash: z.string().optional(), + toolCall: z.object({ + name: z.string(), + args: z.record(z.unknown()), + }), + messageId: z.string().optional(), + }); +} + +export function generateCheckpointFileName( + toolCall: ToolCallRequestInfo, +): string | null { + const toolArgs = toolCall.args as Record; + const toolFilePath = toolArgs['file_path'] as string; + + if (!toolFilePath) { + return null; + } + + const timestamp = new Date() + .toISOString() + .replace(/:/g, '-') + .replace(/\./g, '_'); + const toolName = toolCall.name; + const fileName = path.basename(toolFilePath); + + return `${timestamp}-${fileName}-${toolName}`; +} + +export function formatCheckpointDisplayList(filenames: string[]): string { + return getTruncatedCheckpointNames(filenames).join('\n'); +} + +export function getTruncatedCheckpointNames(filenames: string[]): string[] { + return filenames.map((file) => { + const components = file.split('.'); + if (components.length <= 1) { + return file; + } + components.pop(); + return components.join('.'); + }); +} + +export async function processRestorableToolCalls( + toolCalls: ToolCallRequestInfo[], + gitService: GitService, + geminiClient: GeminiClient, + history?: HistoryType, +): Promise<{ + checkpointsToWrite: Map; + toolCallToCheckpointMap: Map; + errors: string[]; +}> { + const checkpointsToWrite = new Map(); + const toolCallToCheckpointMap = new Map(); + const errors: string[] = []; + + for (const toolCall of toolCalls) { + try { + let commitHash: string | undefined; + try { + commitHash = await gitService.createFileSnapshot( + `Snapshot for ${toolCall.name}`, + ); + } catch (error) { + errors.push( + `Failed to create new snapshot for ${ + toolCall.name + }: ${getErrorMessage(error)}. Attempting to use current commit.`, + ); + commitHash = await gitService.getCurrentCommitHash(); + } + + if (!commitHash) { + errors.push( + `Failed to create snapshot for ${toolCall.name}. Checkpointing may not be working properly. Ensure Git is installed and the project directory is accessible.`, + ); + continue; + } + + const checkpointFileName = generateCheckpointFileName(toolCall); + if (!checkpointFileName) { + errors.push( + `Skipping restorable tool call due to missing file_path: ${toolCall.name}`, + ); + continue; + } + + const clientHistory = await geminiClient.getHistory(); + const checkpointData: ToolCallData = { + history, + clientHistory, + toolCall: { + name: toolCall.name, + args: toolCall.args, + }, + commitHash, + messageId: toolCall.prompt_id, + }; + + const fileName = `${checkpointFileName}.json`; + checkpointsToWrite.set(fileName, JSON.stringify(checkpointData, null, 2)); + toolCallToCheckpointMap.set( + toolCall.callId, + fileName.replace('.json', ''), + ); + } catch (error) { + errors.push( + `Failed to create checkpoint for ${toolCall.name}: ${getErrorMessage( + error, + )}`, + ); + } + } + + return { checkpointsToWrite, toolCallToCheckpointMap, errors }; +} + +export interface CheckpointInfo { + messageId: string; + checkpoint: string; +} + +export function getCheckpointInfoList( + checkpointFiles: Map, +): CheckpointInfo[] { + const checkpointInfoList: CheckpointInfo[] = []; + + for (const [file, content] of checkpointFiles) { + try { + const toolCallData = JSON.parse(content) as ToolCallData; + if (toolCallData.messageId) { + checkpointInfoList.push({ + messageId: toolCallData.messageId, + checkpoint: file.replace('.json', ''), + }); + } + } catch (_e) { + // Ignore invalid JSON files + } + } + return checkpointInfoList; +}