mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(a2a): add DNS rebinding protection and robust URL reconstruction
This commit is contained in:
@@ -8,7 +8,8 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import * as sdkClient from '@a2a-js/sdk/client';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import type { LookupOptions } from 'node:dns';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
interface MockClient {
|
||||
@@ -38,7 +39,7 @@ vi.mock('../utils/debugLogger.js', () => ({
|
||||
}));
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn().mockResolvedValue([{ address: '93.184.216.34' }]),
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('A2AClientManager', () => {
|
||||
@@ -65,8 +66,19 @@ describe('A2AClientManager', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
A2AClientManager.resetInstanceForTesting();
|
||||
manager = A2AClientManager.getInstance();
|
||||
manager.clearCache();
|
||||
|
||||
// Default DNS mock: resolve to public IP.
|
||||
// Use any cast only for return value due to complex multi-signature overloads.
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async (_h: string, options?: LookupOptions | number) => {
|
||||
const addr = { address: '93.184.216.34', family: 4 };
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
// Re-create the instances as plain objects that can be spied on
|
||||
const factoryInstance = {
|
||||
@@ -253,7 +265,10 @@ describe('A2AClientManager', () => {
|
||||
contextId: 'ctx123',
|
||||
taskId: 'task456',
|
||||
});
|
||||
await stream[Symbol.asyncIterator]().next();
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
break;
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
@@ -277,7 +292,10 @@ describe('A2AClientManager', () => {
|
||||
const stream = manager.sendMessageStream('TestAgent', 'Hello', {
|
||||
signal: controller.signal,
|
||||
});
|
||||
await stream[Symbol.asyncIterator]().next();
|
||||
// trigger execution
|
||||
for await (const _ of stream) {
|
||||
break;
|
||||
}
|
||||
|
||||
expect(mockClient.sendMessageStream).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
@@ -415,15 +433,19 @@ describe('A2AClientManager', () => {
|
||||
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
|
||||
const maliciousDomainUrl =
|
||||
'http://malicious.com/.well-known/agent-card.json';
|
||||
vi.mocked(lookup).mockResolvedValueOnce([
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
]);
|
||||
|
||||
vi.mocked(dnsPromises.lookup).mockImplementationOnce(
|
||||
async (_h: string, options?: LookupOptions | number) => {
|
||||
const addr = { address: '10.0.0.1', family: 4 };
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
|
||||
).rejects.toThrow(
|
||||
/Refusing to load agent 'dns-ssrf-agent' from private IP range/,
|
||||
);
|
||||
).rejects.toThrow(/private IP range/);
|
||||
});
|
||||
|
||||
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
|
||||
@@ -438,6 +460,21 @@ describe('A2AClientManager', () => {
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
// DNS for public.agent.com is public
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async (hostname: string, options?: LookupOptions | number) => {
|
||||
const isAll = typeof options === 'object' && options?.all;
|
||||
if (hostname === 'public.agent.com') {
|
||||
const addr = { address: '1.1.1.1', family: 4 };
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
}
|
||||
const addr = { address: '192.168.1.1', family: 4 };
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return (isAll ? [addr] : addr) as any;
|
||||
},
|
||||
);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('malicious-agent', publicUrl),
|
||||
).rejects.toThrow(
|
||||
@@ -463,5 +500,24 @@ describe('A2AClientManager', () => {
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly handle URLs with standardPath in the hash fragment', async () => {
|
||||
const fragmentUrl =
|
||||
'http://localhost:9001/.well-known/agent-card.json#.well-known/agent-card.json';
|
||||
const resolverInstance = {
|
||||
resolve: vi.fn().mockResolvedValue({ name: 'test' } as AgentCard),
|
||||
};
|
||||
vi.mocked(sdkClient.DefaultAgentCardResolver).mockReturnValue(
|
||||
resolverInstance as unknown as sdkClient.DefaultAgentCardResolver,
|
||||
);
|
||||
|
||||
await manager.loadAgent('fragment-agent', fragmentUrl);
|
||||
|
||||
// Should correctly ignore the hash fragment and use the path from the URL object
|
||||
expect(resolverInstance.resolve).toHaveBeenCalledWith(
|
||||
'http://localhost:9001/',
|
||||
'.well-known/agent-card.json',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -12,45 +12,65 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import type { AuthenticationHandler, Client } from '@a2a-js/sdk/client';
|
||||
import {
|
||||
ClientFactory,
|
||||
ClientFactoryOptions,
|
||||
DefaultAgentCardResolver,
|
||||
RestTransportFactory,
|
||||
JsonRpcTransportFactory,
|
||||
type AuthenticationHandler,
|
||||
RestTransportFactory,
|
||||
createAuthenticatingFetchWithRetry,
|
||||
type Client,
|
||||
} from '@a2a-js/sdk/client';
|
||||
import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Agent as UndiciAgent } from 'undici';
|
||||
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
|
||||
import { isPrivateIpAsync } from '../utils/fetch.js';
|
||||
import {
|
||||
getGrpcChannelOptions,
|
||||
getGrpcCredentials,
|
||||
normalizeAgentCard,
|
||||
pinUrlToIp,
|
||||
} from './a2aUtils.js';
|
||||
import {
|
||||
isPrivateIpAsync,
|
||||
safeLookup,
|
||||
isLoopbackHost,
|
||||
} from '../utils/fetch.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
/**
|
||||
* Result of sending a message, which can be a full message, a task,
|
||||
* or an incremental status/artifact update.
|
||||
*/
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: UndiciAgent;
|
||||
}
|
||||
|
||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||
// Use a dedicated dispatcher so the global 5-min timeout isn't affected.
|
||||
const A2A_TIMEOUT = 1800000; // 30 minutes
|
||||
const a2aDispatcher = new UndiciAgent({
|
||||
headersTimeout: A2A_TIMEOUT,
|
||||
bodyTimeout: A2A_TIMEOUT,
|
||||
connect: {
|
||||
// SSRF protection at the connection level (mitigates DNS rebinding)
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
const a2aFetch: typeof fetch = (input, init) =>
|
||||
// @ts-expect-error The `dispatcher` property is a Node.js extension to fetch not present in standard types.
|
||||
fetch(input, { ...init, dispatcher: a2aDispatcher });
|
||||
|
||||
export type SendMessageResult =
|
||||
| Message
|
||||
| Task
|
||||
| TaskStatusUpdateEvent
|
||||
| TaskArtifactUpdateEvent;
|
||||
const a2aFetch: typeof fetch = (input, init) => {
|
||||
const nodeInit: NodeFetchInit = { ...init, dispatcher: a2aDispatcher };
|
||||
return fetch(input, nodeInit);
|
||||
};
|
||||
|
||||
/**
|
||||
* Orchestrates communication with A2A agents.
|
||||
*
|
||||
* This manager handles agent discovery, card caching, and client lifecycle.
|
||||
* It provides a unified messaging interface using the standard A2A SDK.
|
||||
* Orchestrates communication with remote A2A agents.
|
||||
* Manages protocol negotiation, authentication, and transport selection.
|
||||
*/
|
||||
export class A2AClientManager {
|
||||
private static instance: A2AClientManager;
|
||||
@@ -71,19 +91,10 @@ export class A2AClientManager {
|
||||
return A2AClientManager.instance;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resets the singleton instance. Only for testing purposes.
|
||||
* @internal
|
||||
*/
|
||||
static resetInstanceForTesting() {
|
||||
// @ts-expect-error - Resetting singleton for testing
|
||||
A2AClientManager.instance = undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads an agent by fetching its AgentCard and caches the client.
|
||||
* @param name The name to assign to the agent.
|
||||
* @param agentCardUrl The full URL to the agent's card.
|
||||
* @param agentCardUrl {string} The full URL to the agent's card.
|
||||
* @param authHandler Optional authentication handler to use for this agent.
|
||||
* @returns The loaded AgentCard.
|
||||
*/
|
||||
@@ -100,6 +111,29 @@ export class A2AClientManager {
|
||||
const resolver = new DefaultAgentCardResolver({ fetchImpl });
|
||||
const agentCard = await this.resolveAgentCard(name, agentCardUrl, resolver);
|
||||
|
||||
// Pin URL to IP to prevent DNS rebinding for gRPC (connection-level SSRF protection)
|
||||
const grpcInterface = agentCard.additionalInterfaces?.find(
|
||||
(i) => i.transport === 'GRPC',
|
||||
);
|
||||
const urlToPin = grpcInterface?.url ?? agentCard.url;
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(urlToPin, name);
|
||||
|
||||
// Prepare base gRPC options
|
||||
const baseGrpcOptions: ConstructorParameters<
|
||||
typeof GrpcTransportFactory
|
||||
>[0] = {
|
||||
grpcChannelCredentials: getGrpcCredentials(urlToPin),
|
||||
};
|
||||
|
||||
// Include extension properties required for SSRF pinning using object spread.
|
||||
// This allows us to pass additional properties that the SDK uses internally
|
||||
// without triggering narrow-to-broad type assertion warnings.
|
||||
const fullGrpcOptions = {
|
||||
...baseGrpcOptions,
|
||||
target: pinnedUrl,
|
||||
grpcChannelOptions: getGrpcChannelOptions(hostname),
|
||||
};
|
||||
|
||||
// Configure standard SDK client for tool registration and discovery
|
||||
const clientOptions = ClientFactoryOptions.createFrom(
|
||||
ClientFactoryOptions.default,
|
||||
@@ -107,9 +141,11 @@ export class A2AClientManager {
|
||||
transports: [
|
||||
new RestTransportFactory({ fetchImpl }),
|
||||
new JsonRpcTransportFactory({ fetchImpl }),
|
||||
new GrpcTransportFactory({
|
||||
grpcChannelCredentials: getGrpcCredentials(agentCard.url),
|
||||
}),
|
||||
new GrpcTransportFactory(
|
||||
fullGrpcOptions as ConstructorParameters<
|
||||
typeof GrpcTransportFactory
|
||||
>[0],
|
||||
),
|
||||
],
|
||||
cardResolver: resolver,
|
||||
},
|
||||
@@ -166,7 +202,7 @@ export class A2AClientManager {
|
||||
try {
|
||||
yield* client.sendMessageStream(messageParams, {
|
||||
signal: options?.signal,
|
||||
});
|
||||
}) as AsyncIterable<SendMessageResult>;
|
||||
} catch (error: unknown) {
|
||||
const prefix = `[A2AClientManager] sendMessageStream Error [${agentName}]`;
|
||||
if (error instanceof Error) {
|
||||
@@ -263,7 +299,7 @@ export class A2AClientManager {
|
||||
if (await isPrivateIpAsync(url)) {
|
||||
// Local/private IPs are allowed ONLY for localhost for testing.
|
||||
const parsed = new URL(url);
|
||||
if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') {
|
||||
if (!isLoopbackHost(parsed.hostname)) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}' from private IP range: ${url}. Remote agents must use public URLs.`,
|
||||
);
|
||||
@@ -275,7 +311,14 @@ export class A2AClientManager {
|
||||
if (parsedUrl.pathname.endsWith(standardPath)) {
|
||||
// Correctly split the URL into baseUrl and standard path
|
||||
path = standardPath;
|
||||
baseUrl = url.substring(0, url.lastIndexOf(standardPath));
|
||||
// Reconstruct baseUrl from parsed components to avoid issues with hashes or query params.
|
||||
parsedUrl.pathname = parsedUrl.pathname.substring(
|
||||
0,
|
||||
parsedUrl.pathname.lastIndexOf(standardPath),
|
||||
);
|
||||
parsedUrl.search = '';
|
||||
parsedUrl.hash = '';
|
||||
baseUrl = parsedUrl.toString();
|
||||
}
|
||||
} catch (e) {
|
||||
throw new Error(`Invalid agent card URL: ${url}`, { cause: e });
|
||||
@@ -312,10 +355,7 @@ export class A2AClientManager {
|
||||
|
||||
if (await isPrivateIpAsync(validationUrl)) {
|
||||
const parsed = new URL(validationUrl);
|
||||
if (
|
||||
parsed.hostname !== 'localhost' &&
|
||||
parsed.hostname !== '127.0.0.1'
|
||||
) {
|
||||
if (!isLoopbackHost(parsed.hostname)) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}': contains transport URL pointing to private IP range: ${url}.`,
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest';
|
||||
import {
|
||||
extractMessageText,
|
||||
extractIdsFromResponse,
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
AUTH_REQUIRED_MSG,
|
||||
normalizeAgentCard,
|
||||
getGrpcCredentials,
|
||||
pinUrlToIp,
|
||||
} from './a2aUtils.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
import type {
|
||||
@@ -24,8 +25,17 @@ import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
} from '@a2a-js/sdk';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
describe('a2aUtils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('getGrpcCredentials', () => {
|
||||
it('should return secure credentials for https', () => {
|
||||
const credentials = getGrpcCredentials('https://test.agent');
|
||||
@@ -38,6 +48,69 @@ describe('a2aUtils', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('pinUrlToIp', () => {
|
||||
it('should resolve and pin hostname to IP', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '93.184.216.34', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('http://93.184.216.34:9000/');
|
||||
});
|
||||
|
||||
it('should handle raw host:port strings (standard for gRPC)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '93.184.216.34', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'example.com:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('example.com');
|
||||
expect(pinnedUrl).toBe('93.184.216.34:9000');
|
||||
});
|
||||
|
||||
it('should throw error if resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://unreachable.com', 'test-agent'),
|
||||
).rejects.toThrow("Failed to resolve host for agent 'test-agent'");
|
||||
});
|
||||
|
||||
it('should throw error if resolved to private IP', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
await expect(
|
||||
pinUrlToIp('http://malicious.com', 'test-agent'),
|
||||
).rejects.toThrow('resolves to private IP range');
|
||||
});
|
||||
|
||||
it('should allow localhost/127.0.0.1/::1 exceptions', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockResolvedValue([
|
||||
{ address: '127.0.0.1', family: 4 },
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
] as any);
|
||||
|
||||
const { pinnedUrl, hostname } = await pinUrlToIp(
|
||||
'http://localhost:9000',
|
||||
'test-agent',
|
||||
);
|
||||
expect(hostname).toBe('localhost');
|
||||
expect(pinnedUrl).toBe('http://127.0.0.1:9000/');
|
||||
});
|
||||
});
|
||||
|
||||
describe('isTerminalState', () => {
|
||||
it('should return true for completed, failed, canceled, and rejected', () => {
|
||||
expect(isTerminalState('completed')).toBe(true);
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import * as grpc from '@grpc/grpc-js';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import type {
|
||||
Message,
|
||||
Part,
|
||||
@@ -14,9 +15,12 @@ import type {
|
||||
Artifact,
|
||||
TaskState,
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
AgentCard,
|
||||
AgentInterface,
|
||||
Task,
|
||||
} from '@a2a-js/sdk';
|
||||
import { isAddressPrivate } from '../utils/fetch.js';
|
||||
import type { SendMessageResult } from './a2a-client-manager.js';
|
||||
|
||||
export const AUTH_REQUIRED_MSG = `[Authorization Required] The agent has indicated it requires authorization to proceed. Please follow the agent's instructions.`;
|
||||
@@ -36,80 +40,68 @@ export class A2AResultReassembler {
|
||||
update(chunk: SendMessageResult) {
|
||||
if (!('kind' in chunk)) return;
|
||||
|
||||
switch (chunk.kind) {
|
||||
case 'status-update':
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
break;
|
||||
if (isStatusUpdateEvent(chunk)) {
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
} else if (isArtifactUpdateEvent(chunk)) {
|
||||
if (chunk.artifact) {
|
||||
const id = chunk.artifact.artifactId;
|
||||
const existing = this.artifacts.get(id);
|
||||
|
||||
case 'artifact-update':
|
||||
if (chunk.artifact) {
|
||||
const id = chunk.artifact.artifactId;
|
||||
const existing = this.artifacts.get(id);
|
||||
|
||||
if (chunk.append && existing) {
|
||||
for (const part of chunk.artifact.parts) {
|
||||
existing.parts.push(structuredClone(part));
|
||||
}
|
||||
} else {
|
||||
this.artifacts.set(id, structuredClone(chunk.artifact));
|
||||
}
|
||||
|
||||
const newText = extractPartsText(chunk.artifact.parts, '');
|
||||
let chunks = this.artifactChunks.get(id);
|
||||
if (!chunks) {
|
||||
chunks = [];
|
||||
this.artifactChunks.set(id, chunks);
|
||||
}
|
||||
if (chunk.append) {
|
||||
chunks.push(newText);
|
||||
} else {
|
||||
chunks.length = 0;
|
||||
chunks.push(newText);
|
||||
if (chunk.append && existing) {
|
||||
for (const part of chunk.artifact.parts) {
|
||||
existing.parts.push(structuredClone(part));
|
||||
}
|
||||
} else {
|
||||
this.artifacts.set(id, structuredClone(chunk.artifact));
|
||||
}
|
||||
break;
|
||||
|
||||
case 'task':
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
if (chunk.artifacts) {
|
||||
for (const art of chunk.artifacts) {
|
||||
this.artifacts.set(art.artifactId, structuredClone(art));
|
||||
this.artifactChunks.set(art.artifactId, [
|
||||
extractPartsText(art.parts, ''),
|
||||
]);
|
||||
}
|
||||
const newText = extractPartsText(chunk.artifact.parts, '');
|
||||
let chunks = this.artifactChunks.get(id);
|
||||
if (!chunks) {
|
||||
chunks = [];
|
||||
this.artifactChunks.set(id, chunks);
|
||||
}
|
||||
// History Fallback: Some agent implementations do not populate the
|
||||
// status.message in their final terminal response, instead archiving
|
||||
// the final answer in the task's history array. To ensure we don't
|
||||
// present an empty result, we fallback to the most recent agent message
|
||||
// in the history only when the task is terminal and no other content
|
||||
// (message log or artifacts) has been reassembled.
|
||||
if (
|
||||
isTerminalState(chunk.status?.state) &&
|
||||
this.messageLog.length === 0 &&
|
||||
this.artifacts.size === 0 &&
|
||||
chunk.history &&
|
||||
chunk.history.length > 0
|
||||
) {
|
||||
const lastAgentMsg = [...chunk.history]
|
||||
.reverse()
|
||||
.find((m) => m.role?.toLowerCase().includes('agent'));
|
||||
if (lastAgentMsg) {
|
||||
this.pushMessage(lastAgentMsg);
|
||||
}
|
||||
if (chunk.append) {
|
||||
chunks.push(newText);
|
||||
} else {
|
||||
chunks.length = 0;
|
||||
chunks.push(newText);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'message': {
|
||||
this.pushMessage(chunk);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
} else if (isTask(chunk)) {
|
||||
this.appendStateInstructions(chunk.status?.state);
|
||||
this.pushMessage(chunk.status?.message);
|
||||
if (chunk.artifacts) {
|
||||
for (const art of chunk.artifacts) {
|
||||
this.artifacts.set(art.artifactId, structuredClone(art));
|
||||
this.artifactChunks.set(art.artifactId, [
|
||||
extractPartsText(art.parts, ''),
|
||||
]);
|
||||
}
|
||||
}
|
||||
// History Fallback: Some agent implementations do not populate the
|
||||
// status.message in their final terminal response, instead archiving
|
||||
// the final answer in the task's history array. To ensure we don't
|
||||
// present an empty result, we fallback to the most recent agent message
|
||||
// in the history only when the task is terminal and no other content
|
||||
// (message log or artifacts) has been reassembled.
|
||||
if (
|
||||
isTerminalState(chunk.status?.state) &&
|
||||
this.messageLog.length === 0 &&
|
||||
this.artifacts.size === 0 &&
|
||||
chunk.history &&
|
||||
chunk.history.length > 0
|
||||
) {
|
||||
const lastAgentMsg = [...chunk.history]
|
||||
.reverse()
|
||||
.find((m) => m.role?.toLowerCase().includes('agent'));
|
||||
if (lastAgentMsg) {
|
||||
this.pushMessage(lastAgentMsg);
|
||||
}
|
||||
}
|
||||
} else if (isMessage(chunk)) {
|
||||
this.pushMessage(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -270,6 +262,84 @@ export function getGrpcCredentials(url: string): grpc.ChannelCredentials {
|
||||
: grpc.credentials.createInsecure();
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns gRPC channel options to ensure SSL/authority matches the original hostname
|
||||
* when connecting via a pinned IP address.
|
||||
*/
|
||||
export function getGrpcChannelOptions(
|
||||
hostname: string,
|
||||
): Record<string, unknown> {
|
||||
return {
|
||||
'grpc.default_authority': hostname,
|
||||
'grpc.ssl_target_name_override': hostname,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves a hostname to its IP address and validates it against SSRF.
|
||||
* Returns the pinned IP-based URL and the original hostname.
|
||||
*/
|
||||
export async function pinUrlToIp(
|
||||
url: string,
|
||||
agentName: string,
|
||||
): Promise<{ pinnedUrl: string; hostname: string }> {
|
||||
if (!url) return { pinnedUrl: url, hostname: '' };
|
||||
|
||||
// gRPC URLs in A2A can be 'host:port' or 'dns:///host:port' or have schemes.
|
||||
// We normalize to host:port for resolution.
|
||||
const hasScheme = url.includes('://');
|
||||
const normalizedUrl = hasScheme ? url : `http://${url}`;
|
||||
|
||||
try {
|
||||
const parsed = new URL(normalizedUrl);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
const sanitizedHost =
|
||||
hostname.startsWith('[') && hostname.endsWith(']')
|
||||
? hostname.slice(1, -1)
|
||||
: hostname;
|
||||
|
||||
// Resolve DNS to check the actual target IP and pin it
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
const publicAddresses = addresses.filter(
|
||||
(addr) =>
|
||||
!isAddressPrivate(addr.address) ||
|
||||
sanitizedHost === 'localhost' ||
|
||||
sanitizedHost === '127.0.0.1' ||
|
||||
sanitizedHost === '::1',
|
||||
);
|
||||
|
||||
if (publicAddresses.length === 0 && addresses.length > 0) {
|
||||
throw new Error(
|
||||
`Refusing to load agent '${agentName}': transport URL '${url}' resolves to private IP range.`,
|
||||
);
|
||||
}
|
||||
|
||||
const pinnedIp = publicAddresses[0].address;
|
||||
const pinnedHostname = pinnedIp.includes(':') ? `[${pinnedIp}]` : pinnedIp;
|
||||
|
||||
// Reconstruct URL with IP
|
||||
parsed.hostname = pinnedHostname;
|
||||
let pinnedUrl = parsed.toString();
|
||||
|
||||
// If original didn't have scheme, remove it (standard for gRPC targets)
|
||||
if (!hasScheme) {
|
||||
pinnedUrl = pinnedUrl.replace(/^http:\/\//, '');
|
||||
// URL.toString() might append a trailing slash
|
||||
if (pinnedUrl.endsWith('/') && !url.endsWith('/')) {
|
||||
pinnedUrl = pinnedUrl.slice(0, -1);
|
||||
}
|
||||
}
|
||||
|
||||
return { pinnedUrl, hostname };
|
||||
} catch (e) {
|
||||
if (e instanceof Error && e.message.includes('Refusing')) throw e;
|
||||
throw new Error(`Failed to resolve host for agent '${agentName}': ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extracts contextId and taskId from a Message, Task, or Update response.
|
||||
* Follows the pattern from the A2A CLI sample to maintain conversational continuity.
|
||||
@@ -285,7 +355,7 @@ export function extractIdsFromResponse(result: SendMessageResult): {
|
||||
|
||||
if ('kind' in result) {
|
||||
const kind = result.kind;
|
||||
if (kind === 'message' || kind === 'artifact-update') {
|
||||
if (kind === 'message' || isArtifactUpdateEvent(result)) {
|
||||
taskId = result.taskId;
|
||||
contextId = result.contextId;
|
||||
} else if (kind === 'task') {
|
||||
@@ -393,6 +463,20 @@ function isStatusUpdateEvent(
|
||||
return result.kind === 'status-update';
|
||||
}
|
||||
|
||||
function isArtifactUpdateEvent(
|
||||
result: SendMessageResult,
|
||||
): result is TaskArtifactUpdateEvent {
|
||||
return result.kind === 'artifact-update';
|
||||
}
|
||||
|
||||
function isMessage(result: SendMessageResult): result is Message {
|
||||
return result.kind === 'message';
|
||||
}
|
||||
|
||||
function isTask(result: SendMessageResult): result is Task {
|
||||
return result.kind === 'task';
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the given state is a terminal state for a task.
|
||||
*/
|
||||
|
||||
257
packages/core/src/utils/fetch.test.ts
Normal file
257
packages/core/src/utils/fetch.test.ts
Normal file
@@ -0,0 +1,257 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
|
||||
import {
|
||||
isPrivateIp,
|
||||
isPrivateIpAsync,
|
||||
isAddressPrivate,
|
||||
safeLookup,
|
||||
safeFetch,
|
||||
} from './fetch.js';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import * as dns from 'node:dns';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// We need to mock node:dns for safeLookup since it uses the callback API
|
||||
vi.mock('node:dns', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock global fetch
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = vi.fn();
|
||||
|
||||
describe('fetch utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterAll(() => {
|
||||
global.fetch = originalFetch;
|
||||
});
|
||||
|
||||
describe('isAddressPrivate', () => {
|
||||
it('should identify private IPv4 addresses', () => {
|
||||
expect(isAddressPrivate('10.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('127.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('172.16.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('192.168.1.1')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify non-routable and reserved IPv4 addresses (RFC 6890)', () => {
|
||||
expect(isAddressPrivate('0.0.0.0')).toBe(true);
|
||||
expect(isAddressPrivate('100.64.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('192.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('192.0.2.1')).toBe(true);
|
||||
expect(isAddressPrivate('192.88.99.1')).toBe(true);
|
||||
expect(isAddressPrivate('198.18.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('198.51.100.1')).toBe(true);
|
||||
expect(isAddressPrivate('203.0.113.1')).toBe(true);
|
||||
expect(isAddressPrivate('224.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('240.0.0.1')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify private IPv6 addresses', () => {
|
||||
expect(isAddressPrivate('::1')).toBe(true);
|
||||
expect(isAddressPrivate('fc00::')).toBe(true);
|
||||
expect(isAddressPrivate('fd00::')).toBe(true);
|
||||
expect(isAddressPrivate('fe80::')).toBe(true);
|
||||
expect(isAddressPrivate('febf::')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify special local addresses', () => {
|
||||
expect(isAddressPrivate('0.0.0.0')).toBe(true);
|
||||
expect(isAddressPrivate('::')).toBe(true);
|
||||
expect(isAddressPrivate('localhost')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify link-local addresses', () => {
|
||||
expect(isAddressPrivate('169.254.169.254')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify IPv4-mapped IPv6 private addresses', () => {
|
||||
expect(isAddressPrivate('::ffff:127.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:10.0.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:169.254.169.254')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:192.168.1.1')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:172.16.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:0.0.0.0')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:100.64.0.1')).toBe(true);
|
||||
expect(isAddressPrivate('::ffff:a9fe:101')).toBe(true); // 169.254.1.1
|
||||
});
|
||||
|
||||
it('should identify public addresses as non-private', () => {
|
||||
expect(isAddressPrivate('8.8.8.8')).toBe(false);
|
||||
expect(isAddressPrivate('93.184.216.34')).toBe(false);
|
||||
expect(isAddressPrivate('2001:4860:4860::8888')).toBe(false);
|
||||
expect(isAddressPrivate('::ffff:8.8.8.8')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPrivateIp', () => {
|
||||
it('should identify private IPs in URLs', () => {
|
||||
expect(isPrivateIp('http://10.0.0.1/')).toBe(true);
|
||||
expect(isPrivateIp('https://127.0.0.1:8080/')).toBe(true);
|
||||
expect(isPrivateIp('http://localhost/')).toBe(true);
|
||||
expect(isPrivateIp('http://[::1]/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify public IPs in URLs as non-private', () => {
|
||||
expect(isPrivateIp('http://8.8.8.8/')).toBe(false);
|
||||
expect(isPrivateIp('https://google.com/')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPrivateIpAsync', () => {
|
||||
it('should identify private IPs directly', async () => {
|
||||
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to private IPs', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '10.0.0.1', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to public IPs as non-private', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '8.8.8.8', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
|
||||
});
|
||||
|
||||
it('should throw error if DNS resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
|
||||
'Failed to verify if URL resolves to private IP',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false for invalid URLs instead of throwing verification error', async () => {
|
||||
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeLookup', () => {
|
||||
it('should filter out private IPs', async () => {
|
||||
const addresses = [
|
||||
{ address: '8.8.8.8', family: 4 },
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('example.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('8.8.8.8');
|
||||
});
|
||||
|
||||
it('should allow explicit localhost', async () => {
|
||||
const addresses = [{ address: '127.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('localhost', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('127.0.0.1');
|
||||
});
|
||||
|
||||
it('should error if all resolved IPs are private', async () => {
|
||||
const addresses = [{ address: '10.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
await expect(
|
||||
new Promise((resolve, reject) => {
|
||||
safeLookup('malicious.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
}),
|
||||
).rejects.toThrow('Refusing to connect to private IP address');
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeFetch', () => {
|
||||
it('should forward to fetch with dispatcher', async () => {
|
||||
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
|
||||
|
||||
const response = await safeFetch('https://example.com');
|
||||
expect(response.status).toBe(200);
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
expect.objectContaining({
|
||||
dispatcher: expect.any(Object),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle Refusing to connect errors', async () => {
|
||||
vi.mocked(global.fetch).mockRejectedValue(
|
||||
new Error('Refusing to connect to private IP address'),
|
||||
);
|
||||
|
||||
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
|
||||
'Access to private network is blocked',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import * as dns from 'node:dns';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
@@ -23,14 +24,28 @@ setGlobalDispatcher(
|
||||
const PRIVATE_IP_RANGES = [
|
||||
/^10\./,
|
||||
/^127\./,
|
||||
/^0\./,
|
||||
/^100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\./,
|
||||
/^169\.254\./,
|
||||
/^172\.(1[6-9]|2[0-9]|3[0-1])\./,
|
||||
/^192\.0\.(0|2)\./,
|
||||
/^192\.88\.99\./,
|
||||
/^192\.168\./,
|
||||
/^198\.(1[8-9]|51\.100)\./,
|
||||
/^203\.0\.113\./,
|
||||
/^2(2[4-9]|3[0-9]|[4-5][0-9])\./,
|
||||
/^::1$/,
|
||||
/^fc00:/,
|
||||
/^fe80:/,
|
||||
/^::$/,
|
||||
/^f[cd]00:/i, // fc00::/7 (ULA)
|
||||
/^fe[89ab][0-9a-f]:/i, // fe80::/10 (Link-local)
|
||||
/^::ffff:(10\.|127\.|0\.|100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\.|169\.254\.|172\.(1[6-9]|2[0-9]|3[0-1])\.|192\.0\.(0|2)\.|192\.88\.99\.|192\.168\.|198\.(1[8-9]|51\.100)\.|203\.0\.113\.|2(2[4-9]|3[0-9]|[4-5][0-9])\.|a9fe:|ac1[0-9a-f]:|c0a8:|0+:)/i, // IPv4-mapped IPv6
|
||||
];
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: Agent | ProxyAgent;
|
||||
}
|
||||
|
||||
export class FetchError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
@@ -42,6 +57,74 @@ export class FetchError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a hostname by stripping IPv6 brackets if present.
|
||||
*/
|
||||
export function sanitizeHostname(hostname: string): string {
|
||||
return hostname.startsWith('[') && hostname.endsWith(']')
|
||||
? hostname.slice(1, -1)
|
||||
: hostname;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a hostname is a local loopback address allowed for development/testing.
|
||||
*/
|
||||
export function isLoopbackHost(hostname: string): boolean {
|
||||
const sanitized = sanitizeHostname(hostname);
|
||||
return (
|
||||
sanitized === 'localhost' ||
|
||||
sanitized === '127.0.0.1' ||
|
||||
sanitized === '::1'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* A custom DNS lookup implementation for undici agents that prevents
|
||||
* connection to private IP ranges (SSRF protection).
|
||||
*/
|
||||
export function safeLookup(
|
||||
hostname: string,
|
||||
options: dns.LookupOptions | number | null | undefined,
|
||||
callback: (
|
||||
err: Error | null,
|
||||
addresses: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
): void {
|
||||
// Use the callback-based dns.lookup to match undici's expected signature.
|
||||
// We explicitly handle the 'all' option to ensure we get an array of addresses.
|
||||
const lookupOptions =
|
||||
typeof options === 'number' ? { family: options } : { ...options };
|
||||
const finalOptions = { ...lookupOptions, all: true };
|
||||
|
||||
dns.lookup(hostname, finalOptions, (err, addresses) => {
|
||||
if (err) {
|
||||
callback(err, []);
|
||||
return;
|
||||
}
|
||||
|
||||
const addressArray = Array.isArray(addresses) ? addresses : [];
|
||||
const filtered = addressArray.filter(
|
||||
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
|
||||
);
|
||||
|
||||
if (filtered.length === 0 && addressArray.length > 0) {
|
||||
callback(new Error(`Refusing to connect to private IP address`), []);
|
||||
return;
|
||||
}
|
||||
|
||||
callback(null, filtered);
|
||||
});
|
||||
}
|
||||
|
||||
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
|
||||
const safeDispatcher = new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
|
||||
export function isPrivateIp(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
@@ -68,21 +151,82 @@ export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||
// Resolve DNS to check the actual target IP
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||
} catch (_e) {
|
||||
return false;
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof Error &&
|
||||
e.name === 'TypeError' &&
|
||||
e.message.includes('Invalid URL')
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to check if an IP address string is in a private range.
|
||||
*/
|
||||
function isAddressPrivate(address: string): boolean {
|
||||
export function isAddressPrivate(address: string): boolean {
|
||||
const sanitized = sanitizeHostname(address);
|
||||
|
||||
return (
|
||||
address === 'localhost' ||
|
||||
PRIVATE_IP_RANGES.some((range) => range.test(address))
|
||||
sanitized === 'localhost' ||
|
||||
PRIVATE_IP_RANGES.some((range) => range.test(sanitized))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced fetch with SSRF protection.
|
||||
* Prevents access to private/internal networks at the connection level.
|
||||
*/
|
||||
export async function safeFetch(
|
||||
url: string | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...init,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
return await fetch(url, nodeInit);
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
// Re-map refusing to connect errors to standard FetchError
|
||||
if (error.message.includes('Refusing to connect to private IP address')) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url.toString()}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(
|
||||
getErrorMessage(error),
|
||||
isNodeError(error) ? error.code : undefined,
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(String(error));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
|
||||
*/
|
||||
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
|
||||
return new ProxyAgent({
|
||||
uri: proxyUrl,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a fetch with a specified timeout and connection-level SSRF protection.
|
||||
*/
|
||||
export async function fetchWithTimeout(
|
||||
url: string,
|
||||
timeout: number,
|
||||
@@ -101,16 +245,29 @@ export async function fetchWithTimeout(
|
||||
}
|
||||
}
|
||||
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
});
|
||||
const response = await fetch(url, nodeInit);
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ABORT_ERR') {
|
||||
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
|
||||
}
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message.includes('Refusing to connect to private IP address')
|
||||
) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
Reference in New Issue
Block a user