Address comments and remove not needed security changes.

This commit is contained in:
Alisa Novikova
2026-03-11 06:22:04 -07:00
parent df41ba48cd
commit a094b8d915
18 changed files with 22 additions and 804 deletions

View File

@@ -8,8 +8,6 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
import { A2AClientManager } from './a2a-client-manager.js';
import type { AgentCard } from '@a2a-js/sdk';
import * as sdkClient from '@a2a-js/sdk/client';
import * as dnsPromises from 'node:dns/promises';
import type { LookupOptions } from 'node:dns';
import { debugLogger } from '../utils/debugLogger.js';
interface MockClient {
@@ -38,10 +36,6 @@ vi.mock('../utils/debugLogger.js', () => ({
},
}));
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
describe('A2AClientManager', () => {
let manager: A2AClientManager;
const mockAgentCard: AgentCard = {
@@ -69,17 +63,6 @@ describe('A2AClientManager', () => {
manager = A2AClientManager.getInstance();
manager.clearCache();
// Default DNS mock: resolve to public IP.
// Use any cast only for return value due to complex multi-signature overloads.
vi.mocked(dnsPromises.lookup).mockImplementation(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '93.184.216.34', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
// Re-create the instances as plain objects that can be spied on
const factoryInstance = {
createFromUrl: vi.fn(),
@@ -460,82 +443,4 @@ describe('A2AClientManager', () => {
).rejects.toThrow("Agent 'NonExistentAgent' not found.");
});
});
describe('Protocol Routing & URL Logic', () => {
it('should correctly split URLs to prevent .well-known doubling', async () => {
const fullUrl = 'http://localhost:9001/.well-known/agent-card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
await manager.loadAgent('test-doubling', fullUrl);
expect(resolverInstance.resolve).toHaveBeenCalledWith(
'http://localhost:9001/',
undefined,
);
});
it('should throw if a remote agent uses a private IP (SSRF protection)', async () => {
const privateUrl = 'http://169.254.169.254/.well-known/agent-card.json';
await expect(manager.loadAgent('ssrf-agent', privateUrl)).rejects.toThrow(
/Refusing to load agent 'ssrf-agent' from private IP range/,
);
});
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
const maliciousDomainUrl =
'http://malicious.com/.well-known/agent-card.json';
vi.mocked(dnsPromises.lookup).mockImplementationOnce(
async (_h: string, options?: LookupOptions | number) => {
const addr = { address: '10.0.0.1', family: 4 };
const isAll = typeof options === 'object' && options?.all;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
).rejects.toThrow(/private IP range/);
});
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
const publicUrl = 'https://public.agent.com/card.json';
const resolverInstance = {
resolve: vi.fn().mockResolvedValue({
...mockAgentCard,
url: 'http://192.168.1.1/api', // Malicious private transport in public card
} as AgentCard),
};
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
);
// DNS for public.agent.com is public
vi.mocked(dnsPromises.lookup).mockImplementation(
async (hostname: string, options?: LookupOptions | number) => {
const isAll = typeof options === 'object' && options?.all;
if (hostname === 'public.agent.com') {
const addr = { address: '1.1.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
}
const addr = { address: '192.168.1.1', family: 4 };
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return (isAll ? [addr] : addr) as any;
},
);
await expect(
manager.loadAgent('malicious-agent', publicUrl),
).rejects.toThrow(
/contains transport URL pointing to private IP range: http:\/\/192.168.1.1\/api/,
);
});
});
});

View File

