feat(a2a): Introduce /init command for a2a server (#13419)

This commit is contained in:
Coco Sheng
2025-12-12 12:09:04 -05:00
committed by GitHub
parent a02abcf578
commit 299cc9bebf
14 changed files with 742 additions and 106 deletions
+8 -1
View File
@@ -127,6 +127,7 @@ export class CoderAgentExecutor implements AgentExecutor {
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
runtimeTask.taskState = persistedState._taskState;
await runtimeTask.geminiClient.initialize();
@@ -145,7 +146,13 @@ export class CoderAgentExecutor implements AgentExecutor {
): Promise<TaskWrapper> {
const agentSettings = agentSettingsInput || ({} as AgentSettings);
const config = await this.getConfig(agentSettings, taskId);
const runtimeTask = await Task.create(taskId, contextId, config, eventBus);
const runtimeTask = await Task.create(
taskId,
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
await runtimeTask.geminiClient.initialize();
const wrapper = new TaskWrapper(runtimeTask, agentSettings);
+67 -2
View File
@@ -20,6 +20,8 @@ import {
type ToolCallRequestInfo,
type GitService,
type CompletedToolCall,
ApprovalMode,
ToolConfirmationOutcome,
} from '@google/gemini-cli-core';
import { createMockConfig } from '../utils/testing_utils.js';
import type { ExecutionEventBus, RequestContext } from '@a2a-js/sdk/server';
@@ -353,10 +355,12 @@ describe('Task', () => {
let task: Task;
type SpyInstance = ReturnType<typeof vi.spyOn>;
let setTaskStateAndPublishUpdateSpy: SpyInstance;
let mockConfig: Config;
let mockEventBus: ExecutionEventBus;
beforeEach(() => {
const mockConfig = createMockConfig();
const mockEventBus: ExecutionEventBus = {
mockConfig = createMockConfig() as Config;
mockEventBus = {
publish: vi.fn(),
on: vi.fn(),
off: vi.fn(),
@@ -465,6 +469,67 @@ describe('Task', () => {
);
expect(finalCall).toBeUndefined();
});
describe('auto-approval', () => {
it('should auto-approve tool calls when autoExecute is true', () => {
task.autoExecute = true;
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
});
it('should auto-approve tool calls when approval mode is YOLO', () => {
(mockConfig.getApprovalMode as Mock).mockReturnValue(ApprovalMode.YOLO);
task.autoExecute = false;
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
expect(onConfirmSpy).toHaveBeenCalledWith(
ToolConfirmationOutcome.ProceedOnce,
);
});
it('should NOT auto-approve when autoExecute is false and mode is not YOLO', () => {
task.autoExecute = false;
(mockConfig.getApprovalMode as Mock).mockReturnValue(
ApprovalMode.DEFAULT,
);
const onConfirmSpy = vi.fn();
const toolCalls = [
{
request: { callId: '1' },
status: 'awaiting_approval',
confirmationDetails: { onConfirm: onConfirmSpy },
},
] as unknown as ToolCall[];
// @ts-expect-error - Calling private method
task._schedulerToolCallsUpdate(toolCalls);
expect(onConfirmSpy).not.toHaveBeenCalled();
});
});
});
describe('currentPromptId and promptCount', () => {
+14 -3
View File
@@ -73,6 +73,7 @@ export class Task {
modelInfo?: string;
currentPromptId: string | undefined;
promptCount = 0;
autoExecute: boolean;
// For tool waiting logic
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
@@ -87,6 +88,7 @@ export class Task {
contextId: string,
config: Config,
eventBus?: ExecutionEventBus,
autoExecute = false,
) {
this.id = id;
this.contextId = contextId;
@@ -98,6 +100,7 @@ export class Task {
this.eventBus = eventBus;
this.completedToolCalls = [];
this._resetToolCompletionPromise();
this.autoExecute = autoExecute;
this.config.setFallbackModelHandler(
// For a2a-server, we want to automatically switch to the fallback model
// for future requests without retrying the current one. The 'stop'
@@ -111,8 +114,9 @@ export class Task {
contextId: string,
config: Config,
eventBus?: ExecutionEventBus,
autoExecute?: boolean,
): Promise<Task> {
return new Task(id, contextId, config, eventBus);
return new Task(id, contextId, config, eventBus, autoExecute);
}
// Note: `getAllMCPServerStatuses` retrieves the status of all MCP servers for the entire
@@ -396,8 +400,15 @@ export class Task {
}
});
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
logger.info('[Task] YOLO mode enabled. Auto-approving all tool calls.');
if (
this.autoExecute ||
this.config.getApprovalMode() === ApprovalMode.YOLO
) {
logger.info(
'[Task] ' +
(this.autoExecute ? '' : 'YOLO mode enabled. ') +
'Auto-approving all tool calls.',
);
toolCalls.forEach((tc: ToolCall) => {
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
@@ -5,6 +5,7 @@
*/
import { ExtensionsCommand } from './extensions.js';
import { InitCommand } from './init.js';
import { RestoreCommand } from './restore.js';
import type { Command } from './types.js';
@@ -14,6 +15,7 @@ class CommandRegistry {
constructor() {
this.register(new ExtensionsCommand());
this.register(new RestoreCommand());
this.register(new InitCommand());
}
register(command: Command) {
@@ -0,0 +1,182 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { InitCommand } from './init.js';
import { performInit } from '@google/gemini-cli-core';
import * as fs from 'node:fs';
import * as path from 'node:path';
import { CoderAgentExecutor } from '../agent/executor.js';
import { CoderAgentEvent } from '../types.js';
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
import { createMockConfig } from '../utils/testing_utils.js';
import type { CommandContext } from './types.js';
import type { CommandActionReturn, Config } from '@google/gemini-cli-core';
import { logger } from '../utils/logger.js';
vi.mock('@google/gemini-cli-core', async (importOriginal) => {
const actual =
await importOriginal<typeof import('@google/gemini-cli-core')>();
return {
...actual,
performInit: vi.fn(),
};
});
vi.mock('node:fs', () => ({
existsSync: vi.fn(),
writeFileSync: vi.fn(),
}));
vi.mock('../agent/executor.js', () => ({
CoderAgentExecutor: vi.fn().mockImplementation(() => ({
execute: vi.fn(),
})),
}));
vi.mock('../utils/logger.js', () => ({
logger: {
info: vi.fn(),
error: vi.fn(),
},
}));
describe('InitCommand', () => {
let eventBus: ExecutionEventBus;
let command: InitCommand;
let context: CommandContext;
let publishSpy: ReturnType<typeof vi.spyOn>;
let mockExecute: ReturnType<typeof vi.fn>;
const mockWorkspacePath = path.resolve('/tmp');
beforeEach(() => {
process.env['CODER_AGENT_WORKSPACE_PATH'] = mockWorkspacePath;
eventBus = {
publish: vi.fn(),
} as unknown as ExecutionEventBus;
command = new InitCommand();
const mockConfig = createMockConfig({
getModel: () => 'gemini-pro',
});
const mockExecutorInstance = new CoderAgentExecutor();
context = {
config: mockConfig as unknown as Config,
agentExecutor: mockExecutorInstance,
eventBus,
} as CommandContext;
publishSpy = vi.spyOn(eventBus, 'publish');
mockExecute = vi.fn();
vi.spyOn(mockExecutorInstance, 'execute').mockImplementation(mockExecute);
vi.clearAllMocks();
});
it('has requiresWorkspace set to true', () => {
expect(command.requiresWorkspace).toBe(true);
});
describe('execute', () => {
it('handles info from performInit', async () => {
vi.mocked(performInit).mockReturnValue({
type: 'message',
messageType: 'info',
content: 'GEMINI.md already exists.',
} as CommandActionReturn);
await command.execute(context, []);
expect(logger.info).toHaveBeenCalledWith(
'[EventBus event]: ',
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'completed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'GEMINI.md already exists.' }],
}),
}),
}),
);
expect(publishSpy).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'completed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'GEMINI.md already exists.' }],
}),
}),
}),
);
});
it('handles error from performInit', async () => {
vi.mocked(performInit).mockReturnValue({
type: 'message',
messageType: 'error',
content: 'An error occurred.',
} as CommandActionReturn);
await command.execute(context, []);
expect(publishSpy).toHaveBeenCalledWith(
expect.objectContaining({
kind: 'status-update',
status: expect.objectContaining({
state: 'failed',
message: expect.objectContaining({
parts: [{ kind: 'text', text: 'An error occurred.' }],
}),
}),
}),
);
});
describe('when handling submit_prompt', () => {
beforeEach(() => {
vi.mocked(performInit).mockReturnValue({
type: 'submit_prompt',
content: 'Create a new GEMINI.md file.',
} as CommandActionReturn);
});
it('writes the file and executes the agent', async () => {
await command.execute(context, []);
expect(fs.writeFileSync).toHaveBeenCalledWith(
path.join(mockWorkspacePath, 'GEMINI.md'),
'',
'utf8',
);
expect(mockExecute).toHaveBeenCalled();
});
it('passes autoExecute to the agent executor', async () => {
await command.execute(context, []);
expect(mockExecute).toHaveBeenCalledWith(
expect.objectContaining({
userMessage: expect.objectContaining({
parts: expect.arrayContaining([
expect.objectContaining({
text: 'Create a new GEMINI.md file.',
}),
]),
metadata: {
coderAgent: {
kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: mockWorkspacePath,
autoExecute: true,
},
},
}),
}),
eventBus,
);
});
});
});
});
+168
View File
@@ -0,0 +1,168 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'node:fs';
import * as path from 'node:path';
import { CoderAgentEvent, type AgentSettings } from '../types.js';
import { performInit } from '@google/gemini-cli-core';
import type {
Command,
CommandContext,
CommandExecutionResponse,
} from './types.js';
import type { CoderAgentExecutor } from '../agent/executor.js';
import type {
ExecutionEventBus,
RequestContext,
AgentExecutionEvent,
} from '@a2a-js/sdk/server';
import { v4 as uuidv4 } from 'uuid';
import { logger } from '../utils/logger.js';
export class InitCommand implements Command {
name = 'init';
description = 'Analyzes the project and creates a tailored GEMINI.md file';
requiresWorkspace = true;
streaming = true;
private handleMessageResult(
result: { content: string; messageType: 'info' | 'error' },
context: CommandContext,
eventBus: ExecutionEventBus,
taskId: string,
contextId: string,
): CommandExecutionResponse {
const statusState = result.messageType === 'error' ? 'failed' : 'completed';
const eventType =
result.messageType === 'error'
? CoderAgentEvent.StateChangeEvent
: CoderAgentEvent.TextContentEvent;
const event: AgentExecutionEvent = {
kind: 'status-update',
taskId,
contextId,
status: {
state: statusState,
message: {
kind: 'message',
role: 'agent',
parts: [{ kind: 'text', text: result.content }],
messageId: uuidv4(),
taskId,
contextId,
},
timestamp: new Date().toISOString(),
},
final: true,
metadata: {
coderAgent: { kind: eventType },
model: context.config.getModel(),
},
};
logger.info('[EventBus event]: ', event);
eventBus.publish(event);
return {
name: this.name,
data: result,
};
}
private async handleSubmitPromptResult(
result: { content: unknown },
context: CommandContext,
geminiMdPath: string,
eventBus: ExecutionEventBus,
taskId: string,
contextId: string,
): Promise<CommandExecutionResponse> {
fs.writeFileSync(geminiMdPath, '', 'utf8');
if (!context.agentExecutor) {
throw new Error('Agent executor not found in context.');
}
const agentExecutor = context.agentExecutor as CoderAgentExecutor;
const agentSettings: AgentSettings = {
kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: process.env['CODER_AGENT_WORKSPACE_PATH']!,
autoExecute: true,
};
if (typeof result.content !== 'string') {
throw new Error('Init command content must be a string.');
}
const promptText = result.content;
const requestContext: RequestContext = {
userMessage: {
kind: 'message',
role: 'user',
parts: [{ kind: 'text', text: promptText }],
messageId: uuidv4(),
taskId,
contextId,
metadata: {
coderAgent: agentSettings,
},
},
taskId,
contextId,
};
// The executor will handle the entire agentic loop, including
// creating the task, streaming responses, and handling tools.
await agentExecutor.execute(requestContext, eventBus);
return {
name: this.name,
data: geminiMdPath,
};
}
async execute(
context: CommandContext,
_args: string[] = [],
): Promise<CommandExecutionResponse> {
if (!context.eventBus) {
return {
name: this.name,
data: 'Use executeStream to get streaming results.',
};
}
const geminiMdPath = path.join(
process.env['CODER_AGENT_WORKSPACE_PATH']!,
'GEMINI.md',
);
const result = performInit(fs.existsSync(geminiMdPath));
const taskId = uuidv4();
const contextId = uuidv4();
switch (result.type) {
case 'message':
return this.handleMessageResult(
result,
context,
context.eventBus,
taskId,
contextId,
);
case 'submit_prompt':
return this.handleSubmitPromptResult(
result,
context,
geminiMdPath,
context.eventBus,
taskId,
contextId,
);
default:
throw new Error('Unknown result type from performInit');
}
}
}
@@ -4,11 +4,14 @@
* SPDX-License-Identifier: Apache-2.0
*/
import type { ExecutionEventBus, AgentExecutor } from '@a2a-js/sdk/server';
import type { Config, GitService } from '@google/gemini-cli-core';
export interface CommandContext {
config: Config;
git?: GitService;
agentExecutor?: AgentExecutor;
eventBus?: ExecutionEventBus;
}
export interface CommandArgument {
@@ -24,6 +27,7 @@ export interface Command {
readonly subCommands?: Command[];
readonly topLevel?: boolean;
readonly requiresWorkspace?: boolean;
readonly streaming?: boolean;
execute(
config: CommandContext,
+113
View File
@@ -1061,5 +1061,118 @@ describe('E2E Tests', () => {
expect(response.status).toBe(200);
expect(response.body.data).toBe('success');
});
it('should include agentExecutor in context', async () => {
const mockCommand = {
name: 'context-check-command',
description: 'checks context',
execute: vi.fn(async (context: CommandContext) => {
if (!context.agentExecutor) {
throw new Error('agentExecutor missing');
}
return { name: 'context-check-command', data: 'success' };
}),
};
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockCommand);
const agent = request.agent(app);
const res = await agent
.post('/executeCommand')
.send({ command: 'context-check-command', args: [] })
.set('Content-Type', 'application/json')
.expect(200);
expect(res.body.data).toBe('success');
});
describe('/executeCommand streaming', () => {
it('should execute a streaming command and stream back events', (done: (
err?: unknown,
) => void) => {
const executeSpy = vi.fn(async (context: CommandContext) => {
context.eventBus?.publish({
kind: 'status-update',
status: { state: 'working' },
taskId: 'test-task',
contextId: 'test-context',
final: false,
});
context.eventBus?.publish({
kind: 'status-update',
status: { state: 'completed' },
taskId: 'test-task',
contextId: 'test-context',
final: true,
});
return { name: 'stream-test', data: 'done' };
});
const mockStreamCommand = {
name: 'stream-test',
description: 'A test streaming command',
streaming: true,
execute: executeSpy,
};
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockStreamCommand);
const agent = request.agent(app);
agent
.post('/executeCommand')
.send({ command: 'stream-test', args: [] })
.set('Content-Type', 'application/json')
.set('Accept', 'text/event-stream')
.on('response', (res) => {
let data = '';
res.on('data', (chunk: Buffer) => {
data += chunk.toString();
});
res.on('end', () => {
try {
const events = streamToSSEEvents(data);
expect(events.length).toBe(2);
expect(events[0].result).toEqual({
kind: 'status-update',
status: { state: 'working' },
taskId: 'test-task',
contextId: 'test-context',
final: false,
});
expect(events[1].result).toEqual({
kind: 'status-update',
status: { state: 'completed' },
taskId: 'test-task',
contextId: 'test-context',
final: true,
});
expect(executeSpy).toHaveBeenCalled();
done();
} catch (e) {
done(e);
}
});
})
.end();
});
it('should handle non-streaming commands gracefully', async () => {
const mockNonStreamCommand = {
name: 'non-stream-test',
description: 'A test non-streaming command',
execute: vi
.fn()
.mockResolvedValue({ name: 'non-stream-test', data: 'done' }),
};
vi.spyOn(commandRegistry, 'get').mockReturnValue(mockNonStreamCommand);
const agent = request.agent(app);
const res = await agent
.post('/executeCommand')
.send({ command: 'non-stream-test', args: [] })
.set('Content-Type', 'application/json')
.expect(200);
expect(res.body).toEqual({ name: 'non-stream-test', data: 'done' });
});
});
});
});
+81 -44
View File
@@ -6,9 +6,14 @@
import express from 'express';
import type { AgentCard } from '@a2a-js/sdk';
import type { AgentCard, Message } from '@a2a-js/sdk';
import type { TaskStore } from '@a2a-js/sdk/server';
import { DefaultRequestHandler, InMemoryTaskStore } from '@a2a-js/sdk/server';
import {
DefaultRequestHandler,
InMemoryTaskStore,
DefaultExecutionEventBus,
type AgentExecutionEvent,
} from '@a2a-js/sdk/server';
import { A2AExpressApp } from '@a2a-js/sdk/server/express'; // Import server components
import { v4 as uuidv4 } from 'uuid';
import { logger } from '../utils/logger.js';
@@ -73,6 +78,76 @@ export function updateCoderAgentCardUrl(port: number) {
coderAgentCard.url = `http://localhost:${port}/`;
}
async function handleExecuteCommand(
req: express.Request,
res: express.Response,
context: {
config: Awaited<ReturnType<typeof loadConfig>>;
git: GitService | undefined;
agentExecutor: CoderAgentExecutor;
},
) {
logger.info('[CoreAgent] Received /executeCommand request: ', req.body);
const { command, args } = req.body;
try {
if (typeof command !== 'string') {
return res.status(400).json({ error: 'Invalid "command" field.' });
}
if (args && !Array.isArray(args)) {
return res.status(400).json({ error: '"args" field must be an array.' });
}
const commandToExecute = commandRegistry.get(command);
if (commandToExecute?.requiresWorkspace) {
if (!process.env['CODER_AGENT_WORKSPACE_PATH']) {
return res.status(400).json({
error: `Command "${command}" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.`,
});
}
}
if (!commandToExecute) {
return res.status(404).json({ error: `Command not found: ${command}` });
}
if (commandToExecute.streaming) {
const eventBus = new DefaultExecutionEventBus();
res.setHeader('Content-Type', 'text/event-stream');
const eventHandler = (event: AgentExecutionEvent) => {
const jsonRpcResponse = {
jsonrpc: '2.0',
id: 'taskId' in event ? event.taskId : (event as Message).messageId,
result: event,
};
res.write(`data: ${JSON.stringify(jsonRpcResponse)}\n`);
};
eventBus.on('event', eventHandler);
await commandToExecute.execute({ ...context, eventBus }, args ?? []);
eventBus.off('event', eventHandler);
eventBus.finished();
return res.end(); // Explicit return for streaming path
} else {
const result = await commandToExecute.execute(context, args ?? []);
logger.info('[CoreAgent] Sending /executeCommand response: ', result);
return res.status(200).json(result);
}
} catch (e) {
logger.error(
`Error executing /executeCommand: ${command} with args: ${JSON.stringify(
args,
)}`,
e,
);
const errorMessage =
e instanceof Error ? e.message : 'Unknown error executing command';
return res.status(500).json({ error: errorMessage });
}
}
export async function createApp() {
try {
// Load the server configuration once on startup.
@@ -92,8 +167,6 @@ export async function createApp() {
await git.initialize();
}
const context = { config, git };
// loadEnvironment() is called within getConfig now
const bucketName = process.env['GCS_BUCKET_NAME'];
let taskStoreForExecutor: TaskStore;
@@ -113,6 +186,8 @@ export async function createApp() {
const agentExecutor = new CoderAgentExecutor(taskStoreForExecutor);
const context = { config, git, agentExecutor };
const requestHandler = new DefaultRequestHandler(
coderAgentCard,
taskStoreForHandler,
@@ -152,46 +227,8 @@ export async function createApp() {
}
});
expressApp.post('/executeCommand', async (req, res) => {
logger.info('[CoreAgent] Received /executeCommand request: ', req.body);
try {
const { command, args } = req.body;
if (typeof command !== 'string') {
return res.status(400).json({ error: 'Invalid "command" field.' });
}
if (args && !Array.isArray(args)) {
return res
.status(400)
.json({ error: '"args" field must be an array.' });
}
const commandToExecute = commandRegistry.get(command);
if (commandToExecute?.requiresWorkspace) {
if (!process.env['CODER_AGENT_WORKSPACE_PATH']) {
return res.status(400).json({
error: `Command "${command}" requires a workspace, but CODER_AGENT_WORKSPACE_PATH is not set.`,
});
}
}
if (!commandToExecute) {
return res
.status(404)
.json({ error: `Command not found: ${command}` });
}
const result = await commandToExecute.execute(context, args ?? []);
logger.info('[CoreAgent] Sending /executeCommand response: ', result);
return res.status(200).json(result);
} catch (e) {
logger.error('Error executing /executeCommand:', e);
const errorMessage =
e instanceof Error ? e.message : 'Unknown error executing command';
return res.status(500).json({ error: errorMessage });
}
expressApp.post('/executeCommand', (req, res) => {
void handleExecuteCommand(req, res, context);
});
expressApp.get('/listCommands', (req, res) => {
+1
View File
@@ -46,6 +46,7 @@ export enum CoderAgentEvent {
export interface AgentSettings {
kind: CoderAgentEvent.StateAgentSettingsEvent;
workspacePath: string;
autoExecute?: boolean;
}
export interface ToolCallConfirmation {