mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
feat(a2a): enable native gRPC support and protocol routing
This commit is contained in:
@@ -5,96 +5,119 @@
|
||||
*/
|
||||
|
||||
import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
A2AClientManager,
|
||||
type SendMessageResult,
|
||||
} from './a2a-client-manager.js';
|
||||
import type { AgentCard, Task } from '@a2a-js/sdk';
|
||||
import {
|
||||
ClientFactory,
|
||||
DefaultAgentCardResolver,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
ClientFactoryOptions,
|
||||
type AuthenticationHandler,
|
||||
type Client,
|
||||
} from '@a2a-js/sdk/client';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import * as sdkClient from '@a2a-js/sdk/client';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import type { LookupOptions } from 'node:dns';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
interface MockClient {
|
||||
sendMessageStream: ReturnType<typeof vi.fn>;
|
||||
getTask: ReturnType<typeof vi.fn>;
|
||||
cancelTask: ReturnType<typeof vi.fn>;
|
||||
}
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', async (importOriginal) => {
|
||||
const actual = await importOriginal();
|
||||
return {
|
||||
...(actual as Record<string, unknown>),
|
||||
createAuthenticatingFetchWithRetry: vi.fn(),
|
||||
ClientFactory: vi.fn(),
|
||||
DefaultAgentCardResolver: vi.fn(),
|
||||
ClientFactoryOptions: {
|
||||
createFrom: vi.fn(),
|
||||
default: {},
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('../utils/debugLogger.js', () => ({
|
||||
debugLogger: {
|
||||
debug: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@a2a-js/sdk/client', () => {
|
||||
const ClientFactory = vi.fn();
|
||||
const DefaultAgentCardResolver = vi.fn();
|
||||
const RestTransportFactory = vi.fn();
|
||||
const JsonRpcTransportFactory = vi.fn();
|
||||
const ClientFactoryOptions = {
|
||||
default: {},
|
||||
createFrom: vi.fn(),
|
||||
};
|
||||
const createAuthenticatingFetchWithRetry = vi.fn();
|
||||
|
||||
DefaultAgentCardResolver.prototype.resolve = vi.fn();
|
||||
ClientFactory.prototype.createFromUrl = vi.fn();
|
||||
|
||||
return {
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
};
|
||||
});
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('A2AClientManager', () => {
|
||||
let manager: A2AClientManager;
|
||||
const mockAgentCard: AgentCard = {
|
||||
name: 'test-agent',
|
||||
description: 'A test agent',
|
||||
url: 'http://test.agent',
|
||||
version: '1.0.0',
|
||||
protocolVersion: '0.1.0',
|
||||
capabilities: {},
|
||||
skills: [],
|
||||
defaultInputModes: [],
|
||||
defaultOutputModes: [],
|
||||
};
|
||||
|
||||
const mockClient: MockClient = {
|
||||
sendMessageStream: vi.fn(),
|
||||
getTask: vi.fn(),
|
||||
cancelTask: vi.fn(),
|
||||
};
|
||||
|
||||
// Stable mocks initialized once
|
||||
const sendMessageStreamMock = vi.fn();
|
||||
const getTaskMock = vi.fn();
|
||||
const cancelTaskMock = vi.fn();
|
||||
const getAgentCardMock = vi.fn();
|
||||
const authFetchMock = vi.fn();
|
||||
|
||||
const mockClient = {
|
||||
sendMessageStream: sendMessageStreamMock,
|
||||
getTask: getTaskMock,
|
||||
cancelTask: cancelTaskMock,
|
||||
getAgentCard: getAgentCardMock,
|
||||
} as unknown as Client;
|
||||
|
||||
const mockAgentCard: Partial<AgentCard> = { name: 'TestAgent' };
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
manager = A2AClientManager.getInstance();
|
||||
manager.clearCache();
|
||||
|
||||
// Default mock implementations
|
||||
getAgentCardMock.mockResolvedValue({
|
||||
// Default DNS mock: resolve to public IP.
|
||||
// Use any cast only for return value due to complex multi-signature overloads.
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async (_h: string, options?: LookupOptions | number) => {
|
||||
const addr = { address: '93.184.216.34', family: 4 };
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
// Re-create the instances as plain objects that can be spied on
|
||||
const factoryInstance = {
|
||||
createFromUrl: vi.fn(),
|
||||
createFromAgentCard: vi.fn(),
|
||||
};
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as sdkClient.ClientFactory,
|
||||
);
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
vi.spyOn(factoryInstance, 'createFromUrl').mockResolvedValue(
|
||||
mockClient as unknown as sdkClient.Client,
|
||||
);
|
||||
vi.spyOn(factoryInstance, 'createFromAgentCard').mockResolvedValue(
|
||||
mockClient as unknown as sdkClient.Client,
|
||||
);
|
||||
vi.spyOn(resolverInstance, 'resolve').mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactory.prototype.createFromUrl).mockResolvedValue(
|
||||
mockClient,
|
||||
vi.spyOn(sdkClient.ClientFactoryOptions, 'createFrom').mockImplementation(
|
||||
(_defaults, overrides) =>
|
||||
overrides as unknown as sdkClient.ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(DefaultAgentCardResolver.prototype.resolve).mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://test.agent/real/endpoint',
|
||||
} as AgentCard);
|
||||
|
||||
vi.mocked(ClientFactoryOptions.createFrom).mockImplementation(
|
||||
(_defaults, overrides) => overrides as ClientFactoryOptions,
|
||||
);
|
||||
|
||||
vi.mocked(createAuthenticatingFetchWithRetry).mockReturnValue(
|
||||
authFetchMock,
|
||||
vi.mocked(sdkClient.createAuthenticatingFetchWithRetry).mockImplementation(
|
||||
() =>
|
||||
authFetchMock.mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response),
|
||||
);
|
||||
|
||||
vi.stubGlobal(
|
||||
@@ -123,137 +146,194 @@ describe('A2AClientManager', () => {
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
);
|
||||
expect(agentCard).toMatchObject(mockAgentCard);
|
||||
expect(manager.getAgentCard('TestAgent')).toBe(agentCard);
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
});
|
||||
|
||||
it('should configure ClientFactory with REST, JSON-RPC, and gRPC transports', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(sdkClient.ClientFactoryOptions.createFrom).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should throw an error if an agent with the same name is already loaded', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
await expect(
|
||||
manager.loadAgent('TestAgent', 'http://another.agent/card'),
|
||||
manager.loadAgent('TestAgent', 'http://test.agent/card'),
|
||||
).rejects.toThrow("Agent with name 'TestAgent' is already loaded.");
|
||||
});
|
||||
|
||||
it('should use native fetch by default', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(createAuthenticatingFetchWithRetry).not.toHaveBeenCalled();
|
||||
expect(
|
||||
sdkClient.createAuthenticatingFetchWithRetry,
|
||||
).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use provided custom authentication handler', async () => {
|
||||
const customAuthHandler = {
|
||||
headers: vi.fn(),
|
||||
shouldRetryWithHeaders: vi.fn(),
|
||||
const authHandler: sdkClient.AuthenticationHandler = {
|
||||
headers: async () => ({}),
|
||||
shouldRetryWithHeaders: async () => undefined,
|
||||
};
|
||||
await manager.loadAgent(
|
||||
'CustomAuthAgent',
|
||||
'http://custom.agent/card',
|
||||
customAuthHandler as unknown as AuthenticationHandler,
|
||||
);
|
||||
|
||||
expect(createAuthenticatingFetchWithRetry).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
customAuthHandler,
|
||||
'TestAgent',
|
||||
'http://test.agent/card',
|
||||
authHandler,
|
||||
);
|
||||
expect(sdkClient.createAuthenticatingFetchWithRetry).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should log a debug message upon loading an agent', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
"[A2AClientManager] Loaded agent 'TestAgent' from http://test.agent/card",
|
||||
expect.stringContaining("Loaded agent 'TestAgent'"),
|
||||
);
|
||||
});
|
||||
|
||||
it('should clear the cache', async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
expect(manager.getAgentCard('TestAgent')).toBeDefined();
|
||||
expect(manager.getClient('TestAgent')).toBeDefined();
|
||||
|
||||
manager.clearCache();
|
||||
|
||||
expect(manager.getAgentCard('TestAgent')).toBeUndefined();
|
||||
expect(manager.getClient('TestAgent')).toBeUndefined();
|
||||
expect(debugLogger.debug).toHaveBeenCalledWith(
|
||||
'[A2AClientManager] Cache cleared.',
|
||||
});
|
||||
|
||||
it('should throw if resolveAgentCard fails', async () => {
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockRejectedValue(new Error('Resolution failed')),
|
||||
};
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Resolution failed');
|
||||
});
|
||||
|
||||
it('should throw if factory.createFromAgentCard fails', async () => {
|
||||
const factoryInstance = {
|
||||
createFromAgentCard: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Factory failed')),
|
||||
};
|
||||
vi.mocked(sdkClient.ClientFactory).mockReturnValue(
|
||||
factoryInstance as unknown as sdkClient.ClientFactory,
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('FailAgent', 'http://fail.agent'),
|
||||
).rejects.toThrow('Factory failed');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getAgentCard and getClient', () => {
|
||||
it('should return undefined if agent is not found', () => {
|
||||
expect(manager.getAgentCard('Unknown')).toBeUndefined();
|
||||
expect(manager.getClient('Unknown')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should send a message and return a stream', async () => {
|
||||
const mockResult = {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield mockResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const res of stream) {
|
||||
results.push(res);
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toEqual([mockResult]);
|
||||
expect(sendMessageStreamMock).toHaveBeenCalledWith(
|
||||
expect(results).toHaveLength(1);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
});
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
break;
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
message: expect.anything(),
|
||||
message: expect.objectContaining({
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
}),
|
||||
}),
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use contextId and taskId when provided', async () => {
|
||||
sendMessageStreamMock.mockReturnValue(
|
||||
it('should correctly propagate AbortSignal to the stream', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
kind: 'message',
|
||||
messageId: 'a',
|
||||
parts: [],
|
||||
role: 'agent',
|
||||
} as SendMessageResult;
|
||||
yield { kind: 'message' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const expectedContextId = 'user-context-id';
|
||||
const expectedTaskId = 'user-task-id';
|
||||
|
||||
const controller = new AbortController();
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
contextId: expectedContextId,
|
||||
taskId: expectedTaskId,
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
// consume stream
|
||||
break;
|
||||
}
|
||||
|
||||
const call = sendMessageStreamMock.mock.calls[0][0];
|
||||
expect(call.message.contextId).toBe(expectedContextId);
|
||||
expect(call.message.taskId).toBe(expectedTaskId);
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ signal: controller.signal }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle a multi-chunk stream with different event types', async () => {
|
||||
mockClient.sendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield { kind: 'message', messageId: 'm1' };
|
||||
yield { kind: 'status-update', taskId: 't1' };
|
||||
})(),
|
||||
);
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
const results = [];
|
||||
for await (const result of stream) {
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
expect(results).toHaveLength(2);
|
||||
expect(results[0].kind).toBe('message');
|
||||
expect(results[1].kind).toBe('status-update');
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
sendMessageStreamMock.mockImplementationOnce(() => {
|
||||
throw new Error('Network error');
|
||||
mockClient.sendMessageStream.mockImplementation(() => {
|
||||
throw new Error('Network failure');
|
||||
});
|
||||
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow(
|
||||
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network error',
|
||||
'[A2AClientManager] sendMessageStream Error [TestAgent]: Network failure',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -261,7 +341,7 @@ describe('A2AClientManager', () => {
|
||||
const stream = manager.sendMessageStream('NonExistentAgent', 'Hello');
|
||||
await expect(async () => {
|
||||
for await (const _ of stream) {
|
||||
// consume
|
||||
// empty
|
||||
}
|
||||
}).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
@@ -269,28 +349,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('getTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should get a task from the correct agent', async () => {
|
||||
getTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'completed' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.getTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.getTask('TestAgent', 'task123');
|
||||
expect(getTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.getTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.getTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
getTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.getTask.mockRejectedValue(new Error('Not found'));
|
||||
|
||||
await expect(manager.getTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient getTask Error [TestAgent]: Network error',
|
||||
'A2AClient getTask Error [TestAgent]: Not found',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -303,28 +378,23 @@ describe('A2AClientManager', () => {
|
||||
|
||||
describe('cancelTask', () => {
|
||||
beforeEach(async () => {
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent');
|
||||
await manager.loadAgent('TestAgent', 'http://test.agent/card');
|
||||
});
|
||||
|
||||
it('should cancel a task on the correct agent', async () => {
|
||||
cancelTaskMock.mockResolvedValue({
|
||||
id: 'task123',
|
||||
contextId: 'a',
|
||||
kind: 'task',
|
||||
status: { state: 'canceled' },
|
||||
} as Task);
|
||||
const mockTask = { id: 'task123', kind: 'task' };
|
||||
mockClient.cancelTask.mockResolvedValue(mockTask);
|
||||
|
||||
await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(cancelTaskMock).toHaveBeenCalledWith({
|
||||
id: 'task123',
|
||||
});
|
||||
const result = await manager.cancelTask('TestAgent', 'task123');
|
||||
expect(result).toBe(mockTask);
|
||||
expect(mockClient.cancelTask).toHaveBeenCalledWith({ id: 'task123' });
|
||||
});
|
||||
|
||||
it('should throw prefixed error on failure', async () => {
|
||||
cancelTaskMock.mockRejectedValueOnce(new Error('Network error'));
|
||||
mockClient.cancelTask.mockRejectedValue(new Error('Cannot cancel'));
|
||||
|
||||
await expect(manager.cancelTask('TestAgent', 'task123')).rejects.toThrow(
|
||||
'A2AClient cancelTask Error [TestAgent]: Network error',
|
||||
'A2AClient cancelTask Error [TestAgent]: Cannot cancel',
|
||||
);
|
||||
});
|
||||
|
||||
@@ -334,4 +404,82 @@ describe('A2AClientManager', () => {
|
||||
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
|
||||
});
|
||||
});
|
||||
|
||||
describe('Protocol Routing & URL Logic', () => {
|
||||
it('should correctly split URLs to prevent .well-known doubling', async () => {
|
||||
const fullUrl = 'http://localhost:9001/.well-known/agent-card.json';
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
|
||||
};
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await manager.loadAgent('test-doubling', fullUrl);
|
||||
|
||||
expect(resolverInstance.resolve).toHaveBeenCalledWith(
|
||||
'http://localhost:9001/',
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if a remote agent uses a private IP (SSRF protection)', async () => {
|
||||
const privateUrl = 'http://169.254.169.254/.well-known/agent-card.json';
|
||||
await expect(manager.loadAgent('ssrf-agent', privateUrl)).rejects.toThrow(
|
||||
/Refusing to load agent 'ssrf-agent' from private IP range/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
|
||||
const maliciousDomainUrl =
|
||||
'http://malicious.com/.well-known/agent-card.json';
|
||||
|
||||
vi.mocked(dnsPromises.lookup).mockImplementationOnce(
|
||||
async (_h: string, options?: LookupOptions | number) => {
|
||||
const addr = { address: '10.0.0.1', family: 4 };
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
|
||||
).rejects.toThrow(/private IP range/);
|
||||
});
|
||||
|
||||
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
|
||||
const publicUrl = 'https://public.agent.com/card.json';
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockResolvedValue({
|
||||
...mockAgentCard,
|
||||
url: 'http://192.168.1.1/api', // Malicious private transport in public card
|
||||
} as AgentCard),
|
||||
};
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
// DNS for public.agent.com is public
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async (hostname: string, options?: LookupOptions | number) => {
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
if (hostname === 'public.agent.com') {
|
||||
const addr = { address: '1.1.1.1', family: 4 };
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
}
|
||||
const addr = { address: '192.168.1.1', family: 4 };
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('malicious-agent', publicUrl),
|
||||
).rejects.toThrow(
|
||||
/contains transport URL pointing to private IP range: http:\/\/192.168.1.1\/api/,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,20 +12,56 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
|
||||
import {
|
||||
type Client,
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
type AuthenticationHandler,
|
||||
RestTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
} from '@a2a-js/sdk/client';
|
||||
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Agent as UndiciAgent } from 'undici';
|
||||
import {
|
||||
getGrpcChannelOptions,
|
||||
getGrpcCredentials,
|
||||
normalizeAgentCard,
|
||||
pinUrlToIp,
|
||||
splitAgentCardUrl,
|
||||
} from './a2aUtils.js';
|
||||
import {
|
||||
isPrivateIpAsync,
|
||||
safeLookup,
|
||||
isLoopbackHost,
|
||||
} from '../utils/fetch.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { safeLookup } from '../utils/fetch.js';
|
||||
|
||||
/**
|
||||
* Result of sending a message, which can be a full message, a task,
|
||||
* or an incremental status/artifact update.
|
||||
*/
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
|
||||
/**
|
||||
* Internal interface representing properties we inject into the SDK
|
||||
* to enable DNS rebinding protection for gRPC connections.
|
||||
* TODO: Replace with official SDK pinning API once available.
|
||||
*/
|
||||
interface InternalGrpcExtensions {
|
||||
target: string;
|
||||
grpcChannelOptions: Record<string, unknown>;
|
||||
}
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: UndiciAgent;
|
||||
}
|
||||
|
||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
|
||||
@@ -34,22 +70,18 @@ const a2aDispatcher = new UndiciAgent({
|
||||
headersTimeout: A2A_TIMEOUT,
|
||||
bodyTimeout: A2A_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup, // SSRF protection at connection level
|
||||
// SSRF protection at the connection level (mitigates DNS rebinding)
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
const a2aFetch: typeof fetch = (input, init) =>
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
fetch(input, { ...init, dispatcher: a2aDispatcher } as RequestInit);
|
||||
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
const a2aFetch: typeof fetch = (input, init) => {
|
||||
const nodeInit: NodeFetchInit = { ...init, dispatcher: a2aDispatcher };
|
||||
return fetch(input, nodeInit as RequestInit);
|
||||
};
|
||||
|
||||
/**
|
||||
* Manages A2A clients and caches loaded agent information.
|
||||
* Follows a singleton pattern to ensure a single client instance.
|
||||
* Orchestrates communication with remote A2A agents.
|
||||
* Manages protocol negotiation, authentication, and transport selection.
|
||||
*/
|
||||
export class A2AClientManager {
|
||||
private static instance: A2AClientManager;
|
||||
@@ -70,19 +102,10 @@ export class A2AClientManager {
|
||||
return A2AClientManager.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the singleton instance. Only for testing purposes.
|
||||
* @internal
|
||||
*/
|
||||
static resetInstanceForTesting() {
|
||||
// @ts-expect-error - Resetting singleton for testing
|
||||
A2AClientManager.instance = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an agent by fetching its AgentCard and caches the client.
|
||||
* @param name The name to assign to the agent.
|
||||
* @param agentCardUrl The full URL to the agent's card.
|
||||
* @param agentCardUrl {string} The full URL to the agent's card.
|
||||
* @param authHandler Optional authentication handler to use for this agent.
|
||||
* @returns The loaded AgentCard.
|
||||
*/
|
||||
@@ -95,27 +118,52 @@ export class A2AClientManager {
|
||||
throw new Error(`Agent with name '${name}' is already loaded.`);
|
||||
}
|
||||
|
||||
let fetchImpl: typeof fetch = a2aFetch;
|
||||
if (authHandler) {
|
||||
fetchImpl = createAuthenticatingFetchWithRetry(a2aFetch, authHandler);
|
||||
}
|
||||
|
||||
const fetchImpl = this.getFetchImpl(authHandler);
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl });
|
||||
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
|
||||
|
||||
const options = ClientFactoryOptions.createFrom(
|
||||
// Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection)
|
||||
const grpcInterface = agentCard.additionalInterfaces?.find(
|
||||
(i) => i.transport === 'GRPC',
|
||||
);
|
||||
const urlToPin = grpcInterface?.url ?? agentCard.url;
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name);
|
||||
|
||||
// Prepare base gRPC options
|
||||
const baseGrpcOptions: ConstructorParameters<
|
||||
typeof GrpcTransportFactory
|
||||
>[0] = {
|
||||
grpcChannelCredentials: getGrpcCredentials(urlToPin),
|
||||
};
|
||||
|
||||
// We inject additional properties into the transport options to force
|
||||
// the use of a pinned IP address and matching SSL authority. This is
|
||||
// required for robust DNS Rebinding protection.
|
||||
const transportOptions = {
|
||||
...baseGrpcOptions,
|
||||
target: pinnedUrl,
|
||||
grpcChannelOptions: getGrpcChannelOptions(hostname),
|
||||
} as ConstructorParameters<typeof GrpcTransportFactory>[0] &
|
||||
InternalGrpcExtensions;
|
||||
|
||||
// Configure standard SDK client for tool registration and discovery
|
||||
const clientOptions = ClientFactoryOptions.createFrom(
|
||||
ClientFactoryOptions.default,
|
||||
{
|
||||
transports: [
|
||||
new RestTransportFactory({ fetchImpl }),
|
||||
new JsonRpcTransportFactory({ fetchImpl }),
|
||||
new GrpcTransportFactory(
|
||||
transportOptions as ConstructorParameters<
|
||||
typeof GrpcTransportFactory
|
||||
>[0],
|
||||
),
|
||||
],
|
||||
cardResolver: resolver,
|
||||
},
|
||||
);
|
||||
|
||||
const factory = new ClientFactory(options);
|
||||
const client = await factory.createFromUrl(agentCardUrl, '');
|
||||
const agentCard = await client.getAgentCard();
|
||||
const factory = new ClientFactory(clientOptions);
|
||||
const client = await factory.createFromAgentCard(agentCard);
|
||||
|
||||
this.clients.set(name, client);
|
||||
this.agentCards.set(name, agentCard);
|
||||
@@ -150,9 +198,7 @@ export class A2AClientManager {
|
||||
options?: { contextId?: string; taskId?: string; signal?: AbortSignal },
|
||||
): AsyncIterable<SendMessageResult> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
|
||||
const messageParams: MessageSendParams = {
|
||||
message: {
|
||||
@@ -168,7 +214,7 @@ export class A2AClientManager {
|
||||
try {
|
||||
yield* client.sendMessageStream(messageParams, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
}) as AsyncIterable<SendMessageResult>;
|
||||
} catch (error: unknown) {
|
||||
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
|
||||
if (error instanceof Error) {
|
||||
@@ -206,9 +252,7 @@ export class A2AClientManager {
|
||||
*/
|
||||
async getTask(agentName: string, taskId: string): Promise<Task> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
try {
|
||||
return await client.getTask({ id: taskId });
|
||||
} catch (error: unknown) {
|
||||
@@ -228,9 +272,7 @@ export class A2AClientManager {
|
||||
*/
|
||||
async cancelTask(agentName: string, taskId: string): Promise<Task> {
|
||||
const client = this.clients.get(agentName);
|
||||
if (!client) {
|
||||
throw new Error(`Agent '${agentName}' not found.`);
|
||||
}
|
||||
if (!client) throw new Error(`Agent '${agentName}' not found.`);
|
||||
try {
|
||||
return await client.cancelTask({ id: taskId });
|
||||
} catch (error: unknown) {
|
||||
@@ -241,4 +283,75 @@ export class A2AClientManager {
|
||||
throw new Error(`${prefix}: Unexpected error: ${String(error)}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves the appropriate fetch implementation for an agent.
|
||||
*/
|
||||
private getFetchImpl(authHandler?: AuthenticationHandler): typeof fetch {
|
||||
return authHandler
|
||||
? createAuthenticatingFetchWithRetry(a2aFetch, authHandler)
|
||||
: a2aFetch;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves and normalizes an agent card from a given URL.
|
||||
* Handles splitting the URL if it already contains the standard .well-known path.
|
||||
* Also performs basic SSRF validation to prevent internal IP access.
|
||||
*/
|
||||
private async resolveAgentCard(
|
||||
agentName: string,
|
||||
url: string,
|
||||
resolver: DefaultAgentCardResolver,
|
||||
): Promise<AgentCard> {
|
||||
// Validate URL to prevent SSRF (with DNS resolution)
|
||||
if (await isPrivateIpAsync(url)) {
|
||||
// Local/private IPs are allowed ONLY for localhost for testing.
|
||||
const parsed = new URL(url);
|
||||
if (!isLoopbackHost(parsed.hostname)) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const { baseUrl, path } = splitAgentCardUrl(url);
|
||||
const rawCard = await resolver.resolve(baseUrl, path);
|
||||
const agentCard = normalizeAgentCard(rawCard);
|
||||
|
||||
// Deep validation of all transport URLs within the card to prevent SSRF
|
||||
await this.validateAgentCardUrls(agentName, agentCard);
|
||||
|
||||
return agentCard;
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
|
||||
*/
|
||||
private async validateAgentCardUrls(
|
||||
agentName: string,
|
||||
card: AgentCard,
|
||||
): Promise<void> {
|
||||
const urlsToValidate = [card.url];
|
||||
if (card.additionalInterfaces) {
|
||||
for (const intf of card.additionalInterfaces) {
|
||||
if (intf.url) urlsToValidate.push(intf.url);
|
||||
}
|
||||
}
|
||||
|
||||
for (const url of urlsToValidate) {
|
||||
if (!url) continue;
|
||||
|
||||
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
|
||||
const validationUrl = url.includes('://') ? url : `http://${url}`;
|
||||
|
||||
if (await isPrivateIpAsync(validationUrl)) {
|
||||
const parsed = new URL(validationUrl);
|
||||
if (!isLoopbackHost(parsed.hostname)) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user