From 990d010ecfc9d25fb887b23b495c07426252f307 Mon Sep 17 00:00:00 2001 From: Aishanee Shah Date: Mon, 16 Mar 2026 17:38:53 -0400 Subject: [PATCH] feat(core): implement Stage 2 security and consistency improvements for web_fetch (#22217) --- packages/core/src/tools/web-fetch.test.ts | 28 +-- packages/core/src/tools/web-fetch.ts | 208 +++++++++++++++------- packages/core/src/utils/fetch.test.ts | 68 ++++++- packages/core/src/utils/fetch.ts | 32 ++++ 4 files changed, 250 insertions(+), 86 deletions(-) diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index 8e928499cc..2b65a24930 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -497,7 +497,7 @@ describe('WebFetchTool', () => { expect(result.llmContent).toBe('fallback processed response'); expect(result.returnDisplay).toContain( - '2 URL(s) processed using fallback fetch', + 'URL(s) processed using fallback fetch', ); }); @@ -530,7 +530,7 @@ describe('WebFetchTool', () => { // Verify private URL was NOT fetched (mockFetch would throw if it was called for private.com) }); - it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => { + it('should return WEB_FETCH_FALLBACK_FAILED on total failure', async () => { vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); mockGenerateContent.mockRejectedValue(new Error('primary fail')); mockFetch('https://public.ip/', new Error('fallback fetch failed')); @@ -541,16 +541,6 @@ describe('WebFetchTool', () => { expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED); }); - it('should return WEB_FETCH_FALLBACK_FAILED on general processing failure (when fallback also fails)', async () => { - vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); - mockGenerateContent.mockRejectedValue(new Error('API error')); - const tool = new WebFetchTool(mockConfig, bus); - const params = { prompt: 'fetch https://public.ip' }; - const invocation = tool.build(params); - const result = await invocation.execute(new AbortController().signal); - expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED); - }); - it('should log telemetry when falling back due to primary fetch failure', async () => { vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); // Mock primary fetch to return empty response, triggering fallback @@ -639,6 +629,14 @@ describe('WebFetchTool', () => { const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); + const sanitizeXml = (text: string) => + text + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); + if (shouldConvert) { expect(convert).toHaveBeenCalledWith(content, { wordwrap: false, @@ -647,10 +645,12 @@ describe('WebFetchTool', () => { { selector: 'img', format: 'skip' }, ], }); - expect(result.llmContent).toContain(`Converted: ${content}`); + expect(result.llmContent).toContain( + `Converted: ${sanitizeXml(content)}`, + ); } else { expect(convert).not.toHaveBeenCalled(); - expect(result.llmContent).toContain(content); + expect(result.llmContent).toContain(sanitizeXml(content)); } }, ); diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 365c2b55ed..27a60c4259 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -40,7 +40,7 @@ import { LRUCache } from 'mnemonist'; import type { AgentLoopContext } from '../config/agent-loop-context.js'; const URL_FETCH_TIMEOUT_MS = 10000; -const MAX_CONTENT_LENGTH = 100000; +const MAX_CONTENT_LENGTH = 250000; const MAX_EXPERIMENTAL_FETCH_SIZE = 10 * 1024 * 1024; // 10MB const USER_AGENT = 'Mozilla/5.0 (compatible; Google-Gemini-CLI/1.0; +https://github.com/google-gemini/gemini-cli)'; @@ -190,6 +190,18 @@ function isGroundingSupportItem(item: unknown): item is GroundingSupportItem { return typeof item === 'object' && item !== null; } +/** + * Sanitizes text for safe embedding in XML tags. + */ +function sanitizeXml(text: string): string { + return text + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"') + .replace(/'/g, '''); +} + /** * Parameters for the WebFetch tool */ @@ -263,69 +275,65 @@ class WebFetchToolInvocation extends BaseToolInvocation< private async executeFallbackForUrl( urlStr: string, signal: AbortSignal, - contentBudget: number, ): Promise { const url = convertGithubUrlToRaw(urlStr); if (this.isBlockedHost(url)) { debugLogger.warn(`[WebFetchTool] Blocked access to host: ${url}`); - return `Error fetching ${url}: Access to blocked or private host is not allowed.`; + throw new Error( + `Access to blocked or private host ${url} is not allowed.`, + ); } - try { - const response = await retryWithBackoff( - async () => { - const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, { - signal, - headers: { - 'User-Agent': USER_AGENT, - }, - }); - if (!res.ok) { - const error = new Error( - `Request failed with status code ${res.status} ${res.statusText}`, - ); - (error as ErrorWithStatus).status = res.status; - throw error; - } - return res; - }, - { - retryFetchErrors: this.context.config.getRetryFetchErrors(), - onRetry: (attempt, error, delayMs) => - this.handleRetry(attempt, error, delayMs), + const response = await retryWithBackoff( + async () => { + const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, { signal, - }, - ); - - const bodyBuffer = await this.readResponseWithLimit( - response, - MAX_EXPERIMENTAL_FETCH_SIZE, - ); - const rawContent = bodyBuffer.toString('utf8'); - const contentType = response.headers.get('content-type') || ''; - let textContent: string; - - // Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML) - if ( - contentType.toLowerCase().includes('text/html') || - contentType === '' - ) { - textContent = convert(rawContent, { - wordwrap: false, - selectors: [ - { selector: 'a', options: { ignoreHref: true } }, - { selector: 'img', format: 'skip' }, - ], + headers: { + 'User-Agent': USER_AGENT, + }, }); - } else { - // For other content types (text/plain, application/json, etc.), use raw text - textContent = rawContent; - } + if (!res.ok) { + const error = new Error( + `Request failed with status code ${res.status} ${res.statusText}`, + ); + (error as ErrorWithStatus).status = res.status; + throw error; + } + return res; + }, + { + retryFetchErrors: this.context.config.getRetryFetchErrors(), + onRetry: (attempt, error, delayMs) => + this.handleRetry(attempt, error, delayMs), + signal, + }, + ); - return truncateString(textContent, contentBudget, TRUNCATION_WARNING); - } catch (e) { - return `Error fetching ${url}: ${getErrorMessage(e)}`; + const bodyBuffer = await this.readResponseWithLimit( + response, + MAX_EXPERIMENTAL_FETCH_SIZE, + ); + const rawContent = bodyBuffer.toString('utf8'); + const contentType = response.headers.get('content-type') || ''; + let textContent: string; + + // Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML) + if (contentType.toLowerCase().includes('text/html') || contentType === '') { + textContent = convert(rawContent, { + wordwrap: false, + selectors: [ + { selector: 'a', options: { ignoreHref: true } }, + { selector: 'img', format: 'skip' }, + ], + }); + } else { + // For other content types (text/plain, application/json, etc.), use raw text + textContent = rawContent; } + + // Cap at MAX_CONTENT_LENGTH initially to avoid excessive memory usage + // before the global budget allocation. + return truncateString(textContent, MAX_CONTENT_LENGTH, ''); } private filterAndValidateUrls(urls: string[]): { @@ -363,30 +371,82 @@ class WebFetchToolInvocation extends BaseToolInvocation< signal: AbortSignal, ): Promise { const uniqueUrls = [...new Set(urls)]; - const contentBudget = Math.floor( - MAX_CONTENT_LENGTH / (uniqueUrls.length || 1), - ); - const results: string[] = []; + const successes: Array<{ url: string; content: string }> = []; + const errors: Array<{ url: string; message: string }> = []; for (const url of uniqueUrls) { - results.push( - await this.executeFallbackForUrl(url, signal, contentBudget), - ); + try { + const content = await this.executeFallbackForUrl(url, signal); + successes.push({ url, content }); + } catch (e) { + errors.push({ url, message: getErrorMessage(e) }); + } } - const aggregatedContent = results - .map((content, i) => `URL: ${uniqueUrls[i]}\nContent:\n${content}`) - .join('\n\n---\n\n'); + // Change 2: Short-circuit on total failure + if (successes.length === 0) { + const errorMessage = `All fallback fetch attempts failed: ${errors + .map((e) => `${e.url}: ${e.message}`) + .join(', ')}`; + debugLogger.error(`[WebFetchTool] ${errorMessage}`); + return { + llmContent: `Error: ${errorMessage}`, + returnDisplay: `Error: ${errorMessage}`, + error: { + message: errorMessage, + type: ToolErrorType.WEB_FETCH_FALLBACK_FAILED, + }, + }; + } + + // Smart Budget Allocation (Water-filling algorithm) for successes + const sortedSuccesses = [...successes].sort( + (a, b) => a.content.length - b.content.length, + ); + + let remainingBudget = MAX_CONTENT_LENGTH; + let remainingUrls = sortedSuccesses.length; + const finalContentsByUrl = new Map(); + + for (const success of sortedSuccesses) { + const fairShare = Math.floor(remainingBudget / remainingUrls); + const allocated = Math.min(success.content.length, fairShare); + + const truncated = truncateString( + success.content, + allocated, + TRUNCATION_WARNING, + ); + + finalContentsByUrl.set(success.url, truncated); + remainingBudget -= truncated.length; + remainingUrls--; + } + + const aggregatedContent = uniqueUrls + .map((url) => { + const content = finalContentsByUrl.get(url); + if (content !== undefined) { + return `\n${sanitizeXml(content)}\n`; + } + const error = errors.find((e) => e.url === url); + return `\nError: ${sanitizeXml(error?.message || 'Unknown error')}\n`; + }) + .join('\n'); try { const geminiClient = this.context.geminiClient; - const fallbackPrompt = `The user requested the following: "${this.params.prompt}". + const fallbackPrompt = `Follow the user's instructions below using the provided webpage content. + + +${sanitizeXml(this.params.prompt ?? '')} + I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again. ---- + ${aggregatedContent} ---- + `; const result = await geminiClient.generateContent( { model: 'web-fetch-fallback' }, @@ -716,9 +776,19 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun try { const geminiClient = this.context.geminiClient; + const sanitizedPrompt = `Follow the user's instructions to process the authorized URLs. + + +${sanitizeXml(userPrompt)} + + + +${toFetch.join('\n')} + +`; const response = await geminiClient.generateContent( { model: 'web-fetch' }, - [{ role: 'user', parts: [{ text: userPrompt }] }], + [{ role: 'user', parts: [{ text: sanitizedPrompt }] }], signal, LlmRole.UTILITY_TOOL, ); @@ -870,7 +940,7 @@ export class WebFetchTool extends BaseDeclarativeTool< _toolDisplayName?: string, ): ToolInvocation { return new WebFetchToolInvocation( - this.context.config, + this.context, params, messageBus, _toolName, diff --git a/packages/core/src/utils/fetch.test.ts b/packages/core/src/utils/fetch.test.ts index 4ac0c7b344..c4644c3cba 100644 --- a/packages/core/src/utils/fetch.test.ts +++ b/packages/core/src/utils/fetch.test.ts @@ -5,7 +5,15 @@ */ import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; -import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js'; +import { + isPrivateIp, + isPrivateIpAsync, + isAddressPrivate, + fetchWithTimeout, +} from './fetch.js'; +import * as dnsPromises from 'node:dns/promises'; +import type { LookupAddress, LookupAllOptions } from 'node:dns'; +import ipaddr from 'ipaddr.js'; vi.mock('node:dns/promises', () => ({ lookup: vi.fn(), @@ -15,9 +23,25 @@ vi.mock('node:dns/promises', () => ({ const originalFetch = global.fetch; global.fetch = vi.fn(); +interface ErrorWithCode extends Error { + code?: string; +} + describe('fetch utils', () => { beforeEach(() => { vi.clearAllMocks(); + // Default DNS lookup to return a public IP, or the IP itself if valid + vi.mocked( + dnsPromises.lookup as ( + hostname: string, + options: LookupAllOptions, + ) => Promise, + ).mockImplementation(async (hostname: string) => { + if (ipaddr.isValid(hostname)) { + return [{ address: hostname, family: hostname.includes(':') ? 6 : 4 }]; + } + return [{ address: '93.184.216.34', family: 4 }]; + }); }); afterAll(() => { @@ -99,6 +123,43 @@ describe('fetch utils', () => { }); }); + describe('isPrivateIpAsync', () => { + it('should identify private IPs directly', async () => { + expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true); + }); + + it('should identify domains resolving to private IPs', async () => { + vi.mocked( + dnsPromises.lookup as ( + hostname: string, + options: LookupAllOptions, + ) => Promise, + ).mockImplementation(async () => [{ address: '10.0.0.1', family: 4 }]); + expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true); + }); + + it('should identify domains resolving to public IPs as non-private', async () => { + vi.mocked( + dnsPromises.lookup as ( + hostname: string, + options: LookupAllOptions, + ) => Promise, + ).mockImplementation(async () => [{ address: '8.8.8.8', family: 4 }]); + expect(await isPrivateIpAsync('http://google.com/')).toBe(false); + }); + + it('should throw error if DNS resolution fails (fail closed)', async () => { + vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error')); + await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow( + 'Failed to verify if URL resolves to private IP', + ); + }); + + it('should return false for invalid URLs instead of throwing verification error', async () => { + expect(await isPrivateIpAsync('not-a-url')).toBe(false); + }); + }); + describe('fetchWithTimeout', () => { it('should handle timeouts', async () => { vi.mocked(global.fetch).mockImplementation( @@ -106,9 +167,10 @@ describe('fetch utils', () => { new Promise((_resolve, reject) => { if (init?.signal) { init.signal.addEventListener('abort', () => { - const error = new Error('The operation was aborted'); + const error = new Error( + 'The operation was aborted', + ) as ErrorWithCode; error.name = 'AbortError'; - // @ts-expect-error - for mocking purposes error.code = 'ABORT_ERR'; reject(error); }); diff --git a/packages/core/src/utils/fetch.ts b/packages/core/src/utils/fetch.ts index e339ea7fed..8f1ddf864f 100644 --- a/packages/core/src/utils/fetch.ts +++ b/packages/core/src/utils/fetch.ts @@ -8,6 +8,7 @@ import { getErrorMessage, isNodeError } from './errors.js'; import { URL } from 'node:url'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; import ipaddr from 'ipaddr.js'; +import { lookup } from 'node:dns/promises'; const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes @@ -23,6 +24,13 @@ export class FetchError extends Error { } } +export class PrivateIpError extends Error { + constructor(message = 'Access to private network is blocked') { + super(message); + this.name = 'PrivateIpError'; + } +} + // Configure default global dispatcher with higher timeouts setGlobalDispatcher( new Agent({ @@ -115,6 +123,30 @@ export function isAddressPrivate(address: string): boolean { } } +/** + * Checks if a URL resolves to a private IP address. + */ +export async function isPrivateIpAsync(url: string): Promise { + try { + const parsedUrl = new URL(url); + const hostname = parsedUrl.hostname; + + if (isLoopbackHost(hostname)) { + return false; + } + + const addresses = await lookup(hostname, { all: true }); + return addresses.some((addr) => isAddressPrivate(addr.address)); + } catch (error) { + if (error instanceof TypeError) { + return false; + } + throw new Error('Failed to verify if URL resolves to private IP', { + cause: error, + }); + } +} + /** * Creates an undici ProxyAgent that incorporates safe DNS lookup. */