feat(a2a): enable native gRPC support and protocol routing

This commit is contained in:
Alisa Novikova
2026-03-06 06:22:12 -08:00
parent 25aa217306
commit 51aa3cf9db
2 changed files with 458 additions and 197 deletions

View File

@@ -5,96 +5,119 @@
*/ */
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { import { A2AClientManager } from './a2a-client-manager.js';
A2AClientManager, import type { AgentCard } from '@a2a-js/sdk';
type SendMessageResult, import * as sdkClient from '@a2a-js/sdk/client';
} from './a2a-client-manager.js'; import * as dnsPromises from 'node:dns/promises';
import type { AgentCard, Task } from '@a2a-js/sdk'; import type { LookupOptions } from 'node:dns';
import {
ClientFactory,
DefaultAgentCardResolver,
createAuthenticatingFetchWithRetry,
ClientFactoryOptions,
type AuthenticationHandler,
type Client,
} from '@a2a-js/sdk/client';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
sendMessageStream: ReturnType<typeof vi.fn>;
getTask: ReturnType<typeof vi.fn>;
cancelTask: ReturnType<typeof vi.fn>;
}
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
const actual = await importOriginal();
return {
...(actual as Record<string, unknown>),
createAuthenticatingFetchWithRetry: vi.fn(),
ClientFactory: vi.fn(),
DefaultAgentCardResolver: vi.fn(),
ClientFactoryOptions: {
createFrom: vi.fn(),
default: {},
},
};
});
vi.mock('../utils/debugLogger.js', () => ({ vi.mock('../utils/debugLogger.js', () => ({
debugLogger: { debugLogger: {
debug: vi.fn(), debug: vi.fn(),
}, },
})); }));
vi.mock('@a2a-js/sdk/client', () => { vi.mock('node:dns/promises', () => ({
const ClientFactory = vi.fn(); lookup: vi.fn(),
const DefaultAgentCardResolver = vi.fn(); }));
const RestTransportFactory = vi.fn();
const JsonRpcTransportFactory = vi.fn();
const ClientFactoryOptions = {
default: {},
createFrom: vi.fn(),
};
const createAuthenticatingFetchWithRetry = vi.fn();
DefaultAgentCardResolver.prototype.resolve = vi.fn();
ClientFactory.prototype.createFromUrl = vi.fn();
return {
ClientFactory,
ClientFactoryOptions,
DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory,
createAuthenticatingFetchWithRetry,
};
});
describe('A2AClientManager', () => { describe('A2AClientManager', () => {
let manager: A2AClientManager; let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
name: 'test-agent',
description: 'A test agent',
url: 'http://test.agent',
version: '1.0.0',
protocolVersion: '0.1.0',
capabilities: {},
skills: [],
defaultInputModes: [],
defaultOutputModes: [],
};
const mockClient: MockClient = {
sendMessageStream: vi.fn(),
getTask: vi.fn(),
cancelTask: vi.fn(),
};
// Stable mocks initialized once
const sendMessageStreamMock = vi.fn();
const getTaskMock = vi.fn();
const cancelTaskMock = vi.fn();
const getAgentCardMock = vi.fn();
const authFetchMock = vi.fn(); const authFetchMock = vi.fn();
const mockClient = {
sendMessageStream: sendMessageStreamMock,
getTask: getTaskMock,
cancelTask: cancelTaskMock,
getAgentCard: getAgentCardMock,
} as unknown as Client;
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
A2AClientManager.resetInstanceForTesting();
manager = A2AClientManager.getInstance(); manager = A2AClientManager.getInstance();
manager.clearCache();
// Default mock implementations // Default DNS mock: resolve to public IP.
getAgentCardMock.mockResolvedValue({ // Use any cast only for return value due to complex multi-signature overloads.
vi.mocked(dnsPromises.lookup).mockImplementation(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '93.184.216.34', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
createFromAgentCard: vi.fn(),
};
const resolverInstance = {
resolve: vi.fn(),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
mockClient as unknown as sdkClient.Client,
);
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
...mockAgentCard, ...mockAgentCard,
url: 'http://test.agent/real/endpoint', url: 'http://test.agent/real/endpoint',
} as AgentCard); } as AgentCard);
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue( vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation(
mockClient, (_defaults, overrides) =>
overrides as unknown as sdkClient.ClientFactoryOptions,
); );
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({ vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation(
...mockAgentCard, () =>
url: 'http://test.agent/real/endpoint', authFetchMock.mockResolvedValue({
} as AgentCard); ok: true,
json: async () => ({}),
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation( } as Response),
(_defaults, overrides) => overrides as ClientFactoryOptions,
);
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
authFetchMock,
); );
vi.stubGlobal( vi.stubGlobal(
@@ -123,137 +146,194 @@ describe('A2AClientManager', () => {
'TestAgent', 'TestAgent',
'http://test.agent/card', 'http://test.agent/card',
); );
expect(agentCard).toMatchObject(mockAgentCard);
expect(manager.getAgentCard('TestAgent')).toBe(agentCard); expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
expect(manager.getClient('TestAgent')).toBeDefined(); expect(manager.getClient('TestAgent')).toBeDefined();
}); });
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled();
});
it('should throw an error if an agent with the same name is already loaded', async () => { it('should throw an error if an agent with the same name is already loaded', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
await expect( await expect(
manager.loadAgent('TestAgent', 'http://another.agent/card'), manager.loadAgent('TestAgent', 'http://test.agent/card'),
).rejects.toThrow("Agent with name 'TestAgent' is already loaded."); ).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
}); });
it('should use native fetch by default', async () => { it('should use native fetch by default', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled(); expect(
sdkClient.createAuthenticatingFetchWithRetry,
).not.toHaveBeenCalled();
}); });
it('should use provided custom authentication handler', async () => { it('should use provided custom authentication handler', async () => {
const customAuthHandler = { const authHandler: sdkClient.AuthenticationHandler = {
headers: vi.fn(), headers: async () => ({}),
shouldRetryWithHeaders: vi.fn(), shouldRetryWithHeaders: async () => undefined,
}; };
await manager.loadAgent( await manager.loadAgent(
'CustomAuthAgent', 'TestAgent',
'http://custom.agent/card', 'http://test.agent/card',
customAuthHandler as unknown as AuthenticationHandler, authHandler,
);
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
expect.anything(),
customAuthHandler,
); );
expect(sdkClient.createAuthenticatingFetchWithRetry).toHaveBeenCalled();
}); });
it('should log a debug message upon loading an agent', async () => { it('should log a debug message upon loading an agent', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(debugLogger.debug).toHaveBeenCalledWith( expect(debugLogger.debug).toHaveBeenCalledWith(
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card", expect.stringContaining("Loaded agent 'TestAgent'"),
); );
}); });
it('should clear the cache', async () => { it('should clear the cache', async () => {
await manager.loadAgent('TestAgent', 'http://test.agent/card'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
expect(manager.getAgentCard('TestAgent')).toBeDefined();
expect(manager.getClient('TestAgent')).toBeDefined();
manager.clearCache(); manager.clearCache();
expect(manager.getAgentCard('TestAgent')).toBeUndefined(); expect(manager.getAgentCard('TestAgent')).toBeUndefined();
expect(manager.getClient('TestAgent')).toBeUndefined(); expect(manager.getClient('TestAgent')).toBeUndefined();
expect(debugLogger.debug).toHaveBeenCalledWith( });
'[A2AClientManager] Cache cleared.',
it('should throw if resolveAgentCard fails', async () => {
const resolverInstance = {
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
); );
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Resolution failed');
});
it('should throw if factory.createFromAgentCard fails', async () => {
const factoryInstance = {
createFromAgentCard: vi
.fn()
.mockRejectedValue(new Error('Factory failed')),
};
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
factoryInstance as unknown as sdkClient.ClientFactory,
);
await expect(
manager.loadAgent('FailAgent', 'http://fail.agent'),
).rejects.toThrow('Factory failed');
});
});
describe('getAgentCard and getClient', () => {
it('should return undefined if agent is not found', () => {
expect(manager.getAgentCard('Unknown')).toBeUndefined();
expect(manager.getClient('Unknown')).toBeUndefined();
}); });
}); });
describe('sendMessageStream', () => { describe('sendMessageStream', () => {
beforeEach(async () => { beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
}); });
it('should send a message and return a stream', async () => { it('should send a message and return a stream', async () => {
const mockResult = { mockClient.sendMessageStream.mockReturnValue(
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
sendMessageStreamMock.mockReturnValue(
(async function* () { (async function* () {
yield mockResult; yield { kind: 'message' };
})(), })(),
); );
const stream = manager.sendMessageStream('TestAgent', 'Hello'); const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = []; const results = [];
for await (const res of stream) { for await (const result of stream) {
results.push(res); results.push(result);
} }
expect(results).toEqual([mockResult]); expect(results).toHaveLength(1);
expect(sendMessageStreamMock).toHaveBeenCalledWith( expect(mockClient.sendMessageStream).toHaveBeenCalled();
});
it('should use contextId and taskId when provided', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: 'ctx123',
taskId: 'task456',
});
// trigger execution
for await (const _ of stream) {
break;
}
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
message: expect.anything(), message: expect.objectContaining({
contextId: 'ctx123',
taskId: 'task456',
}),
}), }),
expect.any(Object), expect.any(Object),
); );
}); });
it('should use contextId and taskId when provided', async () => { it('should correctly propagate AbortSignal to the stream', async () => {
sendMessageStreamMock.mockReturnValue( mockClient.sendMessageStream.mockReturnValue(
(async function* () { (async function* () {
yield { yield { kind: 'message' };
kind: 'message',
messageId: 'a',
parts: [],
role: 'agent',
} as SendMessageResult;
})(), })(),
); );
const expectedContextId = 'user-context-id'; const controller = new AbortController();
const expectedTaskId = 'user-task-id';
const stream = manager.sendMessageStream('TestAgent', 'Hello', { const stream = manager.sendMessageStream('TestAgent', 'Hello', {
contextId: expectedContextId, signal: controller.signal,
taskId: expectedTaskId,
}); });
// trigger execution
for await (const _ of stream) { for await (const _ of stream) {
// consume stream break;
} }
const call = sendMessageStreamMock.mock.calls[0][0]; expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
expect(call.message.contextId).toBe(expectedContextId); expect.any(Object),
expect(call.message.taskId).toBe(expectedTaskId); expect.objectContaining({ signal: controller.signal }),
);
});
it('should handle a multi-chunk stream with different event types', async () => {
mockClient.sendMessageStream.mockReturnValue(
(async function* () {
yield { kind: 'message', messageId: 'm1' };
yield { kind: 'status-update', taskId: 't1' };
})(),
);
const stream = manager.sendMessageStream('TestAgent', 'Hello');
const results = [];
for await (const result of stream) {
results.push(result);
}
expect(results).toHaveLength(2);
expect(results[0].kind).toBe('message');
expect(results[1].kind).toBe('status-update');
}); });
it('should throw prefixed error on failure', async () => { it('should throw prefixed error on failure', async () => {
sendMessageStreamMock.mockImplementationOnce(() => { mockClient.sendMessageStream.mockImplementation(() => {
throw new Error('Network error'); throw new Error('Network failure');
}); });
const stream = manager.sendMessageStream('TestAgent', 'Hello'); const stream = manager.sendMessageStream('TestAgent', 'Hello');
await expect(async () => { await expect(async () => {
for await (const _ of stream) { for await (const _ of stream) {
// consume // empty
} }
}).rejects.toThrow( }).rejects.toThrow(
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error', '[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
); );
}); });
@@ -261,7 +341,7 @@ describe('A2AClientManager', () => {
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello'); const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
await expect(async () => { await expect(async () => {
for await (const _ of stream) { for await (const _ of stream) {
// consume // empty
} }
}).rejects.toThrow("Agent 'NonExistentAgent' not found."); }).rejects.toThrow("Agent 'NonExistentAgent' not found.");
}); });
@@ -269,28 +349,23 @@ describe('A2AClientManager', () => {
describe('getTask', () => { describe('getTask', () => {
beforeEach(async () => { beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
}); });
it('should get a task from the correct agent', async () => { it('should get a task from the correct agent', async () => {
getTaskMock.mockResolvedValue({ const mockTask = { id: 'task123', kind: 'task' };
id: 'task123', mockClient.getTask.mockResolvedValue(mockTask);
contextId: 'a',
kind: 'task',
status: { state: 'completed' },
} as Task);
await manager.getTask('TestAgent', 'task123'); const result = await manager.getTask('TestAgent', 'task123');
expect(getTaskMock).toHaveBeenCalledWith({ expect(result).toBe(mockTask);
id: 'task123', expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
});
}); });
it('should throw prefixed error on failure', async () => { it('should throw prefixed error on failure', async () => {
getTaskMock.mockRejectedValueOnce(new Error('Network error')); mockClient.getTask.mockRejectedValue(new Error('Not found'));
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow( await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient getTask Error [TestAgent]: Network error', 'A2AClient getTask Error [TestAgent]: Not found',
); );
}); });
@@ -303,28 +378,23 @@ describe('A2AClientManager', () => {
describe('cancelTask', () => { describe('cancelTask', () => {
beforeEach(async () => { beforeEach(async () => {
await manager.loadAgent('TestAgent', 'http://test.agent'); await manager.loadAgent('TestAgent', 'http://test.agent/card');
}); });
it('should cancel a task on the correct agent', async () => { it('should cancel a task on the correct agent', async () => {
cancelTaskMock.mockResolvedValue({ const mockTask = { id: 'task123', kind: 'task' };
id: 'task123', mockClient.cancelTask.mockResolvedValue(mockTask);
contextId: 'a',
kind: 'task',
status: { state: 'canceled' },
} as Task);
await manager.cancelTask('TestAgent', 'task123'); const result = await manager.cancelTask('TestAgent', 'task123');
expect(cancelTaskMock).toHaveBeenCalledWith({ expect(result).toBe(mockTask);
id: 'task123', expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
});
}); });
it('should throw prefixed error on failure', async () => { it('should throw prefixed error on failure', async () => {
cancelTaskMock.mockRejectedValueOnce(new Error('Network error')); mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow( await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
'A2AClient cancelTask Error [TestAgent]: Network error', 'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
); );
}); });
@@ -334,4 +404,82 @@ describe('A2AClientManager', () => {
).rejects.toThrow("Agent 'NonExistentAgent' not found."); ).rejects.toThrow("Agent 'NonExistentAgent' not found.");
}); });
}); });
describe('Protocol Routing & URL Logic', () => {
it('should correctly split URLs to prevent .well-known doubling', async () => {
const fullUrl = 'http://localhost:9001/.well-known/agent-card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await manager.loadAgent('test-doubling', fullUrl);
expect(resolverInstance.resolve).toHaveBeenCalledWith(
'http://localhost:9001/',
undefined,
);
});
it('should throw if a remote agent uses a private IP (SSRF protection)', async () => {
const privateUrl = 'http://169.254.169.254/.well-known/agent-card.json';
await expect(manager.loadAgent('ssrf-agent', privateUrl)).rejects.toThrow(
/Refusing to load agent 'ssrf-agent' from private IP range/,
);
});
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
const maliciousDomainUrl =
'http://malicious.com/.well-known/agent-card.json';
vi.mocked(dnsPromises.lookup).mockImplementationOnce(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '10.0.0.1', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
).rejects.toThrow(/private IP range/);
});
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
const publicUrl = 'https://public.agent.com/card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({
...mockAgentCard,
url: 'http://192.168.1.1/api', // Malicious private transport in public card
} as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
// DNS for public.agent.com is public
vi.mocked(dnsPromises.lookup).mockImplementation(
async (hostname: string, options?: LookupOptions | number) => {
const isAll = typeof options === 'object' && options?.all;
if (hostname === 'public.agent.com') {
const addr = { address: '1.1.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
}
const addr = { address: '192.168.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('malicious-agent', publicUrl),
).rejects.toThrow(
/contains transport URL pointing to private IP range: http:\/\/192.168.1.1\/api/,
);
});
});
}); });

