mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-12 14:22:00 -07:00
feat(a2a): add DNS rebinding protection and robust URL reconstruction
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import * as dns from 'node:dns';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
@@ -23,14 +24,28 @@ setGlobalDispatcher(
|
||||
const PRIVATE_IP_RANGES = [
|
||||
/^10\./,
|
||||
/^127\./,
|
||||
/^0\./,
|
||||
/^100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\./,
|
||||
/^169\.254\./,
|
||||
/^172\.(1[6-9]|2[0-9]|3[0-1])\./,
|
||||
/^192\.0\.(0|2)\./,
|
||||
/^192\.88\.99\./,
|
||||
/^192\.168\./,
|
||||
/^198\.(1[8-9]|51\.100)\./,
|
||||
/^203\.0\.113\./,
|
||||
/^2(2[4-9]|3[0-9]|[4-5][0-9])\./,
|
||||
/^::1$/,
|
||||
/^fc00:/,
|
||||
/^fe80:/,
|
||||
/^::$/,
|
||||
/^f[cd]00:/i, // fc00::/7 (ULA)
|
||||
/^fe[89ab][0-9a-f]:/i, // fe80::/10 (Link-local)
|
||||
/^::ffff:(10\.|127\.|0\.|100\.(6[4-9]|[7-9][0-9]|1[0-1][0-9]|12[0-7])\.|169\.254\.|172\.(1[6-9]|2[0-9]|3[0-1])\.|192\.0\.(0|2)\.|192\.88\.99\.|192\.168\.|198\.(1[8-9]|51\.100)\.|203\.0\.113\.|2(2[4-9]|3[0-9]|[4-5][0-9])\.|a9fe:|ac1[0-9a-f]:|c0a8:|0+:)/i, // IPv4-mapped IPv6
|
||||
];
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: Agent | ProxyAgent;
|
||||
}
|
||||
|
||||
export class FetchError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
@@ -42,6 +57,74 @@ export class FetchError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitizes a hostname by stripping IPv6 brackets if present.
|
||||
*/
|
||||
export function sanitizeHostname(hostname: string): string {
|
||||
return hostname.startsWith('[') && hostname.endsWith(']')
|
||||
? hostname.slice(1, -1)
|
||||
: hostname;
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a hostname is a local loopback address allowed for development/testing.
|
||||
*/
|
||||
export function isLoopbackHost(hostname: string): boolean {
|
||||
const sanitized = sanitizeHostname(hostname);
|
||||
return (
|
||||
sanitized === 'localhost' ||
|
||||
sanitized === '127.0.0.1' ||
|
||||
sanitized === '::1'
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* A custom DNS lookup implementation for undici agents that prevents
|
||||
* connection to private IP ranges (SSRF protection).
|
||||
*/
|
||||
export function safeLookup(
|
||||
hostname: string,
|
||||
options: dns.LookupOptions | number | null | undefined,
|
||||
callback: (
|
||||
err: Error | null,
|
||||
addresses: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
): void {
|
||||
// Use the callback-based dns.lookup to match undici's expected signature.
|
||||
// We explicitly handle the 'all' option to ensure we get an array of addresses.
|
||||
const lookupOptions =
|
||||
typeof options === 'number' ? { family: options } : { ...options };
|
||||
const finalOptions = { ...lookupOptions, all: true };
|
||||
|
||||
dns.lookup(hostname, finalOptions, (err, addresses) => {
|
||||
if (err) {
|
||||
callback(err, []);
|
||||
return;
|
||||
}
|
||||
|
||||
const addressArray = Array.isArray(addresses) ? addresses : [];
|
||||
const filtered = addressArray.filter(
|
||||
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
|
||||
);
|
||||
|
||||
if (filtered.length === 0 && addressArray.length > 0) {
|
||||
callback(new Error(`Refusing to connect to private IP address`), []);
|
||||
return;
|
||||
}
|
||||
|
||||
callback(null, filtered);
|
||||
});
|
||||
}
|
||||
|
||||
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
|
||||
const safeDispatcher = new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
|
||||
export function isPrivateIp(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
@@ -68,21 +151,82 @@ export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||
// Resolve DNS to check the actual target IP
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||
} catch (_e) {
|
||||
return false;
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof Error &&
|
||||
e.name === 'TypeError' &&
|
||||
e.message.includes('Invalid URL')
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to check if an IP address string is in a private range.
|
||||
*/
|
||||
function isAddressPrivate(address: string): boolean {
|
||||
export function isAddressPrivate(address: string): boolean {
|
||||
const sanitized = sanitizeHostname(address);
|
||||
|
||||
return (
|
||||
address === 'localhost' ||
|
||||
PRIVATE_IP_RANGES.some((range) => range.test(address))
|
||||
sanitized === 'localhost' ||
|
||||
PRIVATE_IP_RANGES.some((range) => range.test(sanitized))
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced fetch with SSRF protection.
|
||||
* Prevents access to private/internal networks at the connection level.
|
||||
*/
|
||||
export async function safeFetch(
|
||||
url: string | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...init,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
return await fetch(url, nodeInit);
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
// Re-map refusing to connect errors to standard FetchError
|
||||
if (error.message.includes('Refusing to connect to private IP address')) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url.toString()}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(
|
||||
getErrorMessage(error),
|
||||
isNodeError(error) ? error.code : undefined,
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(String(error));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
|
||||
*/
|
||||
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
|
||||
return new ProxyAgent({
|
||||
uri: proxyUrl,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a fetch with a specified timeout and connection-level SSRF protection.
|
||||
*/
|
||||
export async function fetchWithTimeout(
|
||||
url: string,
|
||||
timeout: number,
|
||||
@@ -101,16 +245,29 @@ export async function fetchWithTimeout(
|
||||
}
|
||||
}
|
||||
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
});
|
||||
const response = await fetch(url, nodeInit);
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ABORT_ERR') {
|
||||
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
|
||||
}
|
||||
if (
|
||||
error instanceof Error &&
|
||||
error.message.includes('Refusing to connect to private IP address')
|
||||
) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
|
||||
Reference in New Issue
Block a user