mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-11 06:31:01 -07:00
feat(a2a): strengthen SSRF protection with DNS resolution
This commit is contained in:
@@ -8,6 +8,7 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { A2AClientManager } from './a2a-client-manager.js';
|
||||
import type { AgentCard } from '@a2a-js/sdk';
|
||||
import * as sdkClient from '@a2a-js/sdk/client';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
interface MockClient {
|
||||
@@ -36,6 +37,10 @@ vi.mock('../utils/debugLogger.js', () => ({
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('node:dns/promises', () => ({
|
||||
lookup: vi.fn().mockResolvedValue([{ address: '93.184.216.34' }]),
|
||||
}));
|
||||
|
||||
describe('A2AClientManager', () => {
|
||||
let manager: A2AClientManager;
|
||||
const mockAgentCard: AgentCard = {
|
||||
@@ -407,6 +412,20 @@ describe('A2AClientManager', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
|
||||
const maliciousDomainUrl =
|
||||
'http://malicious.com/.well-known/agent-card.json';
|
||||
vi.mocked(lookup).mockResolvedValueOnce([
|
||||
{ address: '10.0.0.1', family: 4 },
|
||||
]);
|
||||
|
||||
await expect(
|
||||
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
|
||||
).rejects.toThrow(
|
||||
/Refusing to load agent 'dns-ssrf-agent' from private IP range/,
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
|
||||
const publicUrl = 'https://public.agent.com/card.json';
|
||||
const resolverInstance = {
|
||||
|
||||
@@ -26,7 +26,7 @@ import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { Agent as UndiciAgent } from 'undici';
|
||||
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
|
||||
import { isPrivateIp } from '../utils/fetch.js';
|
||||
import { isPrivateIpAsync } from '../utils/fetch.js';
|
||||
import { debugLogger } from '../utils/debugLogger.js';
|
||||
|
||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||
@@ -259,8 +259,8 @@ export class A2AClientManager {
|
||||
let baseUrl = url;
|
||||
let path: string | undefined;
|
||||
|
||||
// Validate URL to prevent SSRF
|
||||
if (isPrivateIp(url)) {
|
||||
// Validate URL to prevent SSRF (with DNS resolution)
|
||||
if (await isPrivateIpAsync(url)) {
|
||||
// Local/private IPs are allowed ONLY for localhost for testing.
|
||||
const parsed = new URL(url);
|
||||
if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') {
|
||||
@@ -285,7 +285,7 @@ export class A2AClientManager {
|
||||
const agentCard = normalizeAgentCard(rawCard);
|
||||
|
||||
// Deep validation of all transport URLs within the card to prevent SSRF
|
||||
this.validateAgentCardUrls(agentName, agentCard);
|
||||
await this.validateAgentCardUrls(agentName, agentCard);
|
||||
|
||||
return agentCard;
|
||||
}
|
||||
@@ -293,7 +293,10 @@ export class A2AClientManager {
|
||||
/**
|
||||
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
|
||||
*/
|
||||
private validateAgentCardUrls(agentName: string, card: AgentCard): void {
|
||||
private async validateAgentCardUrls(
|
||||
agentName: string,
|
||||
card: AgentCard,
|
||||
): Promise<void> {
|
||||
const urlsToValidate = [card.url];
|
||||
if (card.additionalInterfaces) {
|
||||
for (const intf of card.additionalInterfaces) {
|
||||
@@ -307,7 +310,7 @@ export class A2AClientManager {
|
||||
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
|
||||
const validationUrl = url.includes('://') ? url : `http://${url}`;
|
||||
|
||||
if (isPrivateIp(validationUrl)) {
|
||||
if (await isPrivateIpAsync(validationUrl)) {
|
||||
const parsed = new URL(validationUrl);
|
||||
if (
|
||||
parsed.hostname !== 'localhost' &&
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import { getErrorMessage, isNodeError } from './errors.js';
|
||||
import { URL } from 'node:url';
|
||||
import { lookup } from 'node:dns/promises';
|
||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||
|
||||
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
||||
@@ -44,12 +45,44 @@ export class FetchError extends Error {
|
||||
export function isPrivateIp(url: string): boolean {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
return PRIVATE_IP_RANGES.some((range) => range.test(hostname));
|
||||
return isAddressPrivate(hostname);
|
||||
} catch (_e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if a URL resolves to a private IP address.
|
||||
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
|
||||
*/
|
||||
export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||
try {
|
||||
const parsed = new URL(url);
|
||||
const hostname = parsed.hostname;
|
||||
|
||||
// Fast check for literal IPs or localhost
|
||||
if (isAddressPrivate(hostname)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Resolve DNS to check the actual target IP
|
||||
const addresses = await lookup(hostname, { all: true });
|
||||
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||
} catch (_e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal helper to check if an IP address string is in a private range.
|
||||
*/
|
||||
function isAddressPrivate(address: string): boolean {
|
||||
return (
|
||||
address === 'localhost' ||
|
||||
PRIVATE_IP_RANGES.some((range) => range.test(address))
|
||||
);
|
||||
}
|
||||
|
||||
export async function fetchWithTimeout(
|
||||
url: string,
|
||||
timeout: number,
|
||||
|
||||
Reference in New Issue
Block a user