feat(a2a): enable native gRPC support and protocol routing

This commit is contained in:
Alisa Novikova
2026-03-06 06:22:12 -08:00
parent 25aa217306
commit 51aa3cf9db
2 changed files with 458 additions and 197 deletions

View File

@@ -5,96 +5,119 @@
*/
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import {
A2AClientManager,
type SendMessageResult,
} from './a2a-client-manager.js';
import type { AgentCard, Task } from '@a2a-js/sdk';
import {
ClientFactory,
DefaultAgentCardResolver,
createAuthenticatingFetchWithRetry,
ClientFactoryOptions,
type AuthenticationHandler,
type Client,
} from '@a2a-js/sdk/client';
import { A2AClientManager } from './a2a-client-manager.js';
import type { AgentCard } from '@a2a-js/sdk';
import * as sdkClient from '@a2a-js/sdk/client';
import * as dnsPromises from 'node:dns/promises';
import type { LookupOptions } from 'node:dns';
import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
sendMessageStream: ReturnType<typeof vi.fn>;
getTask: ReturnType<typeof vi.fn>;
cancelTask: ReturnType<typeof vi.fn>;
}
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as Record<string, unknown>),
createAuthenticatingFetchWithRetry: vi.fn(),
ClientFactory: vi.fn(),
DefaultAgentCardResolver: vi.fn(),
ClientFactoryOptions: {
createFrom: vi.fn(),
default: {},
},
};
});
vi.mock('../utils/debugLogger.js', () => ({
debugLogger: {
debug: vi.fn(),
},
}));
vi.mock('@a2a-js/sdk/client', () => {
const ClientFactory = vi.fn();
const DefaultAgentCardResolver = vi.fn();
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
describe('A2AClientManager', () => {
let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
name: 'test-agent',
description: 'A test agent',
url: 'http://test.agent',
version: '1.0.0',
protocolVersion: '0.1.0',
capabilities: {},
skills: [],
defaultInputModes: [],
defaultOutputModes: [],
};
const mockClient: MockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
};
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => {
vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance();
manager.clearCache();
// Default mock implementations
getAgentCardMock.mockResolvedValue({
// Default DNS mock: resolve to public IP.
// Use any cast only for return value due to complex multi-signature overloads.
vi.mocked(dnsPromises.lookup).mockImplementation(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '93.184.216.34', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
createFromAgentCard: vi.fn(),
};
const resolverInstance = {
resolve: vi.fn(),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
mockClient,
vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation(
(_defaults, overrides) =>
overrides as unknown as sdkClient.ClientFactoryOptions,
);
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
...mockAgentCard,
url: 'http://test.agent/real/endpoint',
} as AgentCard);
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation(
() =>
authFetchMock.mockResolvedValue({
ok: true,
json: async () => ({}),
} as Response),
);
vi.stubGlobal(
@@ -123,137 +146,194 @@ describe('A2AClientManager', () => {
'TestAgent',
'http://test.agent/card',
);
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined();
});
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled();
});
it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'),
manager.loadAgent('TestAgent', 'http://test.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
});
it('should use native fetch by default', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
expect(
sdkClient.createAuthenticatingFetchWithRetry,
).not.toHaveBeenCalled();
});
it('should use provided custom authentication handler', async () => {
const customAuthHandler = {
headers: vi.fn(),
shouldRetryWithHeaders: vi.fn(),
const authHandler: sdkClient.AuthenticationHandler = {
headers: async () => ({}),
shouldRetryWithHeaders: async () => undefined,
};
await manager.loadAgent(
'CustomAuthAgent',
'http://custom.agent/card',
customAuthHandler as unknown as AuthenticationHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
'TestAgent',
'http://test.agent/card',
authHandler,
);
expect(sdkClient.createAuthenticatingFetchWithRetry).toHaveBeenCalled();
});
it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
expect.stringContaining("Loaded agent 'TestAgent'"),
);
});
it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith(
'[A2AClientManager] Cache cleared.',
});
it('should throw if resolveAgentCard fails', async () => {
const resolverInstance = {
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Resolution failed');
});
it('should throw if factory.createFromAgentCard fails', async () => {
const factoryInstance = {
createFromAgentCard: vi
.fn()
.mockRejectedValue(new Error('Factory failed')),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Factory failed');
});
});
describe('getAgentCard and getClient', () => {
it('should return undefined if agent is not found', () => {
expect(manager.getAgentCard('Unknown')).toBeUndefined();
expect(manager.getClient('Unknown')).toBeUndefined();
});
});
describe('sendMessageStream', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should send a message and return a stream', async () => {
const mockResult = {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield mockResult;
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const res of stream) {
results.push(res);
for await (const result of stream) {
results.push(result);
}
expect(results).toEqual([mockResult]);
expect(sendMessageStreamMock).toHaveBeenCalledWith(
expect(results).toHaveLength(1);
expect(mockClient.sendMessageStream).toHaveBeenCalled();
});
it('should use contextId and taskId when provided', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: 'ctx123',
taskId: 'task456',
});
// trigger execution
for await (const _ of stream) {
break;
}
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.anything(),
message: expect.objectContaining({
contextId: 'ctx123',
taskId: 'task456',
}),
}),
expect.any(Object),
);
});
it('should use contextId and taskId when provided', async () => {
sendMessageStreamMock.mockReturnValue(
it('should correctly propagate AbortSignal to the stream', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield {
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
yield { kind: 'message' };
})(),
);
const expectedContextId = 'user-context-id';
const expectedTaskId = 'user-task-id';
const controller = new AbortController();
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId,
taskId: expectedTaskId,
signal: controller.signal,
});
// trigger execution
for await (const _ of stream) {
// consume stream
break;
}
const call = sendMessageStreamMock.mock.calls[0][0];
expect(call.message.contextId).toBe(expectedContextId);
expect(call.message.taskId).toBe(expectedTaskId);
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ signal: controller.signal }),
);
});
it('should handle a multi-chunk stream with different event types', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message', messageId: 'm1' };
yield { kind: 'status-update', taskId: 't1' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const result of stream) {
results.push(result);
}
expect(results).toHaveLength(2);
expect(results[0].kind).toBe('message');
expect(results[1].kind).toBe('status-update');
});
it('should throw prefixed error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => {
throw new Error('Network error');
mockClient.sendMessageStream.mockImplementation(() => {
throw new Error('Network failure');
});
const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error',
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
);
});
@@ -261,7 +341,7 @@ describe('A2AClientManager', () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => {
for await (const _ of stream) {
// consume
// empty
}
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
@@ -269,28 +349,23 @@ describe('A2AClientManager', () => {
describe('getTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.getTask.mockResolvedValue(mockTask);
await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.getTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.getTask.mockRejectedValue(new Error('Not found'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error',
'A2AClient getTask Error [TestAgent]: Not found',
);
});
@@ -303,28 +378,23 @@ describe('A2AClientManager', () => {
describe('cancelTask', () => {
beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent');
await manager.loadAgent('TestAgent', 'http://test.agent/card');
});
it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({
id: 'task123',
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
const mockTask = { id: 'task123', kind: 'task' };
mockClient.cancelTask.mockResolvedValue(mockTask);
await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({
id: 'task123',
});
const result = await manager.cancelTask('TestAgent', 'task123');
expect(result).toBe(mockTask);
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
});
it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error',
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
);
});
@@ -334,4 +404,82 @@ describe('A2AClientManager', () => {
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('Protocol Routing & URL Logic', () => {
it('should correctly split URLs to prevent .well-known doubling', async () => {
const fullUrl = 'http://localhost:9001/.well-known/agent-card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await manager.loadAgent('test-doubling', fullUrl);
expect(resolverInstance.resolve).toHaveBeenCalledWith(
'http://localhost:9001/',
undefined,
);
});
it('should throw if a remote agent uses a private IP (SSRF protection)', async () => {
const privateUrl = 'http://169.254.169.254/.well-known/agent-card.json';
await expect(manager.loadAgent('ssrf-agent', privateUrl)).rejects.toThrow(
/Refusing to load agent 'ssrf-agent' from private IP range/,
);
});
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
const maliciousDomainUrl =
'http://malicious.com/.well-known/agent-card.json';
vi.mocked(dnsPromises.lookup).mockImplementationOnce(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '10.0.0.1', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
).rejects.toThrow(/private IP range/);
});
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
const publicUrl = 'https://public.agent.com/card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({
...mockAgentCard,
url: 'http://192.168.1.1/api', // Malicious private transport in public card
} as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
// DNS for public.agent.com is public
vi.mocked(dnsPromises.lookup).mockImplementation(
async (hostname: string, options?: LookupOptions | number) => {
const isAll = typeof options === 'object' && options?.all;
if (hostname === 'public.agent.com') {
const addr = { address: '1.1.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
}
const addr = { address: '192.168.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('malicious-agent', publicUrl),
).rejects.toThrow(
/contains transport URL pointing to private IP range: http:\/\/192.168.1.1\/api/,
);
});
});
});

View File

@@ -12,20 +12,56 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
import {
type Client,
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
type AuthenticationHandler,
RestTransportFactory,
createAuthenticatingFetchWithRetry,
} from '@a2a-js/sdk/client';
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent } from 'undici';
import {
getGrpcChannelOptions,
getGrpcCredentials,
normalizeAgentCard,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import {
isPrivateIpAsync,
safeLookup,
isLoopbackHost,
} from '../utils/fetch.js';
import { debugLogger } from '../utils/debugLogger.js';
import { safeLookup } from '../utils/fetch.js';
/**
* Result of sending a message, which can be a full message, a task,
* or an incremental status/artifact update.
*/
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
/**
* Internal interface representing properties we inject into the SDK
* to enable DNS rebinding protection for gRPC connections.
* TODO: Replace with official SDK pinning API once available.
*/
interface InternalGrpcExtensions {
target: string;
grpcChannelOptions: Record<string, unknown>;
}
// Local extension of RequestInit to support Node.js/undici dispatcher
interface NodeFetchInit extends RequestInit {
dispatcher?: UndiciAgent;
}
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
@@ -34,22 +70,18 @@ const a2aDispatcher = new UndiciAgent({
headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT,
connect: {
lookup: safeLookup, // SSRF protection at connection level
// SSRF protection at the connection level (mitigates DNS rebinding)
lookup: safeLookup,
},
});
const a2aFetch: typeof fetch = (input, init) =>
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
fetch(input, { ...init, dispatcher: a2aDispatcher } as RequestInit);
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
const a2aFetch: typeof fetch = (input, init) => {
const nodeInit: NodeFetchInit = { ...init, dispatcher: a2aDispatcher };
return fetch(input, nodeInit as RequestInit);
};
/**
* Manages A2A clients and caches loaded agent information.
* Follows a singleton pattern to ensure a single client instance.
* Orchestrates communication with remote A2A agents.
* Manages protocol negotiation, authentication, and transport selection.
*/
export class A2AClientManager {
private static instance: A2AClientManager;
@@ -70,19 +102,10 @@ export class A2AClientManager {
return A2AClientManager.instance;
}
/**
* Resets the singleton instance. Only for testing purposes.
* @internal
*/
static resetInstanceForTesting() {
// @ts-expect-error - Resetting singleton for testing
A2AClientManager.instance = undefined;
}
/**
* Loads an agent by fetching its AgentCard and caches the client.
* @param name The name to assign to the agent.
* @param agentCardUrl The full URL to the agent's card.
* @param agentCardUrl {string} The full URL to the agent's card.
* @param authHandler Optional authentication handler to use for this agent.
* @returns The loaded AgentCard.
*/
@@ -95,27 +118,52 @@ export class A2AClientManager {
throw new Error(`Agent with name '${name}' is already loaded.`);
}
let fetchImpl: typeof fetch = a2aFetch;
if (authHandler) {
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
}
const fetchImpl = this.getFetchImpl(authHandler);
const resolver = new DefaultAgentCardResolver({ fetchImpl });
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
const options = ClientFactoryOptions.createFrom(
// Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection)
const grpcInterface = agentCard.additionalInterfaces?.find(
(i) => i.transport === 'GRPC',
);
const urlToPin = grpcInterface?.url ?? agentCard.url;
const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name);
// Prepare base gRPC options
const baseGrpcOptions: ConstructorParameters<
typeof GrpcTransportFactory
>[0] = {
grpcChannelCredentials: getGrpcCredentials(urlToPin),
};
// We inject additional properties into the transport options to force
// the use of a pinned IP address and matching SSL authority. This is
// required for robust DNS Rebinding protection.
const transportOptions = {
...baseGrpcOptions,
target: pinnedUrl,
grpcChannelOptions: getGrpcChannelOptions(hostname),
} as ConstructorParameters<typeof GrpcTransportFactory>[0] &
InternalGrpcExtensions;
// Configure standard SDK client for tool registration and discovery
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
{
transports: [
new RestTransportFactory({ fetchImpl }),
new JsonRpcTransportFactory({ fetchImpl }),
new GrpcTransportFactory(
transportOptions as ConstructorParameters<
typeof GrpcTransportFactory
>[0],
),
],
cardResolver: resolver,
},
);
const factory = new ClientFactory(options);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
const factory = new ClientFactory(clientOptions);
const client = await factory.createFromAgentCard(agentCard);
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
@@ -150,9 +198,7 @@ export class A2AClientManager {
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
): AsyncIterable<SendMessageResult> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
const messageParams: MessageSendParams = {
message: {
@@ -168,7 +214,7 @@ export class A2AClientManager {
try {
yield* client.sendMessageStream(messageParams, {
signal: options?.signal,
});
}) as AsyncIterable<SendMessageResult>;
} catch (error: unknown) {
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
if (error instanceof Error) {
@@ -206,9 +252,7 @@ export class A2AClientManager {
*/
async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.getTask({ id: taskId });
} catch (error: unknown) {
@@ -228,9 +272,7 @@ export class A2AClientManager {
*/
async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName);
if (!client) {
throw new Error(`Agent '${agentName}' not found.`);
}
if (!client) throw new Error(`Agent '${agentName}' not found.`);
try {
return await client.cancelTask({ id: taskId });
} catch (error: unknown) {
@@ -241,4 +283,75 @@ export class A2AClientManager {
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
}
}
/**
* Resolves the appropriate fetch implementation for an agent.
*/
private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch {
return authHandler
? createAuthenticatingFetchWithRetry(a2aFetch, authHandler)
: a2aFetch;
}
/**
* Resolves and normalizes an agent card from a given URL.
* Handles splitting the URL if it already contains the standard .well-known path.
* Also performs basic SSRF validation to prevent internal IP access.
*/
private async resolveAgentCard(
agentName: string,
url: string,
resolver: DefaultAgentCardResolver,
): Promise<AgentCard> {
// Validate URL to prevent SSRF (with DNS resolution)
if (await isPrivateIpAsync(url)) {
// Local/private IPs are allowed ONLY for localhost for testing.
const parsed = new URL(url);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
);
}
}
const { baseUrl, path } = splitAgentCardUrl(url);
const rawCard = await resolver.resolve(baseUrl, path);
const agentCard = normalizeAgentCard(rawCard);
// Deep validation of all transport URLs within the card to prevent SSRF
await this.validateAgentCardUrls(agentName, agentCard);
return agentCard;
}
/**
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
*/
private async validateAgentCardUrls(
agentName: string,
card: AgentCard,
): Promise<void> {
const urlsToValidate = [card.url];
if (card.additionalInterfaces) {
for (const intf of card.additionalInterfaces) {
if (intf.url) urlsToValidate.push(intf.url);
}
}
for (const url of urlsToValidate) {
if (!url) continue;
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
const validationUrl = url.includes('://') ? url : `http://${url}`;
if (await isPrivateIpAsync(validationUrl)) {
const parsed = new URL(validationUrl);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
);
}
}
}
}
}