mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-27 13:34:15 -07:00
feat(a2a): Introduce restore command for a2a server (#13015)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shreya Keshive <shreyakeshive@google.com>
This commit is contained in:
@@ -18,6 +18,7 @@ import {
|
||||
GeminiEventType,
|
||||
type Config,
|
||||
type ToolCallRequestInfo,
|
||||
type GitService,
|
||||
type CompletedToolCall,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { createMockConfig } from '../utils/testing_utils.js';
|
||||
@@ -25,6 +26,17 @@ import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type { ToolCall } from '@google/gemini-cli-core';
|
||||
|
||||
const mockProcessRestorableToolCalls = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
|
||||
const original =
|
||||
await importOriginal<typeof import('@google/gemini-cli-core')>();
|
||||
return {
|
||||
...original,
|
||||
processRestorableToolCalls: mockProcessRestorableToolCalls,
|
||||
};
|
||||
});
|
||||
|
||||
describe('Task', () => {
|
||||
it('scheduleToolCalls should not modify the input requests array', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
@@ -72,6 +84,141 @@ describe('Task', () => {
|
||||
expect(requests).toEqual(originalRequests);
|
||||
});
|
||||
|
||||
describe('scheduleToolCalls', () => {
|
||||
const mockConfig = createMockConfig();
|
||||
const mockEventBus: ExecutionEventBus = {
|
||||
publish: vi.fn(),
|
||||
on: vi.fn(),
|
||||
off: vi.fn(),
|
||||
once: vi.fn(),
|
||||
removeAllListeners: vi.fn(),
|
||||
finished: vi.fn(),
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should not create a checkpoint if no restorable tools are called', async () => {
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'run_shell_command',
|
||||
args: { command: 'ls' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should create a checkpoint if a restorable tool is called', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
getCheckpointingEnabled: () => true,
|
||||
getGitService: () => Promise.resolve({} as GitService),
|
||||
});
|
||||
mockProcessRestorableToolCalls.mockResolvedValue({
|
||||
checkpointsToWrite: new Map([['test.json', 'test content']]),
|
||||
toolCallToCheckpointMap: new Map(),
|
||||
errors: [],
|
||||
});
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: {
|
||||
file_path: 'test.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should process all restorable tools for checkpointing in a single batch', async () => {
|
||||
const mockConfig = createMockConfig({
|
||||
getCheckpointingEnabled: () => true,
|
||||
getGitService: () => Promise.resolve({} as GitService),
|
||||
});
|
||||
mockProcessRestorableToolCalls.mockResolvedValue({
|
||||
checkpointsToWrite: new Map([
|
||||
['test1.json', 'test content 1'],
|
||||
['test2.json', 'test content 2'],
|
||||
]),
|
||||
toolCallToCheckpointMap: new Map([
|
||||
['1', 'test1'],
|
||||
['2', 'test2'],
|
||||
]),
|
||||
errors: [],
|
||||
});
|
||||
// @ts-expect-error - Calling private constructor for test purposes.
|
||||
const task = new Task(
|
||||
'task-id',
|
||||
'context-id',
|
||||
mockConfig as Config,
|
||||
mockEventBus,
|
||||
);
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{
|
||||
callId: '1',
|
||||
name: 'replace',
|
||||
args: {
|
||||
file_path: 'test.txt',
|
||||
old_string: 'old',
|
||||
new_string: 'new',
|
||||
},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-1',
|
||||
},
|
||||
{
|
||||
callId: '2',
|
||||
name: 'write_file',
|
||||
args: { file_path: 'test2.txt', content: 'new content' },
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-2',
|
||||
},
|
||||
{
|
||||
callId: '3',
|
||||
name: 'not_restorable',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
prompt_id: 'prompt-id-3',
|
||||
},
|
||||
];
|
||||
const abortController = new AbortController();
|
||||
await task.scheduleToolCalls(requests, abortController.signal);
|
||||
expect(mockProcessRestorableToolCalls).toHaveBeenCalledExactlyOnceWith(
|
||||
[
|
||||
expect.objectContaining({ callId: '1' }),
|
||||
expect.objectContaining({ callId: '2' }),
|
||||
],
|
||||
expect.anything(),
|
||||
expect.anything(),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('acceptAgentMessage', () => {
|
||||
it('should set currentTraceId when event has traceId', async () => {
|
||||
const mockConfig = createMockConfig();
|
||||
|
||||
@@ -27,6 +27,8 @@ import {
|
||||
type Config,
|
||||
type UserTierId,
|
||||
type AnsiOutput,
|
||||
EDIT_TOOL_NAMES,
|
||||
processRestorableToolCalls,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||
import { type ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
@@ -40,7 +42,8 @@ import type {
|
||||
} from '@a2a-js/sdk';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { logger } from '../utils/logger.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as fs from 'node:fs/promises';
|
||||
import * as path from 'node:path';
|
||||
import { CoderAgentEvent } from '../types.js';
|
||||
import type {
|
||||
CoderAgentMessage,
|
||||
@@ -511,7 +514,7 @@ export class Task {
|
||||
new_string: string,
|
||||
): Promise<string> {
|
||||
try {
|
||||
const currentContent = fs.readFileSync(file_path, 'utf8');
|
||||
const currentContent = await fs.readFile(file_path, 'utf8');
|
||||
return this._applyReplacement(
|
||||
currentContent,
|
||||
old_string,
|
||||
@@ -554,6 +557,44 @@ export class Task {
|
||||
return;
|
||||
}
|
||||
|
||||
// Set checkpoint file before any file modification tool executes
|
||||
const restorableToolCalls = requests.filter((request) =>
|
||||
EDIT_TOOL_NAMES.has(request.name),
|
||||
);
|
||||
|
||||
if (restorableToolCalls.length > 0) {
|
||||
const gitService = await this.config.getGitService();
|
||||
if (gitService) {
|
||||
const { checkpointsToWrite, toolCallToCheckpointMap, errors } =
|
||||
await processRestorableToolCalls(
|
||||
restorableToolCalls,
|
||||
gitService,
|
||||
this.geminiClient,
|
||||
);
|
||||
|
||||
if (errors.length > 0) {
|
||||
errors.forEach((error) => logger.error(error));
|
||||
}
|
||||
|
||||
if (checkpointsToWrite.size > 0) {
|
||||
const checkpointDir =
|
||||
this.config.storage.getProjectTempCheckpointsDir();
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
for (const [fileName, content] of checkpointsToWrite) {
|
||||
const filePath = path.join(checkpointDir, fileName);
|
||||
await fs.writeFile(filePath, content);
|
||||
}
|
||||
}
|
||||
|
||||
for (const request of requests) {
|
||||
const checkpoint = toolCallToCheckpointMap.get(request.callId);
|
||||
if (checkpoint) {
|
||||
request.checkpoint = checkpoint;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const updatedRequests = await Promise.all(
|
||||
requests.map(async (request) => {
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user