Files
gemini-cli/packages/a2a-server/src/agent/executor.ts

617 lines
18 KiB
TypeScript

/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import type { Message, Task as SDKTask } from '@a2a-js/sdk';
import type {
TaskStore,
AgentExecutor,
AgentExecutionEvent,
RequestContext,
ExecutionEventBus,
} from '@a2a-js/sdk/server';
import {
GeminiEventType,
SimpleExtensionLoader,
type ToolCallRequestInfo,
type Config,
} from '@google/gemini-cli-core';
import { v4 as uuidv4 } from 'uuid';
import { logger } from '../utils/logger.js';
import {
CoderAgentEvent,
getPersistedState,
setPersistedState,
type StateChange,
type AgentSettings,
type PersistedStateMetadata,
getContextIdFromMetadata,
getAgentSettingsFromMetadata,
} from '../types.js';
import { loadConfig, loadEnvironment, setTargetDir } from '../config/config.js';
import { loadSettings } from '../config/settings.js';
import { loadExtensions } from '../config/extension.js';
import { Task } from './task.js';
import { requestStorage } from '../http/requestStorage.js';
import { pushTaskStateFailed } from '../utils/executor_utils.js';
/**
* Provides a wrapper for Task. Passes data from Task to SDKTask.
* The idea is to use this class inside CoderAgentExecutor to replace Task.
*/
class TaskWrapper {
task: Task;
agentSettings: AgentSettings;
constructor(task: Task, agentSettings: AgentSettings) {
this.task = task;
this.agentSettings = agentSettings;
}
get id() {
return this.task.id;
}
toSDKTask(): SDKTask {
const persistedState: PersistedStateMetadata = {
_agentSettings: this.agentSettings,
_taskState: this.task.taskState,
};
const sdkTask: SDKTask = {
id: this.task.id,
contextId: this.task.contextId,
kind: 'task',
status: {
state: this.task.taskState,
timestamp: new Date().toISOString(),
},
metadata: setPersistedState({}, persistedState),
history: [],
artifacts: [],
};
sdkTask.metadata!['_contextId'] = this.task.contextId;
return sdkTask;
}
}
/**
* CoderAgentExecutor implements the agent's core logic for code generation.
*/
export class CoderAgentExecutor implements AgentExecutor {
private tasks: Map<string, TaskWrapper> = new Map();
// Track tasks with an active execution loop.
private executingTasks = new Set<string>();
constructor(private taskStore?: TaskStore) {}
private async getConfig(
agentSettings: AgentSettings,
taskId: string,
): Promise<Config> {
const workspaceRoot = setTargetDir(agentSettings);
loadEnvironment(); // Will override any global env with workspace envs
const settings = loadSettings(workspaceRoot);
const extensions = loadExtensions(workspaceRoot);
return loadConfig(settings, new SimpleExtensionLoader(extensions), taskId);
}
/**
* Reconstructs TaskWrapper from SDKTask.
*/
async reconstruct(
sdkTask: SDKTask,
eventBus?: ExecutionEventBus,
): Promise<TaskWrapper> {
const metadata = sdkTask.metadata || {};
const persistedState = getPersistedState(metadata);
if (!persistedState) {
throw new Error(
`Cannot reconstruct task ${sdkTask.id}: missing persisted state in metadata.`,
);
}
const agentSettings = persistedState._agentSettings;
const config = await this.getConfig(agentSettings, sdkTask.id);
const contextId: string =
getContextIdFromMetadata(metadata) || sdkTask.contextId;
const runtimeTask = await Task.create(
sdkTask.id,
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
runtimeTask.taskState = persistedState._taskState;
await runtimeTask.geminiClient.initialize();
const wrapper = new TaskWrapper(runtimeTask, agentSettings);
this.tasks.set(sdkTask.id, wrapper);
logger.info(`Task ${sdkTask.id} reconstructed from store.`);
return wrapper;
}
async createTask(
taskId: string,
contextId: string,
agentSettingsInput?: AgentSettings,
eventBus?: ExecutionEventBus,
): Promise<TaskWrapper> {
const agentSettings: AgentSettings = agentSettingsInput || {
kind: CoderAgentEvent.StateAgentSettingsEvent,
workspacePath: process.cwd(),
};
const config = await this.getConfig(agentSettings, taskId);
const runtimeTask = await Task.create(
taskId,
contextId,
config,
eventBus,
agentSettings.autoExecute,
);
await runtimeTask.geminiClient.initialize();
const wrapper = new TaskWrapper(runtimeTask, agentSettings);
this.tasks.set(taskId, wrapper);
logger.info(`New task ${taskId} created.`);
return wrapper;
}
getTask(taskId: string): TaskWrapper | undefined {
return this.tasks.get(taskId);
}
getAllTasks(): TaskWrapper[] {
return Array.from(this.tasks.values());
}
cancelTask = async (
taskId: string,
eventBus: ExecutionEventBus,
): Promise<void> => {
logger.info(
`[CoderAgentExecutor] Received cancel request for task ${taskId}`,
);
const wrapper = this.tasks.get(taskId);
if (!wrapper) {
logger.warn(
`[CoderAgentExecutor] Task ${taskId} not found for cancellation.`,
);
eventBus.publish({
kind: 'status-update',
taskId,
contextId: uuidv4(),
status: {
state: 'failed',
message: {
kind: 'message',
role: 'agent',
parts: [{ kind: 'text', text: `Task ${taskId} not found.` }],
messageId: uuidv4(),
taskId,
},
},
final: true,
});
return;
}
const { task } = wrapper;
if (task.taskState === 'canceled' || task.taskState === 'failed') {
logger.info(
`[CoderAgentExecutor] Task ${taskId} is already in a final state: ${task.taskState}. No action needed for cancellation.`,
);
eventBus.publish({
kind: 'status-update',
taskId,
contextId: task.contextId,
status: {
state: task.taskState,
message: {
kind: 'message',
role: 'agent',
parts: [
{
kind: 'text',
text: `Task ${taskId} is already ${task.taskState}.`,
},
],
messageId: uuidv4(),
taskId,
},
},
final: true,
});
return;
}
try {
logger.info(
`[CoderAgentExecutor] Initiating cancellation for task ${taskId}.`,
);
task.cancelPendingTools('Task canceled by user request.');
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
task.setTaskStateAndPublishUpdate(
'canceled',
stateChange,
'Task canceled by user request.',
undefined,
true,
);
logger.info(
`[CoderAgentExecutor] Task ${taskId} cancellation processed. Saving state.`,
);
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`);
} catch (error) {
const errorMessage =
error instanceof Error ? error.message : 'Unknown error';
logger.error(
`[CoderAgentExecutor] Error during task cancellation for ${taskId}: ${errorMessage}`,
error,
);
eventBus.publish({
kind: 'status-update',
taskId,
contextId: task.contextId,
status: {
state: 'failed',
message: {
kind: 'message',
role: 'agent',
parts: [
{
kind: 'text',
text: `Failed to process cancellation for task ${taskId}: ${errorMessage}`,
},
],
messageId: uuidv4(),
taskId,
},
},
final: true,
});
}
};
async execute(
requestContext: RequestContext,
eventBus: ExecutionEventBus,
): Promise<void> {
const userMessage = requestContext.userMessage;
const sdkTask = requestContext.task;
const taskId = sdkTask?.id || userMessage.taskId || uuidv4();
const contextId: string =
userMessage.contextId ||
sdkTask?.contextId ||
getContextIdFromMetadata(sdkTask?.metadata) ||
uuidv4();
logger.info(
`[CoderAgentExecutor] Executing for taskId: ${taskId}, contextId: ${contextId}`,
);
logger.info(
`[CoderAgentExecutor] userMessage: ${JSON.stringify(userMessage)}`,
);
eventBus.on('event', (event: AgentExecutionEvent) =>
logger.info('[EventBus event]: ', event),
);
const store = requestStorage.getStore();
if (!store) {
logger.error(
'[CoderAgentExecutor] Could not get request from async local storage. Cancellation on socket close will not be handled for this request.',
);
}
const abortController = new AbortController();
const abortSignal = abortController.signal;
if (store) {
// Grab the raw socket from the request object
const socket = store.req.socket;
const onClientEnd = () => {
logger.info(
`[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`,
);
if (!abortController.signal.aborted) {
abortController.abort();
}
// Clean up the listener to prevent memory leaks
socket.removeListener('close', onClientEnd);
};
// Listen on the socket's 'end' event (remote closed the connection)
socket.on('end', onClientEnd);
// It's also good practice to remove the listener if the task completes successfully
abortSignal.addEventListener('abort', () => {
socket.removeListener('end', onClientEnd);
});
logger.info(
`[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`,
);
}
let wrapper: TaskWrapper | undefined = this.tasks.get(taskId);
if (wrapper) {
wrapper.task.eventBus = eventBus;
logger.info(`[CoderAgentExecutor] Task ${taskId} found in memory cache.`);
} else if (sdkTask) {
logger.info(
`[CoderAgentExecutor] Task ${taskId} found in TaskStore. Reconstructing...`,
);
try {
wrapper = await this.reconstruct(sdkTask, eventBus);
} catch (e) {
logger.error(
`[CoderAgentExecutor] Failed to hydrate task ${taskId}:`,
e,
);
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
eventBus.publish({
kind: 'status-update',
taskId,
contextId: sdkTask.contextId,
status: {
state: 'failed',
message: {
kind: 'message',
role: 'agent',
parts: [
{
kind: 'text',
text: 'Internal error: Task state lost or corrupted.',
},
],
messageId: uuidv4(),
taskId,
contextId: sdkTask.contextId,
} as Message,
},
final: true,
metadata: { coderAgent: stateChange },
});
return;
}
} else {
logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`);
const agentSettings = getAgentSettingsFromMetadata(userMessage.metadata);
try {
wrapper = await this.createTask(
taskId,
contextId,
agentSettings,
eventBus,
);
} catch (error) {
logger.error(
`[CoderAgentExecutor] Error creating task ${taskId}:`,
error,
);
// eslint-disable-next-line @typescript-eslint/no-floating-promises
pushTaskStateFailed(error, eventBus, taskId, contextId);
return;
}
const newTaskSDK = wrapper.toSDKTask();
eventBus.publish({
...newTaskSDK,
kind: 'task',
status: { state: 'submitted', timestamp: new Date().toISOString() },
history: [userMessage],
});
try {
await this.taskStore?.save(newTaskSDK);
logger.info(`[CoderAgentExecutor] New task ${taskId} saved to store.`);
} catch (saveError) {
logger.error(
`[CoderAgentExecutor] Failed to save new task ${taskId} to store:`,
saveError,
);
}
}
if (!wrapper) {
logger.error(
`[CoderAgentExecutor] Task ${taskId} is unexpectedly undefined after load/create.`,
);
return;
}
const currentTask = wrapper.task;
if (['canceled', 'failed', 'completed'].includes(currentTask.taskState)) {
logger.warn(
`[CoderAgentExecutor] Attempted to execute task ${taskId} which is already in state ${currentTask.taskState}. Ignoring.`,
);
return;
}
if (this.executingTasks.has(taskId)) {
logger.info(
`[CoderAgentExecutor] Task ${taskId} has a pending execution. Processing message and yielding.`,
);
currentTask.eventBus = eventBus;
for await (const _ of currentTask.acceptUserMessage(
requestContext,
abortController.signal,
)) {
logger.info(
`[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`,
);
}
// End this execution-- the original/source will be resumed.
return;
}
logger.info(
`[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`,
);
this.executingTasks.add(taskId);
try {
let agentTurnActive = true;
logger.info(`[CoderAgentExecutor] Task ${taskId}: Processing user turn.`);
let agentEvents = currentTask.acceptUserMessage(
requestContext,
abortSignal,
);
while (agentTurnActive) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Processing agent turn (LLM stream).`,
);
const toolCallRequests: ToolCallRequestInfo[] = [];
for await (const event of agentEvents) {
if (abortSignal.aborted) {
logger.warn(
`[CoderAgentExecutor] Task ${taskId}: Abort signal received during agent event processing.`,
);
throw new Error('Execution aborted');
}
if (event.type === GeminiEventType.ToolCallRequest) {
toolCallRequests.push(event.value);
continue;
}
await currentTask.acceptAgentMessage(event);
}
if (abortSignal.aborted) throw new Error('Execution aborted');
if (toolCallRequests.length > 0) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Found ${toolCallRequests.length} tool call requests. Scheduling as a batch.`,
);
await currentTask.scheduleToolCalls(toolCallRequests, abortSignal);
}
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Waiting for pending tools if any.`,
);
await currentTask.waitForPendingTools();
logger.info(
`[CoderAgentExecutor] Task ${taskId}: All pending tools completed or none were pending.`,
);
if (abortSignal.aborted) throw new Error('Execution aborted');
const completedTools = currentTask.getAndClearCompletedTools();
if (completedTools.length > 0) {
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
if (completedTools.every((tool) => tool.status === 'cancelled')) {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
);
currentTask.addToolResponsesToHistory(completedTools);
agentTurnActive = false;
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
currentTask.setTaskStateAndPublishUpdate(
'input-required',
stateChange,
undefined,
undefined,
true,
);
} else {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
);
agentEvents = currentTask.sendCompletedToolsToLlm(
completedTools,
abortSignal,
);
// Continue the loop to process the LLM response to the tool results.
}
} else {
logger.info(
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
);
agentTurnActive = false;
}
}
logger.info(
`[CoderAgentExecutor] Task ${taskId}: Agent turn finished, setting to input-required.`,
);
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
currentTask.setTaskStateAndPublishUpdate(
'input-required',
stateChange,
undefined,
undefined,
true,
);
} catch (error) {
if (abortSignal.aborted) {
logger.warn(`[CoderAgentExecutor] Task ${taskId} execution aborted.`);
currentTask.cancelPendingTools('Execution aborted');
if (
currentTask.taskState !== 'canceled' &&
currentTask.taskState !== 'failed'
) {
currentTask.setTaskStateAndPublishUpdate(
'input-required',
{ kind: CoderAgentEvent.StateChangeEvent },
'Execution aborted by client.',
undefined,
true,
);
}
} else {
const errorMessage =
error instanceof Error ? error.message : 'Agent execution error';
logger.error(
`[CoderAgentExecutor] Error executing agent for task ${taskId}:`,
error,
);
currentTask.cancelPendingTools(errorMessage);
if (currentTask.taskState !== 'failed') {
const stateChange: StateChange = {
kind: CoderAgentEvent.StateChangeEvent,
};
currentTask.setTaskStateAndPublishUpdate(
'failed',
stateChange,
errorMessage,
undefined,
true,
);
}
}
} finally {
this.executingTasks.delete(taskId);
logger.info(
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
);
try {
await this.taskStore?.save(wrapper.toSDKTask());
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
} catch (saveError) {
logger.error(
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
saveError,
);
}
}
}
}