mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-14 07:10:34 -07:00
feat(a2a): enable native gRPC support and protocol routing (#21403)
Co-authored-by: Adam Weidman <adamfweidman@google.com>
This commit is contained in:
@@ -5,27 +5,12 @@
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
|
||||
import {
|
||||
isPrivateIp,
|
||||
isPrivateIpAsync,
|
||||
isAddressPrivate,
|
||||
safeLookup,
|
||||
safeFetch,
|
||||
fetchWithTimeout,
|
||||
PrivateIpError,
|
||||
} from './fetch.js';
|
||||
import * as dnsPromises from 'node:dns/promises';
|
||||
import * as dns from 'node:dns';
|
||||
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js';
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// We need to mock node:dns for safeLookup since it uses the callback API
|
||||
vi.mock('node:dns', () => ({
|
||||
lookup: vi.fn(),
|
||||
}));
|
||||
|
||||
// Mock global fetch
|
||||
const originalFetch = global.fetch;
|
||||
global.fetch = vi.fn();
|
||||
@@ -114,150 +99,6 @@ describe('fetch utils', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('isPrivateIpAsync', () => {
|
||||
it('should identify private IPs directly', async () => {
|
||||
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to private IPs', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '10.0.0.1', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
|
||||
});
|
||||
|
||||
it('should identify domains resolving to public IPs as non-private', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockImplementation(
|
||||
async () =>
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
[{ address: '8.8.8.8', family: 4 }] as any,
|
||||
);
|
||||
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
|
||||
});
|
||||
|
||||
it('should throw error if DNS resolution fails (fail closed)', async () => {
|
||||
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
|
||||
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
|
||||
'Failed to verify if URL resolves to private IP',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return false for invalid URLs instead of throwing verification error', async () => {
|
||||
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeLookup', () => {
|
||||
it('should filter out private IPs', async () => {
|
||||
const addresses = [
|
||||
{ address: '8.8.8.8', family: 4 },
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('example.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('8.8.8.8');
|
||||
});
|
||||
|
||||
it('should allow explicit localhost', async () => {
|
||||
const addresses = [{ address: '127.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
const result = await new Promise<
|
||||
Array<{ address: string; family: number }>
|
||||
>((resolve, reject) => {
|
||||
safeLookup('localhost', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
});
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].address).toBe('127.0.0.1');
|
||||
});
|
||||
|
||||
it('should error if all resolved IPs are private', async () => {
|
||||
const addresses = [{ address: '10.0.0.1', family: 4 }];
|
||||
|
||||
vi.mocked(dns.lookup).mockImplementation(((
|
||||
_h: string,
|
||||
_o: dns.LookupOptions,
|
||||
cb: (
|
||||
err: Error | null,
|
||||
addr: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
) => {
|
||||
cb(null, addresses);
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
}) as any);
|
||||
|
||||
await expect(
|
||||
new Promise((resolve, reject) => {
|
||||
safeLookup('malicious.com', { all: true }, (err, filtered) => {
|
||||
if (err) reject(err);
|
||||
else resolve(filtered);
|
||||
});
|
||||
}),
|
||||
).rejects.toThrow(PrivateIpError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('safeFetch', () => {
|
||||
it('should forward to fetch with dispatcher', async () => {
|
||||
vi.mocked(global.fetch).mockResolvedValue(new Response('ok'));
|
||||
|
||||
const response = await safeFetch('https://example.com');
|
||||
expect(response.status).toBe(200);
|
||||
expect(global.fetch).toHaveBeenCalledWith(
|
||||
'https://example.com',
|
||||
expect.objectContaining({
|
||||
dispatcher: expect.any(Object),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle Refusing to connect errors', async () => {
|
||||
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
|
||||
|
||||
await expect(safeFetch('http://10.0.0.1')).rejects.toThrow(
|
||||
'Access to private network is blocked',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('fetchWithTimeout', () => {
|
||||
it('should handle timeouts', async () => {
|
||||
vi.mocked(global.fetch).mockImplementation(
|
||||
@@ -279,13 +120,5 @@ describe('fetch utils', () => {
|
||||
'Request timed out after 50ms',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle private IP errors via handleFetchError', async () => {
|
||||
vi.mocked(global.fetch).mockRejectedValue(new PrivateIpError());
|
||||
|
||||
await expect(fetchWithTimeout('http://10.0.0.1', 1000)).rejects.toThrow(
|
||||
'Access to private network is blocked: http://10.0.0.1',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,37 +6,12 @@
|
||||
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import * as dns from 'node:dns';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
import ipaddr from 'ipaddr.js';
|
||||
|
||||
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
||||
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
|
||||
|
||||
// Configure default global dispatcher with higher timeouts
|
||||
setGlobalDispatcher(
|
||||
new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
}),
|
||||
);
|
||||
|
||||
// Local extension of RequestInit to support Node.js/undici dispatcher
|
||||
interface NodeFetchInit extends RequestInit {
|
||||
dispatcher?: Agent | ProxyAgent;
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when a connection to a private IP address is blocked for security reasons.
|
||||
*/
|
||||
export class PrivateIpError extends Error {
|
||||
constructor(message = 'Refusing to connect to private IP address') {
|
||||
super(message);
|
||||
this.name = 'PrivateIpError';
|
||||
}
|
||||
}
|
||||
|
||||
export class FetchError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
@@ -48,6 +23,14 @@ export class FetchError extends Error {
|
||||
}
|
||||
}
|
||||
|
||||
// Configure default global dispatcher with higher timeouts
|
||||
setGlobalDispatcher(
|
||||
new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
}),
|
||||
);
|
||||
|
||||
/**
|
||||
* Sanitizes a hostname by stripping IPv6 brackets if present.
|
||||
*/
|
||||
@@ -69,53 +52,6 @@ export function isLoopbackHost(hostname: string): boolean {
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* A custom DNS lookup implementation for undici agents that prevents
|
||||
* connection to private IP ranges (SSRF protection).
|
||||
*/
|
||||
export function safeLookup(
|
||||
hostname: string,
|
||||
options: dns.LookupOptions | number | null | undefined,
|
||||
callback: (
|
||||
err: Error | null,
|
||||
addresses: Array<{ address: string; family: number }>,
|
||||
) => void,
|
||||
): void {
|
||||
// Use the callback-based dns.lookup to match undici's expected signature.
|
||||
// We explicitly handle the 'all' option to ensure we get an array of addresses.
|
||||
const lookupOptions =
|
||||
typeof options === 'number' ? { family: options } : { ...options };
|
||||
const finalOptions = { ...lookupOptions, all: true };
|
||||
|
||||
dns.lookup(hostname, finalOptions, (err, addresses) => {
|
||||
if (err) {
|
||||
callback(err, []);
|
||||
return;
|
||||
}
|
||||
|
||||
const addressArray = Array.isArray(addresses) ? addresses : [];
|
||||
const filtered = addressArray.filter(
|
||||
(addr) => !isAddressPrivate(addr.address) || isLoopbackHost(hostname),
|
||||
);
|
||||
|
||||
if (filtered.length === 0 && addressArray.length > 0) {
|
||||
callback(new PrivateIpError(), []);
|
||||
return;
|
||||
}
|
||||
|
||||
callback(null, filtered);
|
||||
});
|
||||
}
|
||||
|
||||
// Dedicated dispatcher with connection-level SSRF protection (safeLookup)
|
||||
const safeDispatcher = new Agent({
|
||||
headersTimeout: DEFAULT_HEADERS_TIMEOUT,
|
||||
bodyTimeout: DEFAULT_BODY_TIMEOUT,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
|
||||
export function isPrivateIp(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
@@ -125,37 +61,6 @@ export function isPrivateIp(url: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a URL resolves to a private IP address.
|
||||
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
|
||||
*/
|
||||
export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
// Fast check for literal IPs or localhost
|
||||
if (isAddressPrivate(hostname)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Resolve DNS to check the actual target IP
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||
} catch (e) {
|
||||
if (
|
||||
e instanceof Error &&
|
||||
e.name === 'TypeError' &&
|
||||
e.message.includes('Invalid URL')
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
throw new Error(`Failed to verify if URL resolves to private IP: ${url}`, {
|
||||
cause: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* IANA Benchmark Testing Range (198.18.0.0/15).
|
||||
* Classified as 'unicast' by ipaddr.js but is reserved and should not be
|
||||
@@ -210,72 +115,15 @@ export function isAddressPrivate(address: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to map varied fetch errors to a standardized FetchError.
|
||||
* Centralizes security-related error mapping (e.g. PrivateIpError).
|
||||
*/
|
||||
function handleFetchError(error: unknown, url: string): never {
|
||||
if (error instanceof PrivateIpError) {
|
||||
throw new FetchError(
|
||||
`Access to private network is blocked: ${url}`,
|
||||
'ERR_PRIVATE_NETWORK',
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
|
||||
if (error instanceof FetchError) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
throw new FetchError(
|
||||
getErrorMessage(error),
|
||||
isNodeError(error) ? error.code : undefined,
|
||||
{ cause: error },
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced fetch with SSRF protection.
|
||||
* Prevents access to private/internal networks at the connection level.
|
||||
*/
|
||||
export async function safeFetch(
|
||||
input: RequestInfo | URL,
|
||||
init?: RequestInit,
|
||||
): Promise<Response> {
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...init,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
return await fetch(input, nodeInit);
|
||||
} catch (error) {
|
||||
const url =
|
||||
input instanceof Request
|
||||
? input.url
|
||||
: typeof input === 'string'
|
||||
? input
|
||||
: input.toString();
|
||||
handleFetchError(error, url);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an undici ProxyAgent that incorporates safe DNS lookup.
|
||||
*/
|
||||
export function createSafeProxyAgent(proxyUrl: string): ProxyAgent {
|
||||
return new ProxyAgent({
|
||||
uri: proxyUrl,
|
||||
connect: {
|
||||
lookup: safeLookup,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs a fetch with a specified timeout and connection-level SSRF protection.
|
||||
*/
|
||||
export async function fetchWithTimeout(
|
||||
url: string,
|
||||
timeout: number,
|
||||
@@ -294,21 +142,17 @@ export async function fetchWithTimeout(
|
||||
}
|
||||
}
|
||||
|
||||
const nodeInit: NodeFetchInit = {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
dispatcher: safeDispatcher,
|
||||
};
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
const response = await fetch(url, nodeInit);
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
signal: controller.signal,
|
||||
});
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (isNodeError(error) && error.code === 'ABORT_ERR') {
|
||||
throw new FetchError(`Request timed out after ${timeout}ms`, 'ETIMEDOUT');
|
||||
}
|
||||
handleFetchError(error, url.toString());
|
||||
throw new FetchError(getErrorMessage(error), undefined, { cause: error });
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
@@ -454,7 +454,6 @@ export async function exchangeCodeForToken(
|
||||
params.append('resource', resource);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(config.tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
@@ -508,7 +507,6 @@ export async function refreshAccessToken(
|
||||
params.append('resource', resource);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line no-restricted-syntax -- TODO: Migrate to safeFetch for SSRF protection
|
||||
const response = await fetch(tokenUrl, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
|
||||
Reference in New Issue
Block a user