mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-18 18:11:02 -07:00
324 lines
10 KiB
TypeScript
324 lines
10 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import type {
|
|
AgentCard,
|
|
Message,
|
|
MessageSendParams,
|
|
Task,
|
|
TaskStatusUpdateEvent,
|
|
TaskArtifactUpdateEvent,
|
|
} from '@a2a-js/sdk';
|
|
import {
|
|
ClientFactory,
|
|
ClientFactoryOptions,
|
|
DefaultAgentCardResolver,
|
|
RestTransportFactory,
|
|
JsonRpcTransportFactory,
|
|
type AuthenticationHandler,
|
|
createAuthenticatingFetchWithRetry,
|
|
type Client,
|
|
} from '@a2a-js/sdk/client';
|
|
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
|
import { v4 as uuidv4 } from 'uuid';
|
|
import { Agent as UndiciAgent } from 'undici';
|
|
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
|
|
import { isPrivateIp } from '../utils/fetch.js';
|
|
import { debugLogger } from '../utils/debugLogger.js';
|
|
|
|
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
|
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
|
|
const A2A_TIMEOUT = 1800000; // 30 minutes
|
|
const a2aDispatcher = new UndiciAgent({
|
|
headersTimeout: A2A_TIMEOUT,
|
|
bodyTimeout: A2A_TIMEOUT,
|
|
});
|
|
const a2aFetch: typeof fetch = (input, init) =>
|
|
// @ts-expect-error The `dispatcher` property is a Node.js extension to fetch not present in standard types.
|
|
fetch(input, { ...init, dispatcher: a2aDispatcher });
|
|
|
|
export type SendMessageResult =
|
|
| Message
|
|
| Task
|
|
| TaskStatusUpdateEvent
|
|
| TaskArtifactUpdateEvent;
|
|
|
|
/**
|
|
* Orchestrates communication with A2A agents.
|
|
*
|
|
* This manager handles agent discovery, card caching, and client lifecycle.
|
|
* It provides a unified messaging interface using the standard A2A SDK.
|
|
*/
|
|
export class A2AClientManager {
|
|
private static instance: A2AClientManager;
|
|
|
|
// Each agent should manage their own context/taskIds/card/etc
|
|
private clients = new Map<string, Client>();
|
|
private agentCards = new Map<string, AgentCard>();
|
|
|
|
private constructor() {}
|
|
|
|
/**
|
|
* Gets the singleton instance of the A2AClientManager.
|
|
*/
|
|
static getInstance(): A2AClientManager {
|
|
if (!A2AClientManager.instance) {
|
|
A2AClientManager.instance = new A2AClientManager();
|
|
}
|
|
return A2AClientManager.instance;
|
|
}
|
|
|
|
/**
|
|
* Resets the singleton instance. Only for testing purposes.
|
|
* @internal
|
|
*/
|
|
static resetInstanceForTesting() {
|
|
// @ts-expect-error - Resetting singleton for testing
|
|
A2AClientManager.instance = undefined;
|
|
}
|
|
|
|
/**
|
|
* Loads an agent by fetching its AgentCard and caches the client.
|
|
* @param name The name to assign to the agent.
|
|
* @param agentCardUrl The full URL to the agent's card.
|
|
* @param authHandler Optional authentication handler to use for this agent.
|
|
* @returns The loaded AgentCard.
|
|
*/
|
|
async loadAgent(
|
|
name: string,
|
|
agentCardUrl: string,
|
|
authHandler?: AuthenticationHandler,
|
|
): Promise<AgentCard> {
|
|
if (this.clients.has(name) && this.agentCards.has(name)) {
|
|
throw new Error(`Agent with name '${name}' is already loaded.`);
|
|
}
|
|
|
|
const fetchImpl = this.getFetchImpl(authHandler);
|
|
const resolver = new DefaultAgentCardResolver({ fetchImpl });
|
|
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
|
|
|
|
// Configure standard SDK client for tool registration and discovery
|
|
const clientOptions = ClientFactoryOptions.createFrom(
|
|
ClientFactoryOptions.default,
|
|
{
|
|
transports: [
|
|
new RestTransportFactory({ fetchImpl }),
|
|
new JsonRpcTransportFactory({ fetchImpl }),
|
|
new GrpcTransportFactory({
|
|
grpcChannelCredentials: getGrpcCredentials(agentCard.url),
|
|
}),
|
|
],
|
|
cardResolver: resolver,
|
|
},
|
|
);
|
|
const factory = new ClientFactory(clientOptions);
|
|
const client = await factory.createFromAgentCard(agentCard);
|
|
|
|
this.clients.set(name, client);
|
|
this.agentCards.set(name, agentCard);
|
|
|
|
debugLogger.debug(
|
|
`[A2AClientManager] Loaded agent '${name}' from ${agentCardUrl}`,
|
|
);
|
|
|
|
return agentCard;
|
|
}
|
|
|
|
/**
|
|
* Invalidates all cached clients and agent cards.
|
|
*/
|
|
clearCache(): void {
|
|
this.clients.clear();
|
|
this.agentCards.clear();
|
|
debugLogger.debug('[A2AClientManager] Cache cleared.');
|
|
}
|
|
|
|
/**
|
|
* Sends a message to a loaded agent and returns a stream of responses.
|
|
* @param agentName The name of the agent to send the message to.
|
|
* @param message The message content.
|
|
* @param options Optional context and task IDs to maintain conversation state.
|
|
* @returns An async iterable of responses from the agent (Message or Task).
|
|
* @throws Error if the agent returns an error response.
|
|
*/
|
|
async *sendMessageStream(
|
|
agentName: string,
|
|
message: string,
|
|
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
|
|
): AsyncIterable<SendMessageResult> {
|
|
const client = this.clients.get(agentName);
|
|
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
|
|
|
const messageParams: MessageSendParams = {
|
|
message: {
|
|
kind: 'message',
|
|
role: 'user',
|
|
messageId: uuidv4(),
|
|
parts: [{ kind: 'text', text: message }],
|
|
contextId: options?.contextId,
|
|
taskId: options?.taskId,
|
|
},
|
|
};
|
|
|
|
try {
|
|
yield* client.sendMessageStream(messageParams, {
|
|
signal: options?.signal,
|
|
});
|
|
} catch (error: unknown) {
|
|
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
|
|
if (error instanceof Error) {
|
|
throw new Error(`${prefix}: ${error.message}`, { cause: error });
|
|
}
|
|
throw new Error(
|
|
`${prefix}: Unexpected error during sendMessageStream: ${String(error)}`,
|
|
);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Retrieves a loaded agent card.
|
|
* @param name The name of the agent.
|
|
* @returns The agent card, or undefined if not found.
|
|
*/
|
|
getAgentCard(name: string): AgentCard | undefined {
|
|
return this.agentCards.get(name);
|
|
}
|
|
|
|
/**
|
|
* Retrieves a loaded client.
|
|
* @param name The name of the agent.
|
|
* @returns The client, or undefined if not found.
|
|
*/
|
|
getClient(name: string): Client | undefined {
|
|
return this.clients.get(name);
|
|
}
|
|
|
|
/**
|
|
* Retrieves a task from an agent.
|
|
* @param agentName The name of the agent.
|
|
* @param taskId The ID of the task to retrieve.
|
|
* @returns The task details.
|
|
*/
|
|
async getTask(agentName: string, taskId: string): Promise<Task> {
|
|
const client = this.clients.get(agentName);
|
|
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
|
try {
|
|
return await client.getTask({ id: taskId });
|
|
} catch (error: unknown) {
|
|
const prefix = `A2AClient getTask Error [${agentName}]`;
|
|
if (error instanceof Error) {
|
|
throw new Error(`${prefix}: ${error.message}`, { cause: error });
|
|
}
|
|
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Cancels a task on an agent.
|
|
* @param agentName The name of the agent.
|
|
* @param taskId The ID of the task to cancel.
|
|
* @returns The cancellation response.
|
|
*/
|
|
async cancelTask(agentName: string, taskId: string): Promise<Task> {
|
|
const client = this.clients.get(agentName);
|
|
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
|
try {
|
|
return await client.cancelTask({ id: taskId });
|
|
} catch (error: unknown) {
|
|
const prefix = `A2AClient cancelTask Error [${agentName}]`;
|
|
if (error instanceof Error) {
|
|
throw new Error(`${prefix}: ${error.message}`, { cause: error });
|
|
}
|
|
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Resolves the appropriate fetch implementation for an agent.
|
|
*/
|
|
private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch {
|
|
return authHandler
|
|
? createAuthenticatingFetchWithRetry(a2aFetch, authHandler)
|
|
: a2aFetch;
|
|
}
|
|
|
|
/**
|
|
* Resolves and normalizes an agent card from a given URL.
|
|
* Handles splitting the URL if it already contains the standard .well-known path.
|
|
* Also performs basic SSRF validation to prevent internal IP access.
|
|
*/
|
|
private async resolveAgentCard(
|
|
agentName: string,
|
|
url: string,
|
|
resolver: DefaultAgentCardResolver,
|
|
): Promise<AgentCard> {
|
|
const standardPath = '.well-known/agent-card.json';
|
|
let baseUrl = url;
|
|
let path: string | undefined;
|
|
|
|
// Validate URL to prevent SSRF
|
|
if (isPrivateIp(url)) {
|
|
// Local/private IPs are allowed ONLY for localhost for testing.
|
|
const parsed = new URL(url);
|
|
if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') {
|
|
throw new Error(
|
|
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
|
|
);
|
|
}
|
|
}
|
|
|
|
try {
|
|
const parsedUrl = new URL(url);
|
|
if (parsedUrl.pathname.endsWith(standardPath)) {
|
|
// Correctly split the URL into baseUrl and standard path
|
|
path = standardPath;
|
|
baseUrl = url.substring(0, url.lastIndexOf(standardPath));
|
|
}
|
|
} catch (e) {
|
|
throw new Error(`Invalid agent card URL: ${url}`, { cause: e });
|
|
}
|
|
|
|
const rawCard = await resolver.resolve(baseUrl, path);
|
|
const agentCard = normalizeAgentCard(rawCard);
|
|
|
|
// Deep validation of all transport URLs within the card to prevent SSRF
|
|
this.validateAgentCardUrls(agentName, agentCard);
|
|
|
|
return agentCard;
|
|
}
|
|
|
|
/**
|
|
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
|
|
*/
|
|
private validateAgentCardUrls(agentName: string, card: AgentCard): void {
|
|
const urlsToValidate = [card.url];
|
|
if (card.additionalInterfaces) {
|
|
for (const intf of card.additionalInterfaces) {
|
|
if (intf.url) urlsToValidate.push(intf.url);
|
|
}
|
|
}
|
|
|
|
for (const url of urlsToValidate) {
|
|
if (!url) continue;
|
|
|
|
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
|
|
const validationUrl = url.includes('://') ? url : `http://${url}`;
|
|
|
|
if (isPrivateIp(validationUrl)) {
|
|
const parsed = new URL(validationUrl);
|
|
if (
|
|
parsed.hostname !== 'localhost' &&
|
|
parsed.hostname !== '127.0.0.1'
|
|
) {
|
|
throw new Error(
|
|
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|