mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-05-03 16:34:31 -07:00
feat(a2a): strengthen SSRF protection with DNS resolution
This commit is contained in:
@@ -8,6 +8,7 @@ import { vi, describe, it, expect, beforeEach, afterEach } from 'vitest';
|
|||||||
import { A2AClientManager } from './a2a-client-manager.js';
|
import { A2AClientManager } from './a2a-client-manager.js';
|
||||||
import type { AgentCard } from '@a2a-js/sdk';
|
import type { AgentCard } from '@a2a-js/sdk';
|
||||||
import * as sdkClient from '@a2a-js/sdk/client';
|
import * as sdkClient from '@a2a-js/sdk/client';
|
||||||
|
import { lookup } from 'node:dns/promises';
|
||||||
import { debugLogger } from '../utils/debugLogger.js';
|
import { debugLogger } from '../utils/debugLogger.js';
|
||||||
|
|
||||||
interface MockClient {
|
interface MockClient {
|
||||||
@@ -36,6 +37,10 @@ vi.mock('../utils/debugLogger.js', () => ({
|
|||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
vi.mock('node:dns/promises', () => ({
|
||||||
|
lookup: vi.fn().mockResolvedValue([{ address: '93.184.216.34' }]),
|
||||||
|
}));
|
||||||
|
|
||||||
describe('A2AClientManager', () => {
|
describe('A2AClientManager', () => {
|
||||||
let manager: A2AClientManager;
|
let manager: A2AClientManager;
|
||||||
const mockAgentCard: AgentCard = {
|
const mockAgentCard: AgentCard = {
|
||||||
@@ -407,6 +412,20 @@ describe('A2AClientManager', () => {
|
|||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('should throw if a domain resolves to a private IP (DNS SSRF protection)', async () => {
|
||||||
|
const maliciousDomainUrl =
|
||||||
|
'http://malicious.com/.well-known/agent-card.json';
|
||||||
|
vi.mocked(lookup).mockResolvedValueOnce([
|
||||||
|
{ address: '10.0.0.1', family: 4 },
|
||||||
|
]);
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
manager.loadAgent('dns-ssrf-agent', maliciousDomainUrl),
|
||||||
|
).rejects.toThrow(
|
||||||
|
/Refusing to load agent 'dns-ssrf-agent' from private IP range/,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
|
it('should throw if a public agent card contains a private transport URL (Deep SSRF protection)', async () => {
|
||||||
const publicUrl = 'https://public.agent.com/card.json';
|
const publicUrl = 'https://public.agent.com/card.json';
|
||||||
const resolverInstance = {
|
const resolverInstance = {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ import { GrpcTransportFactory } from '@a2a-js/sdk/client/grpc';
|
|||||||
import { v4 as uuidv4 } from 'uuid';
|
import { v4 as uuidv4 } from 'uuid';
|
||||||
import { Agent as UndiciAgent } from 'undici';
|
import { Agent as UndiciAgent } from 'undici';
|
||||||
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
|
import { getGrpcCredentials, normalizeAgentCard } from './a2aUtils.js';
|
||||||
import { isPrivateIp } from '../utils/fetch.js';
|
import { isPrivateIpAsync } from '../utils/fetch.js';
|
||||||
import { debugLogger } from '../utils/debugLogger.js';
|
import { debugLogger } from '../utils/debugLogger.js';
|
||||||
|
|
||||||
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
// Remote agents can take 10+ minutes (e.g. Deep Research).
|
||||||
@@ -259,8 +259,8 @@ export class A2AClientManager {
|
|||||||
let baseUrl = url;
|
let baseUrl = url;
|
||||||
let path: string | undefined;
|
let path: string | undefined;
|
||||||
|
|
||||||
// Validate URL to prevent SSRF
|
// Validate URL to prevent SSRF (with DNS resolution)
|
||||||
if (isPrivateIp(url)) {
|
if (await isPrivateIpAsync(url)) {
|
||||||
// Local/private IPs are allowed ONLY for localhost for testing.
|
// Local/private IPs are allowed ONLY for localhost for testing.
|
||||||
const parsed = new URL(url);
|
const parsed = new URL(url);
|
||||||
if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') {
|
if (parsed.hostname !== 'localhost' && parsed.hostname !== '127.0.0.1') {
|
||||||
@@ -285,7 +285,7 @@ export class A2AClientManager {
|
|||||||
const agentCard = normalizeAgentCard(rawCard);
|
const agentCard = normalizeAgentCard(rawCard);
|
||||||
|
|
||||||
// Deep validation of all transport URLs within the card to prevent SSRF
|
// Deep validation of all transport URLs within the card to prevent SSRF
|
||||||
this.validateAgentCardUrls(agentName, agentCard);
|
await this.validateAgentCardUrls(agentName, agentCard);
|
||||||
|
|
||||||
return agentCard;
|
return agentCard;
|
||||||
}
|
}
|
||||||
@@ -293,7 +293,10 @@ export class A2AClientManager {
|
|||||||
/**
|
/**
|
||||||
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
|
* Validates all URLs (top-level and interfaces) within an AgentCard for SSRF.
|
||||||
*/
|
*/
|
||||||
private validateAgentCardUrls(agentName: string, card: AgentCard): void {
|
private async validateAgentCardUrls(
|
||||||
|
agentName: string,
|
||||||
|
card: AgentCard,
|
||||||
|
): Promise<void> {
|
||||||
const urlsToValidate = [card.url];
|
const urlsToValidate = [card.url];
|
||||||
if (card.additionalInterfaces) {
|
if (card.additionalInterfaces) {
|
||||||
for (const intf of card.additionalInterfaces) {
|
for (const intf of card.additionalInterfaces) {
|
||||||
@@ -307,7 +310,7 @@ export class A2AClientManager {
|
|||||||
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
|
// Ensure URL has a scheme for the parser (gRPC often provides raw IP:port)
|
||||||
const validationUrl = url.includes('://') ? url : `http://${url}`;
|
const validationUrl = url.includes('://') ? url : `http://${url}`;
|
||||||
|
|
||||||
if (isPrivateIp(validationUrl)) {
|
if (await isPrivateIpAsync(validationUrl)) {
|
||||||
const parsed = new URL(validationUrl);
|
const parsed = new URL(validationUrl);
|
||||||
if (
|
if (
|
||||||
parsed.hostname !== 'localhost' &&
|
parsed.hostname !== 'localhost' &&
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import { getErrorMessage, isNodeError } from './errors.js';
|
import { getErrorMessage, isNodeError } from './errors.js';
|
||||||
import { URL } from 'node:url';
|
import { URL } from 'node:url';
|
||||||
|
import { lookup } from 'node:dns/promises';
|
||||||
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
|
||||||
|
|
||||||
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
|
||||||
@@ -44,12 +45,44 @@ export class FetchError extends Error {
|
|||||||
export function isPrivateIp(url: string): boolean {
|
export function isPrivateIp(url: string): boolean {
|
||||||
try {
|
try {
|
||||||
const hostname = new URL(url).hostname;
|
const hostname = new URL(url).hostname;
|
||||||
return PRIVATE_IP_RANGES.some((range) => range.test(hostname));
|
return isAddressPrivate(hostname);
|
||||||
} catch (_e) {
|
} catch (_e) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a URL resolves to a private IP address.
|
||||||
|
* Performs DNS resolution to prevent DNS rebinding/SSRF bypasses.
|
||||||
|
*/
|
||||||
|
export async function isPrivateIpAsync(url: string): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
const parsed = new URL(url);
|
||||||
|
const hostname = parsed.hostname;
|
||||||
|
|
||||||
|
// Fast check for literal IPs or localhost
|
||||||
|
if (isAddressPrivate(hostname)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve DNS to check the actual target IP
|
||||||
|
const addresses = await lookup(hostname, { all: true });
|
||||||
|
return addresses.some((addr) => isAddressPrivate(addr.address));
|
||||||
|
} catch (_e) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Internal helper to check if an IP address string is in a private range.
|
||||||
|
*/
|
||||||
|
function isAddressPrivate(address: string): boolean {
|
||||||
|
return (
|
||||||
|
address === 'localhost' ||
|
||||||
|
PRIVATE_IP_RANGES.some((range) => range.test(address))
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
export async function fetchWithTimeout(
|
export async function fetchWithTimeout(
|
||||||
url: string,
|
url: string,
|
||||||
timeout: number,
|
timeout: number,
|
||||||
|
|||||||
Reference in New Issue
Block a user