@@ -24,19 +24,7 @@ import {
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
import { v4 as uuidv4 } from 'uuid';
import { Agent as UndiciAgent } from 'undici';
import {
getGrpcChannelOptions,
getGrpcCredentials,
normalizeAgentCard,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import {
isPrivateIpAsync,
safeLookup,
isLoopbackHost,
safeFetch,
} from '../utils/fetch.js';
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
import { debugLogger } from '../utils/debugLogger.js';
import { classifyAgentError } from './a2a-errors.js';
@@ -50,29 +38,16 @@ export type SendMessageResult =
| TaskStatusUpdateEvent
| TaskArtifactUpdateEvent;
/**
* Internal interface representing properties we inject into the SDK
* to enable DNS rebinding protection for gRPC connections.
* TODO: Replace with official SDK pinning API once available.
*/
interface InternalGrpcExtensions {
target: string;
grpcChannelOptions: Record<string, unknown>;
}
// Remote agents can take 10+ minutes (e.g. Deep Research).
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
const A2A_TIMEOUT = 1800000; // 30 minutes
const a2aDispatcher = new UndiciAgent({
headersTimeout: A2A_TIMEOUT,
bodyTimeout: A2A_TIMEOUT,
connect: {
// SSRF protection at the connection level (mitigates DNS rebinding)
lookup: safeLookup,
},
});
const a2aFetch: typeof fetch = (input, init) =>
safeFetch(input, { ...init, dispatcher: a2aDispatcher });
// @ts-expect-error The `dispatcher` property is a Node.js extension to fetch not present in standard types.
fetch(input, { ...init, dispatcher: a2aDispatcher });
/**
* Orchestrates communication with remote A2A agents.
@@ -139,30 +114,6 @@ export class A2AClientManager {
const resolver = new DefaultAgentCardResolver({ fetchImpl: cardFetch });
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
// Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection)
const grpcInterface = agentCard.additionalInterfaces?.find(
(i) => i.transport === 'GRPC',
);
const urlToPin = grpcInterface?.url ?? agentCard.url;
const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name);
// Prepare base gRPC options
const baseGrpcOptions: ConstructorParameters<
typeof GrpcTransportFactory
>[0] = {
grpcChannelCredentials: getGrpcCredentials(urlToPin),
};
// We inject additional properties into the transport options to force
// the use of a pinned IP address and matching SSL authority. This is
// required for robust DNS Rebinding protection.
const transportOptions = {
...baseGrpcOptions,
target: pinnedUrl,
grpcChannelOptions: getGrpcChannelOptions(hostname),
} as ConstructorParameters<typeof GrpcTransportFactory>[0] &
InternalGrpcExtensions;
// Configure standard SDK client for tool registration and discovery
const clientOptions = ClientFactoryOptions.createFrom(
ClientFactoryOptions.default,
@@ -170,11 +121,9 @@ export class A2AClientManager {
transports: [
new RestTransportFactory({ fetchImpl: authFetch }),
new JsonRpcTransportFactory({ fetchImpl: authFetch }),
new GrpcTransportFactory(
transportOptions as ConstructorParameters<
typeof GrpcTransportFactory
>[0],
),
new GrpcTransportFactory({
grpcChannelCredentials: getGrpcCredentials(agentCard.url),
}),
],
cardResolver: resolver,
},
@@ -182,8 +131,7 @@ export class A2AClientManager {
try {
const factory = new ClientFactory(clientOptions);
const client = await factory.createFromUrl(agentCardUrl, '');
const agentCard = await client.getAgentCard();
const client = await factory.createFromAgentCard(agentCard);
this.clients.set(name, client);
this.agentCards.set(name, agentCard);
@@ -307,65 +255,13 @@ export class A2AClientManager {
}
}
/**
* Resolves and normalizes an agent card from a given URL.
* Handles splitting the URL if it already contains the standard .well-known path.
* Also performs basic SSRF validation to prevent internal IP access.
*/
private async resolveAgentCard(
agentName: string,
url: string,
resolver: DefaultAgentCardResolver,
): Promise<AgentCard> {
// Validate URL to prevent SSRF (with DNS resolution)
if (await isPrivateIpAsync(url)) {
// Local/private IPs are allowed ONLY for localhost for testing.
const parsed = new URL(url);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
);
}
}
const { baseUrl, path } = splitAgentCardUrl(url);
const rawCard = await resolver.resolve(baseUrl, path);
const rawCard = await resolver.resolve(url);
const agentCard = normalizeAgentCard(rawCard);
// Deep validation of all transport URLs within the card to prevent SSRF
await this.validateAgentCardUrls(agentName, agentCard);
return agentCard;
}
/**
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
*/
private async validateAgentCardUrls(
agentName: string,
card: AgentCard,
): Promise<void> {
const urlsToValidate = [card.url];
if (card.additionalInterfaces) {
for (const intf of card.additionalInterfaces) {
if (intf.url) urlsToValidate.push(intf.url);
}
}
for (const url of urlsToValidate) {
if (!url) continue;
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
const validationUrl = url.includes('://') ? url : `http://${url}`;
if (await isPrivateIpAsync(validationUrl)) {
const parsed = new URL(validationUrl);
if (!isLoopbackHost(parsed.hostname)) {
throw new Error(
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
);
}
}
}
}
}

