diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index 484a68314b..f52a9ccee6 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -8,6 +8,7 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { A2AClientManager } from './a2a-client-manager.js'; import type { AgentCard } from '@a2a-js/sdk'; import * as sdkClient from '@a2a-js/sdk/client'; +import { lookup } from 'node:dns/promises'; import { debugLogger } from '../utils/debugLogger.js'; interface MockClient { @@ -36,6 +37,10 @@ vi.mock('../utils/debugLogger.js', () => ({ }, })); +vi.mock('node:dns/promises', () => ({ + lookup: vi.fn().mockResolvedValue([{ address: '93.184.216.34' }]), +})); + describe('A2AClientManager', () => { let manager: A2AClientManager; const mockAgentCard: AgentCard = { @@ -407,6 +412,20 @@ describe('A2AClientManager', () => { ); }); + it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => { + const maliciousDomainUrl = + 'http://malicious.com/.well-known/agent-card.json'; + vi.mocked(lookup).mockResolvedValueOnce([ + { address: '10.0.0.1', family: 4 }, + ]); + + await expect( + manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl), + ).rejects.toThrow( + /Refusing to load agent 'dns-ssrf-agent' from private IP range/, + ); + }); + it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => { const publicUrl = 'https://public.agent.com/card.json'; const resolverInstance = { diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 4619a4c4fe..8099de5ce6 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -26,7 +26,7 @@ import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent } from 'undici'; import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js'; -import { isPrivateIp } from '../utils/fetch.js'; +import { isPrivateIpAsync } from '../utils/fetch.js'; import { debugLogger } from '../utils/debugLogger.js'; // Remote agents can take 10+ minutes (e.g. Deep Research). @@ -259,8 +259,8 @@ export class A2AClientManager { let baseUrl = url; let path: string | undefined; - // Validate URL to prevent SSRF - if (isPrivateIp(url)) { + // Validate URL to prevent SSRF (with DNS resolution) + if (await isPrivateIpAsync(url)) { // Local/private IPs are allowed ONLY for localhost for testing. const parsed = new URL(url); if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') { @@ -285,7 +285,7 @@ export class A2AClientManager { const agentCard = normalizeAgentCard(rawCard); // Deep validation of all transport URLs within the card to prevent SSRF - this.validateAgentCardUrls(agentName, agentCard); + await this.validateAgentCardUrls(agentName, agentCard); return agentCard; } @@ -293,7 +293,10 @@ export class A2AClientManager { /** * Validates all URLs (top-level and interfaces) within an AgentCard for SSRF. */ - private validateAgentCardUrls(agentName: string, card: AgentCard): void { + private async validateAgentCardUrls( + agentName: string, + card: AgentCard, + ): Promise { const urlsToValidate = [card.url]; if (card.additionalInterfaces) { for (const intf of card.additionalInterfaces) { @@ -307,7 +310,7 @@ export class A2AClientManager { // Ensure URL has a scheme for the parser (gRPC often provides raw IP:port) const validationUrl = url.includes('://') ? url : `http://${url}`; - if (isPrivateIp(validationUrl)) { + if (await isPrivateIpAsync(validationUrl)) { const parsed = new URL(validationUrl); if ( parsed.hostname !== 'localhost' && diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index 28b776e3d5..4f1c81b5fb 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -6,6 +6,7 @@ import { getErrorMessage, isNodeError } from './errors.js'; import { URL } from 'node:url'; +import { lookup } from 'node:dns/promises'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes @@ -44,12 +45,44 @@ export class FetchError extends Error { export function isPrivateIp(url: string): boolean { try { const hostname = new URL(url).hostname; - return PRIVATE_IP_RANGES.some((range) => range.test(hostname)); + return isAddressPrivate(hostname); } catch (_e) { return false; } } +/** + * Checks if a URL resolves to a private IP address. + * Performs DNS resolution to prevent DNS rebinding/SSRF bypasses. + */ +export async function isPrivateIpAsync(url: string): Promise { + try { + const parsed = new URL(url); + const hostname = parsed.hostname; + + // Fast check for literal IPs or localhost + if (isAddressPrivate(hostname)) { + return true; + } + + // Resolve DNS to check the actual target IP + const addresses = await lookup(hostname, { all: true }); + return addresses.some((addr) => isAddressPrivate(addr.address)); + } catch (_e) { + return false; + } +} + +/** + * Internal helper to check if an IP address string is in a private range. + */ +function isAddressPrivate(address: string): boolean { + return ( + address === 'localhost' || + PRIVATE_IP_RANGES.some((range) => range.test(address)) + ); +} + export async function fetchWithTimeout( url: string, timeout: number,