diff --git a/packages/core/src/agents/a2a-client-manager.test.ts b/packages/core/src/agents/a2a-client-manager.test.ts index f52a9ccee6..96c91ae33f 100644 --- a/packages/core/src/agents/a2a-client-manager.test.ts +++ b/packages/core/src/agents/a2a-client-manager.test.ts @@ -8,7 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest'; import { A2AClientManager } from './a2a-client-manager.js'; import type { AgentCard } from '@a2a-js/sdk'; import * as sdkClient from '@a2a-js/sdk/client'; -import { lookup } from 'node:dns/promises'; +import * as dnsPromises from 'node:dns/promises'; +import type { LookupOptions } from 'node:dns'; import { debugLogger } from '../utils/debugLogger.js'; interface MockClient { @@ -38,7 +39,7 @@ vi.mock('../utils/debugLogger.js', () => ({ })); vi.mock('node:dns/promises', () => ({ - lookup: vi.fn().mockResolvedValue([{ address: '93.184.216.34' }]), + lookup: vi.fn(), })); describe('A2AClientManager', () => { @@ -65,8 +66,19 @@ describe('A2AClientManager', () => { beforeEach(() => { vi.clearAllMocks(); - A2AClientManager.resetInstanceForTesting(); manager = A2AClientManager.getInstance(); + manager.clearCache(); + + // Default DNS mock: resolve to public IP. + // Use any cast only for return value due to complex multi-signature overloads. + vi.mocked(dnsPromises.lookup).mockImplementation( + async (_h: string, options?: LookupOptions | number) => { + const addr = { address: '93.184.216.34', family: 4 }; + const isAll = typeof options === 'object' && options?.all; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); // Re-create the instances as plain objects that can be spied on const factoryInstance = { @@ -253,7 +265,10 @@ describe('A2AClientManager', () => { contextId: 'ctx123', taskId: 'task456', }); - await stream[Symbol.asyncIterator]().next(); + // trigger execution + for await (const _ of stream) { + break; + } expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.objectContaining({ @@ -277,7 +292,10 @@ describe('A2AClientManager', () => { const stream = manager.sendMessageStream('TestAgent', 'Hello', { signal: controller.signal, }); - await stream[Symbol.asyncIterator]().next(); + // trigger execution + for await (const _ of stream) { + break; + } expect(mockClient.sendMessageStream).toHaveBeenCalledWith( expect.any(Object), @@ -415,15 +433,19 @@ describe('A2AClientManager', () => { it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => { const maliciousDomainUrl = 'http://malicious.com/.well-known/agent-card.json'; - vi.mocked(lookup).mockResolvedValueOnce([ - { address: '10.0.0.1', family: 4 }, - ]); + + vi.mocked(dnsPromises.lookup).mockImplementationOnce( + async (_h: string, options?: LookupOptions | number) => { + const addr = { address: '10.0.0.1', family: 4 }; + const isAll = typeof options === 'object' && options?.all; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); await expect( manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl), - ).rejects.toThrow( - /Refusing to load agent 'dns-ssrf-agent' from private IP range/, - ); + ).rejects.toThrow(/private IP range/); }); it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => { @@ -438,6 +460,21 @@ describe('A2AClientManager', () => { resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, ); + // DNS for public.agent.com is public + vi.mocked(dnsPromises.lookup).mockImplementation( + async (hostname: string, options?: LookupOptions | number) => { + const isAll = typeof options === 'object' && options?.all; + if (hostname === 'public.agent.com') { + const addr = { address: '1.1.1.1', family: 4 }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + } + const addr = { address: '192.168.1.1', family: 4 }; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + return (isAll ? [addr] : addr) as any; + }, + ); + await expect( manager.loadAgent('malicious-agent', publicUrl), ).rejects.toThrow( @@ -463,5 +500,24 @@ describe('A2AClientManager', () => { undefined, ); }); + + it('should correctly handle URLs with standardPath in the hash fragment', async () => { + const fragmentUrl = + 'http://localhost:9001/.well-known/agent-card.json#.well-known/agent-card.json'; + const resolverInstance = { + resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard), + }; + vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue( + resolverInstance as unknown as sdkClient.DefaultAgentCardResolver, + ); + + await manager.loadAgent('fragment-agent', fragmentUrl); + + // Should correctly ignore the hash fragment and use the path from the URL object + expect(resolverInstance.resolve).toHaveBeenCalledWith( + 'http://localhost:9001/', + '.well-known/agent-card.json', + ); + }); }); }); diff --git a/packages/core/src/agents/a2a-client-manager.ts b/packages/core/src/agents/a2a-client-manager.ts index 8099de5ce6..cefde87daf 100644 --- a/packages/core/src/agents/a2a-client-manager.ts +++ b/packages/core/src/agents/a2a-client-manager.ts @@ -12,45 +12,65 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; +import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client'; import { ClientFactory, ClientFactoryOptions, DefaultAgentCardResolver, - RestTransportFactory, JsonRpcTransportFactory, - type AuthenticationHandler, + RestTransportFactory, createAuthenticatingFetchWithRetry, - type Client, } from '@a2a-js/sdk/client'; import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc'; import { v4 as uuidv4 } from 'uuid'; import { Agent as UndiciAgent } from 'undici'; -import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js'; -import { isPrivateIpAsync } from '../utils/fetch.js'; +import { + getGrpcChannelOptions, + getGrpcCredentials, + normalizeAgentCard, + pinUrlToIp, +} from './a2aUtils.js'; +import { + isPrivateIpAsync, + safeLookup, + isLoopbackHost, +} from '../utils/fetch.js'; import { debugLogger } from '../utils/debugLogger.js'; +/** + * Result of sending a message, which can be a full message, a task, + * or an incremental status/artifact update. + */ +export type SendMessageResult = + | Message + | Task + | TaskStatusUpdateEvent + | TaskArtifactUpdateEvent; + +// Local extension of RequestInit to support Node.js/undici dispatcher +interface NodeFetchInit extends RequestInit { + dispatcher?: UndiciAgent; +} + // Remote agents can take 10+ minutes (e.g. Deep Research). // Use a dedicated dispatcher so the global 5-min timeout isn't affected. const A2A_TIMEOUT = 1800000; // 30 minutes const a2aDispatcher = new UndiciAgent({ headersTimeout: A2A_TIMEOUT, bodyTimeout: A2A_TIMEOUT, + connect: { + // SSRF protection at the connection level (mitigates DNS rebinding) + lookup: safeLookup, + }, }); -const a2aFetch: typeof fetch = (input, init) => - // @ts-expect-error The `dispatcher` property is a Node.js extension to fetch not present in standard types. - fetch(input, { ...init, dispatcher: a2aDispatcher }); - -export type SendMessageResult = - | Message - | Task - | TaskStatusUpdateEvent - | TaskArtifactUpdateEvent; +const a2aFetch: typeof fetch = (input, init) => { + const nodeInit: NodeFetchInit = { ...init, dispatcher: a2aDispatcher }; + return fetch(input, nodeInit); +}; /** - * Orchestrates communication with A2A agents. - * - * This manager handles agent discovery, card caching, and client lifecycle. - * It provides a unified messaging interface using the standard A2A SDK. + * Orchestrates communication with remote A2A agents. + * Manages protocol negotiation, authentication, and transport selection. */ export class A2AClientManager { private static instance: A2AClientManager; @@ -71,19 +91,10 @@ export class A2AClientManager { return A2AClientManager.instance; } - /** - * Resets the singleton instance. Only for testing purposes. - * @internal - */ - static resetInstanceForTesting() { - // @ts-expect-error - Resetting singleton for testing - A2AClientManager.instance = undefined; - } - /** * Loads an agent by fetching its AgentCard and caches the client. * @param name The name to assign to the agent. - * @param agentCardUrl The full URL to the agent's card. + * @param agentCardUrl {string} The full URL to the agent's card. * @param authHandler Optional authentication handler to use for this agent. * @returns The loaded AgentCard. */ @@ -100,6 +111,29 @@ export class A2AClientManager { const resolver = new DefaultAgentCardResolver({ fetchImpl }); const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver); + // Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection) + const grpcInterface = agentCard.additionalInterfaces?.find( + (i) => i.transport === 'GRPC', + ); + const urlToPin = grpcInterface?.url ?? agentCard.url; + const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name); + + // Prepare base gRPC options + const baseGrpcOptions: ConstructorParameters< + typeof GrpcTransportFactory + >[0] = { + grpcChannelCredentials: getGrpcCredentials(urlToPin), + }; + + // Include extension properties required for SSRF pinning using object spread. + // This allows us to pass additional properties that the SDK uses internally + // without triggering narrow-to-broad type assertion warnings. + const fullGrpcOptions = { + ...baseGrpcOptions, + target: pinnedUrl, + grpcChannelOptions: getGrpcChannelOptions(hostname), + }; + // Configure standard SDK client for tool registration and discovery const clientOptions = ClientFactoryOptions.createFrom( ClientFactoryOptions.default, @@ -107,9 +141,11 @@ export class A2AClientManager { transports: [ new RestTransportFactory({ fetchImpl }), new JsonRpcTransportFactory({ fetchImpl }), - new GrpcTransportFactory({ - grpcChannelCredentials: getGrpcCredentials(agentCard.url), - }), + new GrpcTransportFactory( + fullGrpcOptions as ConstructorParameters< + typeof GrpcTransportFactory + >[0], + ), ], cardResolver: resolver, }, @@ -166,7 +202,7 @@ export class A2AClientManager { try { yield* client.sendMessageStream(messageParams, { signal: options?.signal, - }); + }) as AsyncIterable; } catch (error: unknown) { const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`; if (error instanceof Error) { @@ -263,7 +299,7 @@ export class A2AClientManager { if (await isPrivateIpAsync(url)) { // Local/private IPs are allowed ONLY for localhost for testing. const parsed = new URL(url); - if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') { + if (!isLoopbackHost(parsed.hostname)) { throw new Error( `Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`, ); @@ -275,7 +311,14 @@ export class A2AClientManager { if (parsedUrl.pathname.endsWith(standardPath)) { // Correctly split the URL into baseUrl and standard path path = standardPath; - baseUrl = url.substring(0, url.lastIndexOf(standardPath)); + // Reconstruct baseUrl from parsed components to avoid issues with hashes or query params. + parsedUrl.pathname = parsedUrl.pathname.substring( + 0, + parsedUrl.pathname.lastIndexOf(standardPath), + ); + parsedUrl.search = ''; + parsedUrl.hash = ''; + baseUrl = parsedUrl.toString(); } } catch (e) { throw new Error(`Invalid agent card URL: ${url}`, { cause: e }); @@ -312,10 +355,7 @@ export class A2AClientManager { if (await isPrivateIpAsync(validationUrl)) { const parsed = new URL(validationUrl); - if ( - parsed.hostname !== 'localhost' && - parsed.hostname !== '127.0.0.1' - ) { + if (!isLoopbackHost(parsed.hostname)) { throw new Error( `Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`, ); diff --git a/packages/core/src/agents/a2aUtils.test.ts b/packages/core/src/agents/a2aUtils.test.ts index e9bae9a083..2257975c66 100644 --- a/packages/core/src/agents/a2aUtils.test.ts +++ b/packages/core/src/agents/a2aUtils.test.ts @@ -4,7 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect } from 'vitest'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; import { extractMessageText, extractIdsFromResponse, @@ -13,6 +13,7 @@ import { AUTH_REQUIRED_MSG, normalizeAgentCard, getGrpcCredentials, + pinUrlToIp, } from './a2aUtils.js'; import type { SendMessageResult } from './a2a-client-manager.js'; import type { @@ -24,8 +25,17 @@ import type { TaskStatusUpdateEvent, TaskArtifactUpdateEvent, } from '@a2a-js/sdk'; +import * as dnsPromises from 'node:dns/promises'; + +vi.mock('node:dns/promises', () => ({ + lookup: vi.fn(), +})); describe('a2aUtils', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + describe('getGrpcCredentials', () => { it('should return secure credentials for https', () => { const credentials = getGrpcCredentials('https://test.agent'); @@ -38,6 +48,69 @@ describe('a2aUtils', () => { }); }); + describe('pinUrlToIp', () => { + it('should resolve and pin hostname to IP', async () => { + vi.mocked(dnsPromises.lookup).mockResolvedValue([ + { address: '93.184.216.34', family: 4 }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + + const { pinnedUrl, hostname } = await pinUrlToIp( + 'http://example.com:9000', + 'test-agent', + ); + expect(hostname).toBe('example.com'); + expect(pinnedUrl).toBe('http://93.184.216.34:9000/'); + }); + + it('should handle raw host:port strings (standard for gRPC)', async () => { + vi.mocked(dnsPromises.lookup).mockResolvedValue([ + { address: '93.184.216.34', family: 4 }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + + const { pinnedUrl, hostname } = await pinUrlToIp( + 'example.com:9000', + 'test-agent', + ); + expect(hostname).toBe('example.com'); + expect(pinnedUrl).toBe('93.184.216.34:9000'); + }); + + it('should throw error if resolution fails (fail closed)', async () => { + vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); + + await expect( + pinUrlToIp('http://unreachable.com', 'test-agent'), + ).rejects.toThrow("Failed to resolve host for agent 'test-agent'"); + }); + + it('should throw error if resolved to private IP', async () => { + vi.mocked(dnsPromises.lookup).mockResolvedValue([ + { address: '10.0.0.1', family: 4 }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + + await expect( + pinUrlToIp('http://malicious.com', 'test-agent'), + ).rejects.toThrow('resolves to private IP range'); + }); + + it('should allow localhost/127.0.0.1/::1 exceptions', async () => { + vi.mocked(dnsPromises.lookup).mockResolvedValue([ + { address: '127.0.0.1', family: 4 }, + // eslint-disable-next-line @typescript-eslint/no-explicit-any + ] as any); + + const { pinnedUrl, hostname } = await pinUrlToIp( + 'http://localhost:9000', + 'test-agent', + ); + expect(hostname).toBe('localhost'); + expect(pinnedUrl).toBe('http://127.0.0.1:9000/'); + }); + }); + describe('isTerminalState', () => { it('should return true for completed, failed, canceled, and rejected', () => { expect(isTerminalState('completed')).toBe(true); diff --git a/packages/core/src/agents/a2aUtils.ts b/packages/core/src/agents/a2aUtils.ts index 917dfb33a2..3e45870182 100644 --- a/packages/core/src/agents/a2aUtils.ts +++ b/packages/core/src/agents/a2aUtils.ts @@ -5,6 +5,7 @@ */ import * as grpc from '@grpc/grpc-js'; +import { lookup } from 'node:dns/promises'; import type { Message, Part, @@ -14,9 +15,12 @@ import type { Artifact, TaskState, TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, AgentCard, AgentInterface, + Task, } from '@a2a-js/sdk'; +import { isAddressPrivate } from '../utils/fetch.js'; import type { SendMessageResult } from './a2a-client-manager.js'; export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`; @@ -36,80 +40,68 @@ export class A2AResultReassembler { update(chunk: SendMessageResult) { if (!('kind' in chunk)) return; - switch (chunk.kind) { - case 'status-update': - this.appendStateInstructions(chunk.status?.state); - this.pushMessage(chunk.status?.message); - break; + if (isStatusUpdateEvent(chunk)) { + this.appendStateInstructions(chunk.status?.state); + this.pushMessage(chunk.status?.message); + } else if (isArtifactUpdateEvent(chunk)) { + if (chunk.artifact) { + const id = chunk.artifact.artifactId; + const existing = this.artifacts.get(id); - case 'artifact-update': - if (chunk.artifact) { - const id = chunk.artifact.artifactId; - const existing = this.artifacts.get(id); - - if (chunk.append && existing) { - for (const part of chunk.artifact.parts) { - existing.parts.push(structuredClone(part)); - } - } else { - this.artifacts.set(id, structuredClone(chunk.artifact)); - } - - const newText = extractPartsText(chunk.artifact.parts, ''); - let chunks = this.artifactChunks.get(id); - if (!chunks) { - chunks = []; - this.artifactChunks.set(id, chunks); - } - if (chunk.append) { - chunks.push(newText); - } else { - chunks.length = 0; - chunks.push(newText); + if (chunk.append && existing) { + for (const part of chunk.artifact.parts) { + existing.parts.push(structuredClone(part)); } + } else { + this.artifacts.set(id, structuredClone(chunk.artifact)); } - break; - case 'task': - this.appendStateInstructions(chunk.status?.state); - this.pushMessage(chunk.status?.message); - if (chunk.artifacts) { - for (const art of chunk.artifacts) { - this.artifacts.set(art.artifactId, structuredClone(art)); - this.artifactChunks.set(art.artifactId, [ - extractPartsText(art.parts, ''), - ]); - } + const newText = extractPartsText(chunk.artifact.parts, ''); + let chunks = this.artifactChunks.get(id); + if (!chunks) { + chunks = []; + this.artifactChunks.set(id, chunks); } - // History Fallback: Some agent implementations do not populate the - // status.message in their final terminal response, instead archiving - // the final answer in the task's history array. To ensure we don't - // present an empty result, we fallback to the most recent agent message - // in the history only when the task is terminal and no other content - // (message log or artifacts) has been reassembled. - if ( - isTerminalState(chunk.status?.state) && - this.messageLog.length === 0 && - this.artifacts.size === 0 && - chunk.history && - chunk.history.length > 0 - ) { - const lastAgentMsg = [...chunk.history] - .reverse() - .find((m) => m.role?.toLowerCase().includes('agent')); - if (lastAgentMsg) { - this.pushMessage(lastAgentMsg); - } + if (chunk.append) { + chunks.push(newText); + } else { + chunks.length = 0; + chunks.push(newText); } - break; - - case 'message': { - this.pushMessage(chunk); - break; } - - default: - break; + } else if (isTask(chunk)) { + this.appendStateInstructions(chunk.status?.state); + this.pushMessage(chunk.status?.message); + if (chunk.artifacts) { + for (const art of chunk.artifacts) { + this.artifacts.set(art.artifactId, structuredClone(art)); + this.artifactChunks.set(art.artifactId, [ + extractPartsText(art.parts, ''), + ]); + } + } + // History Fallback: Some agent implementations do not populate the + // status.message in their final terminal response, instead archiving + // the final answer in the task's history array. To ensure we don't + // present an empty result, we fallback to the most recent agent message + // in the history only when the task is terminal and no other content + // (message log or artifacts) has been reassembled. + if ( + isTerminalState(chunk.status?.state) && + this.messageLog.length === 0 && + this.artifacts.size === 0 && + chunk.history && + chunk.history.length > 0 + ) { + const lastAgentMsg = [...chunk.history] + .reverse() + .find((m) => m.role?.toLowerCase().includes('agent')); + if (lastAgentMsg) { + this.pushMessage(lastAgentMsg); + } + } + } else if (isMessage(chunk)) { + this.pushMessage(chunk); } } @@ -270,6 +262,84 @@ export function getGrpcCredentials(url: string): grpc.ChannelCredentials { : grpc.credentials.createInsecure(); } +/** + * Returns gRPC channel options to ensure SSL/authority matches the original hostname + * when connecting via a pinned IP address. + */ +export function getGrpcChannelOptions( + hostname: string, +): Record { + return { + 'grpc.default_authority': hostname, + 'grpc.ssl_target_name_override': hostname, + }; +} + +/** + * Resolves a hostname to its IP address and validates it against SSRF. + * Returns the pinned IP-based URL and the original hostname. + */ +export async function pinUrlToIp( + url: string, + agentName: string, +): Promise<{ pinnedUrl: string; hostname: string }> { + if (!url) return { pinnedUrl: url, hostname: '' }; + + // gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes. + // We normalize to host:port for resolution. + const hasScheme = url.includes('://'); + const normalizedUrl = hasScheme ? url : `http://${url}`; + + try { + const parsed = new URL(normalizedUrl); + const hostname = parsed.hostname; + + const sanitizedHost = + hostname.startsWith('[') && hostname.endsWith(']') + ? hostname.slice(1, -1) + : hostname; + + // Resolve DNS to check the actual target IP and pin it + const addresses = await lookup(hostname, { all: true }); + const publicAddresses = addresses.filter( + (addr) => + !isAddressPrivate(addr.address) || + sanitizedHost === 'localhost' || + sanitizedHost === '127.0.0.1' || + sanitizedHost === '::1', + ); + + if (publicAddresses.length === 0 && addresses.length > 0) { + throw new Error( + `Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`, + ); + } + + const pinnedIp = publicAddresses[0].address; + const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp; + + // Reconstruct URL with IP + parsed.hostname = pinnedHostname; + let pinnedUrl = parsed.toString(); + + // If original didn't have scheme, remove it (standard for gRPC targets) + if (!hasScheme) { + pinnedUrl = pinnedUrl.replace(/^http:\/\//, ''); + // URL.toString() might append a trailing slash + if (pinnedUrl.endsWith('/') && !url.endsWith('/')) { + pinnedUrl = pinnedUrl.slice(0, -1); + } + } + + return { pinnedUrl, hostname }; + } catch (e) { + if (e instanceof Error && e.message.includes('Refusing')) throw e; + throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, { + cause: e, + }); + } +} + /** * Extracts contextId and taskId from a Message, Task, or Update response. * Follows the pattern from the A2A CLI sample to maintain conversational continuity. @@ -285,7 +355,7 @@ export function extractIdsFromResponse(result: SendMessageResult): { if ('kind' in result) { const kind = result.kind; - if (kind === 'message' || kind === 'artifact-update') { + if (kind === 'message' || isArtifactUpdateEvent(result)) { taskId = result.taskId; contextId = result.contextId; } else if (kind === 'task') { @@ -393,6 +463,20 @@ function isStatusUpdateEvent( return result.kind === 'status-update'; } +function isArtifactUpdateEvent( + result: SendMessageResult, +): result is TaskArtifactUpdateEvent { + return result.kind === 'artifact-update'; +} + +function isMessage(result: SendMessageResult): result is Message { + return result.kind === 'message'; +} + +function isTask(result: SendMessageResult): result is Task { + return result.kind === 'task'; +} + /** * Returns true if the given state is a terminal state for a task. */ diff --git a/packages/core/src/utils/fetch.test.ts b/packages/core/src/utils/fetch.test.ts new file mode 100644 index 0000000000..7f0f9c3edd --- /dev/null +++ b/packages/core/src/utils/fetch.test.ts @@ -0,0 +1,257 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; +import { + isPrivateIp, + isPrivateIpAsync, + isAddressPrivate, + safeLookup, + safeFetch, +} from './fetch.js'; +import * as dnsPromises from 'node:dns/promises'; +import * as dns from 'node:dns'; + +vi.mock('node:dns/promises', () => ({ + lookup: vi.fn(), +})); + +// We need to mock node:dns for safeLookup since it uses the callback API +vi.mock('node:dns', () => ({ + lookup: vi.fn(), +})); + +// Mock global fetch +const originalFetch = global.fetch; +global.fetch = vi.fn(); + +describe('fetch utils', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + afterAll(() => { + global.fetch = originalFetch; + }); + + describe('isAddressPrivate', () => { + it('should identify private IPv4 addresses', () => { + expect(isAddressPrivate('10.0.0.1')).toBe(true); + expect(isAddressPrivate('127.0.0.1')).toBe(true); + expect(isAddressPrivate('172.16.0.1')).toBe(true); + expect(isAddressPrivate('192.168.1.1')).toBe(true); + }); + + it('should identify non-routable and reserved IPv4 addresses (RFC 6890)', () => { + expect(isAddressPrivate('0.0.0.0')).toBe(true); + expect(isAddressPrivate('100.64.0.1')).toBe(true); + expect(isAddressPrivate('192.0.0.1')).toBe(true); + expect(isAddressPrivate('192.0.2.1')).toBe(true); + expect(isAddressPrivate('192.88.99.1')).toBe(true); + expect(isAddressPrivate('198.18.0.1')).toBe(true); + expect(isAddressPrivate('198.51.100.1')).toBe(true); + expect(isAddressPrivate('203.0.113.1')).toBe(true); + expect(isAddressPrivate('224.0.0.1')).toBe(true); + expect(isAddressPrivate('240.0.0.1')).toBe(true); + }); + + it('should identify private IPv6 addresses', () => { + expect(isAddressPrivate('::1')).toBe(true); + expect(isAddressPrivate('fc00::')).toBe(true); + expect(isAddressPrivate('fd00::')).toBe(true); + expect(isAddressPrivate('fe80::')).toBe(true); + expect(isAddressPrivate('febf::')).toBe(true); + }); + + it('should identify special local addresses', () => { + expect(isAddressPrivate('0.0.0.0')).toBe(true); + expect(isAddressPrivate('::')).toBe(true); + expect(isAddressPrivate('localhost')).toBe(true); + }); + + it('should identify link-local addresses', () => { + expect(isAddressPrivate('169.254.169.254')).toBe(true); + }); + + it('should identify IPv4-mapped IPv6 private addresses', () => { + expect(isAddressPrivate('::ffff:127.0.0.1')).toBe(true); + expect(isAddressPrivate('::ffff:10.0.0.1')).toBe(true); + expect(isAddressPrivate('::ffff:169.254.169.254')).toBe(true); + expect(isAddressPrivate('::ffff:192.168.1.1')).toBe(true); + expect(isAddressPrivate('::ffff:172.16.0.1')).toBe(true); + expect(isAddressPrivate('::ffff:0.0.0.0')).toBe(true); + expect(isAddressPrivate('::ffff:100.64.0.1')).toBe(true); + expect(isAddressPrivate('::ffff:a9fe:101')).toBe(true); // 169.254.1.1 + }); + + it('should identify public addresses as non-private', () => { + expect(isAddressPrivate('8.8.8.8')).toBe(false); + expect(isAddressPrivate('93.184.216.34')).toBe(false); + expect(isAddressPrivate('2001:4860:4860::8888')).toBe(false); + expect(isAddressPrivate('::ffff:8.8.8.8')).toBe(false); + }); + }); + + describe('isPrivateIp', () => { + it('should identify private IPs in URLs', () => { + expect(isPrivateIp('http://10.0.0.1/')).toBe(true); + expect(isPrivateIp('https://127.0.0.1:8080/')).toBe(true); + expect(isPrivateIp('http://localhost/')).toBe(true); + expect(isPrivateIp('http://[::1]/')).toBe(true); + }); + + it('should identify public IPs in URLs as non-private', () => { + expect(isPrivateIp('http://8.8.8.8/')).toBe(false); + expect(isPrivateIp('https://google.com/')).toBe(false); + }); + }); + + describe('isPrivateIpAsync', () => { + it('should identify private IPs directly', async () => { + expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true); + }); + + it('should identify domains resolving to private IPs', async () => { + vi.mocked(dnsPromises.lookup).mockImplementation( + async () => + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [{ address: '10.0.0.1', family: 4 }] as any, + ); + expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true); + }); + + it('should identify domains resolving to public IPs as non-private', async () => { + vi.mocked(dnsPromises.lookup).mockImplementation( + async () => + // eslint-disable-next-line @typescript-eslint/no-explicit-any + [{ address: '8.8.8.8', family: 4 }] as any, + ); + expect(await isPrivateIpAsync('http://google.com/')).toBe(false); + }); + + it('should throw error if DNS resolution fails (fail closed)', async () => { + vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); + await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow( + 'Failed to verify if URL resolves to private IP', + ); + }); + + it('should return false for invalid URLs instead of throwing verification error', async () => { + expect(await isPrivateIpAsync('not-a-url')).toBe(false); + }); + }); + + describe('safeLookup', () => { + it('should filter out private IPs', async () => { + const addresses = [ + { address: '8.8.8.8', family: 4 }, + { address: '10.0.0.1', family: 4 }, + ]; + + vi.mocked(dns.lookup).mockImplementation((( + _h: string, + _o: dns.LookupOptions, + cb: ( + err: Error | null, + addr: Array<{ address: string; family: number }>, + ) => void, + ) => { + cb(null, addresses); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + const result = await new Promise< + Array<{ address: string; family: number }> + >((resolve, reject) => { + safeLookup('example.com', { all: true }, (err, filtered) => { + if (err) reject(err); + else resolve(filtered); + }); + }); + + expect(result).toHaveLength(1); + expect(result[0].address).toBe('8.8.8.8'); + }); + + it('should allow explicit localhost', async () => { + const addresses = [{ address: '127.0.0.1', family: 4 }]; + + vi.mocked(dns.lookup).mockImplementation((( + _h: string, + _o: dns.LookupOptions, + cb: ( + err: Error | null, + addr: Array<{ address: string; family: number }>, + ) => void, + ) => { + cb(null, addresses); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + const result = await new Promise< + Array<{ address: string; family: number }> + >((resolve, reject) => { + safeLookup('localhost', { all: true }, (err, filtered) => { + if (err) reject(err); + else resolve(filtered); + }); + }); + + expect(result).toHaveLength(1); + expect(result[0].address).toBe('127.0.0.1'); + }); + + it('should error if all resolved IPs are private', async () => { + const addresses = [{ address: '10.0.0.1', family: 4 }]; + + vi.mocked(dns.lookup).mockImplementation((( + _h: string, + _o: dns.LookupOptions, + cb: ( + err: Error | null, + addr: Array<{ address: string; family: number }>, + ) => void, + ) => { + cb(null, addresses); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + }) as any); + + await expect( + new Promise((resolve, reject) => { + safeLookup('malicious.com', { all: true }, (err, filtered) => { + if (err) reject(err); + else resolve(filtered); + }); + }), + ).rejects.toThrow('Refusing to connect to private IP address'); + }); + }); + + describe('safeFetch', () => { + it('should forward to fetch with dispatcher', async () => { + vi.mocked(global.fetch).mockResolvedValue(new Response('ok')); + + const response = await safeFetch('https://example.com'); + expect(response.status).toBe(200); + expect(global.fetch).toHaveBeenCalledWith( + 'https://example.com', + expect.objectContaining({ + dispatcher: expect.any(Object), + }), + ); + }); + + it('should handle Refusing to connect errors', async () => { + vi.mocked(global.fetch).mockRejectedValue( + new Error('Refusing to connect to private IP address'), + ); + + await expect(safeFetch('http://10.0.0.1')).rejects.toThrow( + 'Access to private network is blocked', + ); + }); + }); +}); diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index 4f1c81b5fb..efd096833e 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -6,6 +6,7 @@ import { getErrorMessage, isNodeError } from './errors.js'; import { URL } from 'node:url'; +import * as dns from 'node:dns'; import { lookup } from 'node:dns/promises'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; @@ -23,14 +24,28 @@ setGlobalDispatcher( const PRIVATE_IP_RANGES = [ /^10\./, /^127\./, + /^0\./, + /^100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\./, /^169\.254\./, /^172\.(1[6-9]|2[0-9]|3[0-1])\./, + /^192\.0\.(0|2)\./, + /^192\.88\.99\./, /^192\.168\./, + /^198\.(1[8-9]|51\.100)\./, + /^203\.0\.113\./, + /^2(2[4-9]|3[0-9]|[4-5][0-9])\./, /^::1$/, - /^fc00:/, - /^fe80:/, + /^::$/, + /^f[cd]00:/i, // fc00::/7 (ULA) + /^fe[89ab][0-9a-f]:/i, // fe80::/10 (Link-local) + /^::ffff:(10\.|127\.|0\.|100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\.|169\.254\.|172\.(1[6-9]|2[0-9]|3[0-1])\.|192\.0\.(0|2)\.|192\.88\.99\.|192\.168\.|198\.(1[8-9]|51\.100)\.|203\.0\.113\.|2(2[4-9]|3[0-9]|[4-5][0-9])\.|a9fe:|ac1[0-9a-f]:|c0a8:|0+:)/i, // IPv4-mapped IPv6 ]; +// Local extension of RequestInit to support Node.js/undici dispatcher +interface NodeFetchInit extends RequestInit { + dispatcher?: Agent | ProxyAgent; +} + export class FetchError extends Error { constructor( message: string, @@ -42,6 +57,74 @@ export class FetchError extends Error { } } +/** + * Sanitizes a hostname by stripping IPv6 brackets if present. + */ +export function sanitizeHostname(hostname: string): string { + return hostname.startsWith('[') && hostname.endsWith(']') + ? hostname.slice(1, -1) + : hostname; +} + +/** + * Checks if a hostname is a local loopback address allowed for development/testing. + */ +export function isLoopbackHost(hostname: string): boolean { + const sanitized = sanitizeHostname(hostname); + return ( + sanitized === 'localhost' || + sanitized === '127.0.0.1' || + sanitized === '::1' + ); +} + +/** + * A custom DNS lookup implementation for undici agents that prevents + * connection to private IP ranges (SSRF protection). + */ +export function safeLookup( + hostname: string, + options: dns.LookupOptions | number | null | undefined, + callback: ( + err: Error | null, + addresses: Array<{ address: string; family: number }>, + ) => void, +): void { + // Use the callback-based dns.lookup to match undici's expected signature. + // We explicitly handle the 'all' option to ensure we get an array of addresses. + const lookupOptions = + typeof options === 'number' ? { family: options } : { ...options }; + const finalOptions = { ...lookupOptions, all: true }; + + dns.lookup(hostname, finalOptions, (err, addresses) => { + if (err) { + callback(err, []); + return; + } + + const addressArray = Array.isArray(addresses) ? addresses : []; + const filtered = addressArray.filter( + (addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname), + ); + + if (filtered.length === 0 && addressArray.length > 0) { + callback(new Error(`Refusing to connect to private IP address`), []); + return; + } + + callback(null, filtered); + }); +} + +// Dedicated dispatcher with connection-level SSRF protection (safeLookup) +const safeDispatcher = new Agent({ + headersTimeout: DEFAULT_HEADERS_TIMEOUT, + bodyTimeout: DEFAULT_BODY_TIMEOUT, + connect: { + lookup: safeLookup, + }, +}); + export function isPrivateIp(url: string): boolean { try { const hostname = new URL(url).hostname; @@ -68,21 +151,82 @@ export async function isPrivateIpAsync(url: string): Promise { // Resolve DNS to check the actual target IP const addresses = await lookup(hostname, { all: true }); return addresses.some((addr) => isAddressPrivate(addr.address)); - } catch (_e) { - return false; + } catch (e) { + if ( + e instanceof Error && + e.name === 'TypeError' && + e.message.includes('Invalid URL') + ) { + return false; + } + throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, { + cause: e, + }); } } /** * Internal helper to check if an IP address string is in a private range. */ -function isAddressPrivate(address: string): boolean { +export function isAddressPrivate(address: string): boolean { + const sanitized = sanitizeHostname(address); + return ( - address === 'localhost' || - PRIVATE_IP_RANGES.some((range) => range.test(address)) + sanitized === 'localhost' || + PRIVATE_IP_RANGES.some((range) => range.test(sanitized)) ); } +/** + * Enhanced fetch with SSRF protection. + * Prevents access to private/internal networks at the connection level. + */ +export async function safeFetch( + url: string | URL, + init?: RequestInit, +): Promise { + const nodeInit: NodeFetchInit = { + ...init, + dispatcher: safeDispatcher, + }; + + try { + return await fetch(url, nodeInit); + } catch (error) { + if (error instanceof Error) { + // Re-map refusing to connect errors to standard FetchError + if (error.message.includes('Refusing to connect to private IP address')) { + throw new FetchError( + `Access to private network is blocked: ${url.toString()}`, + 'ERR_PRIVATE_NETWORK', + { cause: error }, + ); + } + throw new FetchError( + getErrorMessage(error), + isNodeError(error) ? error.code : undefined, + { cause: error }, + ); + } + throw new FetchError(String(error)); + } +} + +/** + * Creates an undici ProxyAgent that incorporates safe DNS lookup. + */ +export function createSafeProxyAgent(proxyUrl: string): ProxyAgent { + return new ProxyAgent({ + uri: proxyUrl, + connect: { + lookup: safeLookup, + }, + }); +} + +/** + * Performs a fetch with a specified timeout and connection-level SSRF protection. + */ export async function fetchWithTimeout( url: string, timeout: number, @@ -101,16 +245,29 @@ export async function fetchWithTimeout( } } + const nodeInit: NodeFetchInit = { + ...options, + signal: controller.signal, + dispatcher: safeDispatcher, + }; + try { - const response = await fetch(url, { - ...options, - signal: controller.signal, - }); + const response = await fetch(url, nodeInit); return response; } catch (error) { if (isNodeError(error) && error.code === 'ABORT_ERR') { throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT'); } + if ( + error instanceof Error && + error.message.includes('Refusing to connect to private IP address') + ) { + throw new FetchError( + `Access to private network is blocked: ${url}`, + 'ERR_PRIVATE_NETWORK', + { cause: error }, + ); + } throw new FetchError(getErrorMessage(error), undefined, { cause: error }); } finally { clearTimeout(timeoutId);