View File

@@ -12,20 +12,56 @@ import type {
TaskStatusUpdateEvent, TaskStatusUpdateEvent,
TaskArtifactUpdateEvent, TaskArtifactUpdateEvent,
} from '@a2a-js/sdk'; } from '@a2a-js/sdk';
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
import { import {
type Client,
ClientFactory, ClientFactory,
ClientFactoryOptions, ClientFactoryOptions,
DefaultAgentCardResolver, DefaultAgentCardResolver,
RestTransportFactory,
JsonRpcTransportFactory, JsonRpcTransportFactory,
type AuthenticationHandler, RestTransportFactory,
createAuthenticatingFetchWithRetry, createAuthenticatingFetchWithRetry,
} from '@a2a-js/sdk/client'; } from '@a2a-js/sdk/client';
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent } from 'undici'; import { Agent as UndiciAgent } from 'undici';
import {
getGrpcChannelOptions,
getGrpcCredentials,
normalizeAgentCard,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import {
isPrivateIpAsync,
safeLookup,
isLoopbackHost,
} from '../utils/fetch.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { safeLookup } from '../utils/fetch.js';
/**
* Result of sending a message, which can be a full message, a task,
* or an incremental status/artifact update.
*/
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
/**
* Internal interface representing properties we inject into the SDK
* to enable DNS rebinding protection for gRPC connections.
* TODO: Replace with official SDK pinning API once available.
*/
interface InternalGrpcExtensions {
target: string;
grpcChannelOptions: Record<string, unknown>;
}
// Local extension of RequestInit to support Node.js/undici dispatcher
interface NodeFetchInit extends RequestInit {
dispatcher?: UndiciAgent;
}
// Remote agents can take 10+ minutes (e.g. Deep Research). // Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected. // Use a dedicated dispatcher so the global 5-min timeout isn't affected.
@@ -34,22 +70,18 @@ const a2aDispatcher = new UndiciAgent({
headersTimeout: A2A_TIMEOUT, headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT,
connect: { connect: {
lookup: safeLookup, // SSRF protection at connection level // SSRF protection at the connection level (mitigates DNS rebinding)
lookup: safeLookup,
}, },
}); });
const a2aFetch: typeof fetch = (input, init) => const a2aFetch: typeof fetch = (input, init) => {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection const nodeInit: NodeFetchInit = { ...init, dispatcher: a2aDispatcher };
fetch(input, { ...init, dispatcher: a2aDispatcher } as RequestInit); return fetch(input, nodeInit as RequestInit);
};
export type SendMessageResult =
| Message
| Task
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
/** /**
* Manages A2A clients and caches loaded agent information. * Orchestrates communication with remote A2A agents.
* Follows a singleton pattern to ensure a single client instance. * Manages protocol negotiation, authentication, and transport selection.
*/ */
export class A2AClientManager { export class A2AClientManager {
private static instance: A2AClientManager; private static instance: A2AClientManager;
@@ -70,19 +102,10 @@ export class A2AClientManager {
return A2AClientManager.instance; return A2AClientManager.instance;
} }
/**
* Resets the singleton instance. Only for testing purposes.
* @internal
*/
static resetInstanceForTesting() {
// @ts-expect-error - Resetting singleton for testing
A2AClientManager.instance = undefined;
}
/** /**
* Loads an agent by fetching its AgentCard and caches the client. * Loads an agent by fetching its AgentCard and caches the client.
* @param name The name to assign to the agent. * @param name The name to assign to the agent.
* @param agentCardUrl The full URL to the agent's card. * @param agentCardUrl {string} The full URL to the agent's card.
* @param authHandler Optional authentication handler to use for this agent. * @param authHandler Optional authentication handler to use for this agent.
* @returns The loaded AgentCard. * @returns The loaded AgentCard.
*/ */
@@ -95,27 +118,52 @@ export class A2AClientManager {
throw new Error(`Agent with name '${name}' is already loaded.`); throw new Error(`Agent with name '${name}' is already loaded.`);
} }
let fetchImpl: typeof fetch = a2aFetch; const fetchImpl = this.getFetchImpl(authHandler);
if (authHandler) {
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
}
const resolver = new DefaultAgentCardResolver({ fetchImpl }); const resolver = new DefaultAgentCardResolver({ fetchImpl });
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
const options = ClientFactoryOptions.createFrom( // Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection)
const grpcInterface = agentCard.additionalInterfaces?.find(
(i) => i.transport === 'GRPC',
);
const urlToPin = grpcInterface?.url ?? agentCard.url;
const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name);
// Prepare base gRPC options
const baseGrpcOptions: ConstructorParameters<
typeof GrpcTransportFactory
>[0] = {
grpcChannelCredentials: getGrpcCredentials(urlToPin),
};
// We inject additional properties into the transport options to force
// the use of a pinned IP address and matching SSL authority. This is
// required for robust DNS Rebinding protection.
const transportOptions = {
...baseGrpcOptions,
target: pinnedUrl,
grpcChannelOptions: getGrpcChannelOptions(hostname),
} as ConstructorParameters<typeof GrpcTransportFactory>[0] &
InternalGrpcExtensions;
// Configure standard SDK client for tool registration and discovery
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default, ClientFactoryOptions.default,
{ {
transports: [ transports: [
new RestTransportFactory({ fetchImpl }), new RestTransportFactory({ fetchImpl }),
new JsonRpcTransportFactory({ fetchImpl }), new JsonRpcTransportFactory({ fetchImpl }),
new GrpcTransportFactory(
transportOptions as ConstructorParameters<
typeof GrpcTransportFactory
>[0],
),
], ],
cardResolver: resolver, cardResolver: resolver,
}, },
); );
const factory = new ClientFactory(clientOptions);
const factory = new ClientFactory(options); const client = await factory.createFromAgentCard(agentCard);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
this.clients.set(name, client); this.clients.set(name, client);
this.agentCards.set(name, agentCard); this.agentCards.set(name, agentCard);
@@ -150,9 +198,7 @@ export class A2AClientManager {
options?: { contextId?: string; taskId?: string; signal?: AbortSignal }, options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
): AsyncIterable<SendMessageResult> { ): AsyncIterable<SendMessageResult> {
const client = this.clients.get(agentName); const client = this.clients.get(agentName);
if (!client) { if (!client) throw new Error(`Agent '${agentName}' not found.`);
throw new Error(`Agent '${agentName}' not found.`);
}
const messageParams: MessageSendParams = { const messageParams: MessageSendParams = {
message: { message: {
@@ -168,7 +214,7 @@ export class A2AClientManager {
try { try {
yield* client.sendMessageStream(messageParams, { yield* client.sendMessageStream(messageParams, {
signal: options?.signal, signal: options?.signal,
}); }) as AsyncIterable<SendMessageResult>;
} catch (error: unknown) { } catch (error: unknown) {
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
if (error instanceof Error) { if (error instanceof Error) {
@@ -206,9 +252,7 @@ export class A2AClientManager {
*/ */
async getTask(agentName: string, taskId: string): Promise<Task> { async getTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName); const client = this.clients.get(agentName);
if (!client) { if (!client) throw new Error(`Agent '${agentName}' not found.`);
throw new Error(`Agent '${agentName}' not found.`);
}
try { try {
return await client.getTask({ id: taskId }); return await client.getTask({ id: taskId });
} catch (error: unknown) { } catch (error: unknown) {
@@ -228,9 +272,7 @@ export class A2AClientManager {
*/ */
async cancelTask(agentName: string, taskId: string): Promise<Task> { async cancelTask(agentName: string, taskId: string): Promise<Task> {
const client = this.clients.get(agentName); const client = this.clients.get(agentName);
if (!client) { if (!client) throw new Error(`Agent '${agentName}' not found.`);
throw new Error(`Agent '${agentName}' not found.`);
}
try { try {
return await client.cancelTask({ id: taskId }); return await client.cancelTask({ id: taskId });
} catch (error: unknown) { } catch (error: unknown) {
@@ -241,4 +283,75 @@ export class A2AClientManager {
throw new Error(`${prefix}: Unexpected error: ${String(error)}`); throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
} }
} }
/**
* Resolves the appropriate fetch implementation for an agent.
*/
private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch {
return authHandler
? createAuthenticatingFetchWithRetry(a2aFetch, authHandler)
: a2aFetch;
}
/**
* Resolves and normalizes an agent card from a given URL.
* Handles splitting the URL if it already contains the standard .well-known path.
* Also performs basic SSRF validation to prevent internal IP access.
*/
private async resolveAgentCard(
agentName: string,
url: string,
resolver: DefaultAgentCardResolver,
): Promise<AgentCard> {
// Validate URL to prevent SSRF (with DNS resolution)
if (await isPrivateIpAsync(url)) {
// Local/private IPs are allowed ONLY for localhost for testing.
const parsed = new URL(url);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
);
}
}
const { baseUrl, path } = splitAgentCardUrl(url);
const rawCard = await resolver.resolve(baseUrl, path);
const agentCard = normalizeAgentCard(rawCard);
// Deep validation of all transport URLs within the card to prevent SSRF
await this.validateAgentCardUrls(agentName, agentCard);
return agentCard;
}
/**
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
*/
private async validateAgentCardUrls(
agentName: string,
card: AgentCard,
): Promise<void> {
const urlsToValidate = [card.url];
if (card.additionalInterfaces) {
for (const intf of card.additionalInterfaces) {
if (intf.url) urlsToValidate.push(intf.url);
}
}
for (const url of urlsToValidate) {
if (!url) continue;
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
const validationUrl = url.includes('://') ? url : `http://${url}`;
if (await isPrivateIpAsync(validationUrl)) {
const parsed = new URL(validationUrl);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
);
}
}
}
}
} }