View File

@@ -13,8 +13,6 @@ import {
AUTH_REQUIRED_MSG,
normalizeAgentCard,
getGrpcCredentials,
pinUrlToIp,
splitAgentCardUrl,
} from './a2aUtils.js';
import type { SendMessageResult } from './a2a-client-manager.js';
import type {
@@ -26,12 +24,6 @@ import type {
TaskStatusUpdateEvent,
TaskArtifactUpdateEvent,
} from '@a2a-js/sdk';
import * as dnsPromises from 'node:dns/promises';
import type { LookupAddress } from 'node:dns';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
describe('a2aUtils', () => {
beforeEach(() => {
@@ -54,77 +46,6 @@ describe('a2aUtils', () => {
});
});
describe('pinUrlToIp', () => {
it('should resolve and pin hostname to IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
});
it('should handle raw host:port strings (standard for gRPC)', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '93.184.216.34', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'example.com:9000',
'test-agent',
);
expect(hostname).toBe('example.com');
expect(pinnedUrl).toBe('93.184.216.34:9000');
});
it('should throw error if resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(
pinUrlToIp('http://unreachable.com', 'test-agent'),
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
});
it('should throw error if resolved to private IP', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '10.0.0.1', family: 4 }]);
await expect(
pinUrlToIp('http://malicious.com', 'test-agent'),
).rejects.toThrow('resolves to private IP range');
});
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
vi.mocked(
dnsPromises.lookup as unknown as (
hostname: string,
options: { all: true },
) => Promise<LookupAddress[]>,
).mockResolvedValue([{ address: '127.0.0.1', family: 4 }]);
const { pinnedUrl, hostname } = await pinUrlToIp(
'http://localhost:9000',
'test-agent',
);
expect(hostname).toBe('localhost');
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
});
});
describe('isTerminalState', () => {
it('should return true for completed, failed, canceled, and rejected', () => {
expect(isTerminalState('completed')).toBe(true);
@@ -450,49 +371,6 @@ describe('a2aUtils', () => {
});
});
describe('splitAgentCardUrl', () => {
const standard = '.well-known/agent-card.json';
it('should return baseUrl as-is if it does not end with standard path', () => {
const url = 'http://localhost:9001/custom/path';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should split correctly if URL ends with standard path', () => {
const url = `http://localhost:9001/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should handle trailing slash in baseUrl when splitting', () => {
const url = `http://example.com/api/${standard}`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://example.com/api/',
path: undefined,
});
});
it('should ignore hashes and query params when splitting', () => {
const url = `http://localhost:9001/${standard}?foo=bar#baz`;
expect(splitAgentCardUrl(url)).toEqual({
baseUrl: 'http://localhost:9001/',
path: undefined,
});
});
it('should return original URL if parsing fails', () => {
const url = 'not-a-url';
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
it('should handle standard path appearing earlier in the path', () => {
const url = `http://localhost:9001/${standard}/something-else`;
expect(splitAgentCardUrl(url)).toEqual({ baseUrl: url });
});
});
describe('A2AResultReassembler', () => {
it('should reassemble sequential messages and incremental artifacts', () => {
const reassembler = new A2AResultReassembler();

View File

@@ -5,7 +5,6 @@
*/
import * as grpc from '@grpc/grpc-js';
import { lookup } from 'node:dns/promises';
import { z } from 'zod';
import type {
Message,
@@ -18,7 +17,6 @@ import type {
AgentCard,
AgentInterface,
} from '@a2a-js/sdk';
import { isAddressPrivate } from '../utils/fetch.js';
import type { SendMessageResult } from './a2a-client-manager.js';
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
@@ -289,118 +287,6 @@ export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
: grpc.credentials.createInsecure();
}
/**
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
* when connecting via a pinned IP address.
*/
export function getGrpcChannelOptions(
hostname: string,
): Record<string, unknown> {
return {
'grpc.default_authority': hostname,
'grpc.ssl_target_name_override': hostname,
};
}
/**
* Resolves a hostname to its IP address and validates it against SSRF.
* Returns the pinned IP-based URL and the original hostname.
*/
export async function pinUrlToIp(
url: string,
agentName: string,
): Promise<{ pinnedUrl: string; hostname: string }> {
if (!url) return { pinnedUrl: url, hostname: '' };
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
// We normalize to host:port for resolution.
const hasScheme = url.includes('://');
const normalizedUrl = hasScheme ? url : `http://${url}`;
try {
const parsed = new URL(normalizedUrl);
const hostname = parsed.hostname;
const sanitizedHost =
hostname.startsWith('[') && hostname.endsWith(']')
? hostname.slice(1, -1)
: hostname;
// Resolve DNS to check the actual target IP and pin it
const addresses = await lookup(hostname, { all: true });
const publicAddresses = addresses.filter(
(addr) =>
!isAddressPrivate(addr.address) ||
sanitizedHost === 'localhost' ||
sanitizedHost === '127.0.0.1' ||
sanitizedHost === '::1',
);
if (publicAddresses.length === 0) {
if (addresses.length > 0) {
throw new Error(
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
);
}
throw new Error(
`Failed to resolve any public IP addresses for host: ${hostname}`,
);
}
const pinnedIp = publicAddresses[0].address;
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
// Reconstruct URL with IP
parsed.hostname = pinnedHostname;
let pinnedUrl = parsed.toString();
// If original didn't have scheme, remove it (standard for gRPC targets)
if (!hasScheme) {
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
// URL.toString() might append a trailing slash
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
pinnedUrl = pinnedUrl.slice(0, -1);
}
}
return { pinnedUrl, hostname };
} catch (e) {
if (e instanceof Error && e.message.includes('Refusing')) throw e;
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
cause: e,
});
}
}
/**
* Splts an agent card URL into a baseUrl and a standard path if it already
* contains '.well-known/agent-card.json'.
*/
export function splitAgentCardUrl(url: string): {
baseUrl: string;
path?: string;
} {
const standardPath = '.well-known/agent-card.json';
try {
const parsedUrl = new URL(url);
if (parsedUrl.pathname.endsWith(standardPath)) {
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
parsedUrl.pathname = parsedUrl.pathname.substring(
0,
parsedUrl.pathname.lastIndexOf(standardPath),
);
parsedUrl.search = '';
parsedUrl.hash = '';
// We return undefined for path if it's the standard one,
// because the SDK's DefaultAgentCardResolver appends it automatically.
return { baseUrl: parsedUrl.toString(), path: undefined };
}
} catch (_e) {
// Ignore URL parsing errors here, let the resolver handle them.
}
return { baseUrl: url };
}
/**
* Extracts contextId and taskId from a Message, Task, or Update response.
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.

View File

@@ -700,7 +700,6 @@ async function fetchAndCacheUserInfo(client: OAuth2Client): Promise<void> {
return;
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(
'https://www.googleapis.com/oauth2/v2/userinfo',
{

View File

@@ -111,7 +111,6 @@ export class MCPOAuthProvider {
scope: config.scopes?.join(' ') || '',
};
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(registrationUrl, {
method: 'POST',
headers: {
@@ -301,7 +300,6 @@ export class MCPOAuthProvider {
? { Accept: 'text/event-stream' }
: { Accept: 'application/json' };
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(mcpServerUrl, {
method: 'HEAD',
headers,

View File

@@ -97,7 +97,6 @@ export class OAuthUtils {
resourceMetadataUrl: string,
): Promise<OAuthProtectedResourceMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(resourceMetadataUrl);
if (!response.ok) {
return null;
@@ -122,7 +121,6 @@ export class OAuthUtils {
authServerMetadataUrl: string,
): Promise<OAuthAuthorizationServerMetadata | null> {
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(authServerMetadataUrl);
if (!response.ok) {
return null;

View File

@@ -534,7 +534,6 @@ export class ClearcutLogger {
let result: LogResponse = {};
try {
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(CLEARCUT_URL, {
method: 'POST',
body: safeJsonStringify(request),

View File

@@ -1903,7 +1903,6 @@ export async function connectToMcpServer(
acceptHeader = 'application/json';
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(urlToFetch, {
method: 'HEAD',
headers: {

View File

@@ -5,27 +5,12 @@
*/
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
import {
isPrivateIp,
isPrivateIpAsync,
isAddressPrivate,
safeLookup,
safeFetch,
fetchWithTimeout,
PrivateIpError,
} from './fetch.js';
import * as dnsPromises from 'node:dns/promises';
import * as dns from 'node:dns';
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js';
vi.mock('node:dns/promises', () => ({
lookup: vi.fn(),
}));
// We need to mock node:dns for safeLookup since it uses the callback API
vi.mock('node:dns', () => ({
lookup: vi.fn(),
}));
// Mock global fetch
const originalFetch = global.fetch;
global.fetch = vi.fn();
@@ -114,150 +99,6 @@ describe('fetch utils', () => {
});
});
describe('isPrivateIpAsync', () => {
it('should identify private IPs directly', async () => {
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
});
it('should identify domains resolving to private IPs', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '10.0.0.1', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
});
it('should identify domains resolving to public IPs as non-private', async () => {
vi.mocked(dnsPromises.lookup).mockImplementation(
async () =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
[{ address: '8.8.8.8', family: 4 }] as any,
);
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
});
it('should throw error if DNS resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
'Failed to verify if URL resolves to private IP',
);
});
it('should return false for invalid URLs instead of throwing verification error', async () => {
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
});
});
describe('safeLookup', () => {
it('should filter out private IPs', async () => {
const addresses = [
{ address: '8.8.8.8', family: 4 },
{ address: '10.0.0.1', family: 4 },
];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('example.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('8.8.8.8');
});
it('should allow explicit localhost', async () => {
const addresses = [{ address: '127.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
const result = await new Promise<
Array<{ address: string; family: number }>
>((resolve, reject) => {
safeLookup('localhost', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
});
expect(result).toHaveLength(1);
expect(result[0].address).toBe('127.0.0.1');
});
it('should error if all resolved IPs are private', async () => {
const addresses = [{ address: '10.0.0.1', family: 4 }];
vi.mocked(dns.lookup).mockImplementation(((
_h: string,
_o: dns.LookupOptions,
cb: (
err: Error | null,
addr: Array<{ address: string; family: number }>,
) => void,
) => {
cb(null, addresses);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
}) as any);
await expect(
new Promise((resolve, reject) => {
safeLookup('malicious.com', { all: true }, (err, filtered) => {
if (err) reject(err);
else resolve(filtered);
});
}),
).rejects.toThrow(PrivateIpError);
});
});
describe('safeFetch', () => {
it('should forward to fetch with dispatcher', async () => {
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
const response = await safeFetch('https://example.com');
expect(response.status).toBe(200);
expect(global.fetch).toHaveBeenCalledWith(
'https://example.com',
expect.objectContaining({
dispatcher: expect.any(Object),
}),
);
});
it('should handle Refusing to connect errors', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
'Access to private network is blocked',
);
});
});
describe('fetchWithTimeout', () => {
it('should handle timeouts', async () => {
vi.mocked(global.fetch).mockImplementation(
@@ -279,13 +120,5 @@ describe('fetch utils', () => {
'Request timed out after 50ms',
);
});
it('should handle private IP errors via handleFetchError', async () => {
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow(
'Access to private network is blocked: http://10.0.0.1',
);
});
});
});

View File

@@ -6,37 +6,12 @@
import { getErrorMessage, isNodeError } from './errors.js';
import { URL } from 'node:url';
import * as dns from 'node:dns';
import { lookup } from 'node:dns/promises';
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
import ipaddr from 'ipaddr.js';
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
// Local extension of RequestInit to support Node.js/undici dispatcher
interface NodeFetchInit extends RequestInit {
dispatcher?: Agent | ProxyAgent;
}
/**
* Error thrown when a connection to a private IP address is blocked for security reasons.
*/
export class PrivateIpError extends Error {
constructor(message = 'Refusing to connect to private IP address') {
super(message);
this.name = 'PrivateIpError';
}
}
export class FetchError extends Error {
constructor(
message: string,
@@ -48,6 +23,14 @@ export class FetchError extends Error {
}
}
// Configure default global dispatcher with higher timeouts
setGlobalDispatcher(
new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
}),
);
/**
* Sanitizes a hostname by stripping IPv6 brackets if present.
*/
@@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean {
);
}
/**
* A custom DNS lookup implementation for undici agents that prevents
* connection to private IP ranges (SSRF protection).
*/
export function safeLookup(
hostname: string,
options: dns.LookupOptions | number | null | undefined,
callback: (
err: Error | null,
addresses: Array<{ address: string; family: number }>,
) => void,
): void {
// Use the callback-based dns.lookup to match undici's expected signature.
// We explicitly handle the 'all' option to ensure we get an array of addresses.
const lookupOptions =
typeof options === 'number' ? { family: options } : { ...options };
const finalOptions = { ...lookupOptions, all: true };
dns.lookup(hostname, finalOptions, (err, addresses) => {
if (err) {
callback(err, []);
return;
}
const addressArray = Array.isArray(addresses) ? addresses : [];
const filtered = addressArray.filter(
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
);
if (filtered.length === 0 && addressArray.length > 0) {
callback(new PrivateIpError(), []);
return;
}
callback(null, filtered);
});
}
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
const safeDispatcher = new Agent({
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
bodyTimeout: DEFAULT_BODY_TIMEOUT,
connect: {
lookup: safeLookup,
},
});
export function isPrivateIp(url: string): boolean {
try {
const hostname = new URL(url).hostname;
@@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean {
}
}
/**
* Checks if a URL resolves to a private IP address.
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
*/
export async function isPrivateIpAsync(url: string): Promise<boolean> {
try {
const parsed = new URL(url);
const hostname = parsed.hostname;
// Fast check for literal IPs or localhost
if (isAddressPrivate(hostname)) {
return true;
}
// Resolve DNS to check the actual target IP
const addresses = await lookup(hostname, { all: true });
return addresses.some((addr) => isAddressPrivate(addr.address));
} catch (e) {
if (
e instanceof Error &&
e.name === 'TypeError' &&
e.message.includes('Invalid URL')
) {
return false;
}
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
cause: e,
});
}
}
/**
* IANA Benchmark Testing Range (198.18.0.0/15).
* Classified as 'unicast' by ipaddr.js but is reserved and should not be
@@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean {
}
}
/**
* Internal helper to map varied fetch errors to a standardized FetchError.
* Centralizes security-related error mapping (e.g. PrivateIpError).
*/
function handleFetchError(error: unknown, url: string): never {
if (error instanceof PrivateIpError) {
throw new FetchError(
`Access to private network is blocked: ${url}`,
'ERR_PRIVATE_NETWORK',
{ cause: error },
);
}
if (error instanceof FetchError) {
throw error;
}
throw new FetchError(
getErrorMessage(error),
isNodeError(error) ? error.code : undefined,
{ cause: error },
);
}
/**
* Enhanced fetch with SSRF protection.
* Prevents access to private/internal networks at the connection level.
*/
export async function safeFetch(
input: RequestInfo | URL,
init?: NodeFetchInit,
): Promise<Response> {
const nodeInit: NodeFetchInit = {
dispatcher: safeDispatcher,
...init,
};
try {
// eslint-disable-next-line no-restricted-syntax
return await fetch(input, nodeInit as RequestInit);
} catch (error) {
const url =
input instanceof Request
? input.url
: typeof input === 'string'
? input
: input.toString();
handleFetchError(error, url);
}
}
/**
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
*/
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
return new ProxyAgent({
uri: proxyUrl,
connect: {
lookup: safeLookup,
},
});
}
/**
* Performs a fetch with a specified timeout and connection-level SSRF protection.
*/
export async function fetchWithTimeout(
url: string,
timeout: number,
@@ -294,21 +142,17 @@ export async function fetchWithTimeout(
}
}
const nodeInit: NodeFetchInit = {
...options,
signal: controller.signal,
dispatcher: safeDispatcher,
};
try {
// eslint-disable-next-line no-restricted-syntax
const response = await fetch(url, nodeInit);
const response = await fetch(url, {
...options,
signal: controller.signal,
});
return response;
} catch (error) {
if (isNodeError(error) && error.code === 'ABORT_ERR') {
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
}
handleFetchError(error, url.toString());
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
} finally {
clearTimeout(timeoutId);
}

View File

@@ -454,7 +454,6 @@ export async function exchangeCodeForToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(config.tokenUrl, {
method: 'POST',
headers: {
@@ -508,7 +507,6 @@ export async function refreshAccessToken(
params.append('resource', resource);
}
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
const response = await fetch(tokenUrl, {
method: 'POST',
headers: {