feat(core): implement Stage 2 security and consistency improvements for web_fetch (#22217)

This commit is contained in:
Aishanee Shah
2026-03-16 17:38:53 -04:00
committed by GitHub
parent b6c6da3618
commit 990d010ecf
4 changed files with 250 additions and 86 deletions
+14 -14
View File
@@ -497,7 +497,7 @@ describe('WebFetchTool', () => {
expect(result.llmContent).toBe('fallback processed response'); expect(result.llmContent).toBe('fallback processed response');
expect(result.returnDisplay).toContain( expect(result.returnDisplay).toContain(
'2 URL(s) processed using fallback fetch', 'URL(s) processed using fallback fetch',
); );
}); });
@@ -530,7 +530,7 @@ describe('WebFetchTool', () => {
// Verify private URL was NOT fetched (mockFetch would throw if it was called for private.com) // Verify private URL was NOT fetched (mockFetch would throw if it was called for private.com)
}); });
it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => { it('should return WEB_FETCH_FALLBACK_FAILED on total failure', async () => {
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
mockGenerateContent.mockRejectedValue(new Error('primary fail')); mockGenerateContent.mockRejectedValue(new Error('primary fail'));
mockFetch('https://public.ip/', new Error('fallback fetch failed')); mockFetch('https://public.ip/', new Error('fallback fetch failed'));
@@ -541,16 +541,6 @@ describe('WebFetchTool', () => {
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED); expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED);
}); });
it('should return WEB_FETCH_FALLBACK_FAILED on general processing failure (when fallback also fails)', async () => {
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
mockGenerateContent.mockRejectedValue(new Error('API error'));
const tool = new WebFetchTool(mockConfig, bus);
const params = { prompt: 'fetch https://public.ip' };
const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal);
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED);
});
it('should log telemetry when falling back due to primary fetch failure', async () => { it('should log telemetry when falling back due to primary fetch failure', async () => {
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
// Mock primary fetch to return empty response, triggering fallback // Mock primary fetch to return empty response, triggering fallback
@@ -639,6 +629,14 @@ describe('WebFetchTool', () => {
const invocation = tool.build(params); const invocation = tool.build(params);
const result = await invocation.execute(new AbortController().signal); const result = await invocation.execute(new AbortController().signal);
const sanitizeXml = (text: string) =>
text
.replace(/&/g, '&')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&apos;');
if (shouldConvert) { if (shouldConvert) {
expect(convert).toHaveBeenCalledWith(content, { expect(convert).toHaveBeenCalledWith(content, {
wordwrap: false, wordwrap: false,
@@ -647,10 +645,12 @@ describe('WebFetchTool', () => {
{ selector: 'img', format: 'skip' }, { selector: 'img', format: 'skip' },
], ],
}); });
expect(result.llmContent).toContain(`Converted: ${content}`); expect(result.llmContent).toContain(
`Converted: ${sanitizeXml(content)}`,
);
} else { } else {
expect(convert).not.toHaveBeenCalled(); expect(convert).not.toHaveBeenCalled();
expect(result.llmContent).toContain(content); expect(result.llmContent).toContain(sanitizeXml(content));
} }
}, },
); );
+97 -27
View File
@@ -40,7 +40,7 @@ import { LRUCache } from 'mnemonist';
import type { AgentLoopContext } from '../config/agent-loop-context.js'; import type { AgentLoopContext } from '../config/agent-loop-context.js';
const URL_FETCH_TIMEOUT_MS = 10000; const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 100000; const MAX_CONTENT_LENGTH = 250000;
const MAX_EXPERIMENTAL_FETCH_SIZE = 10 * 1024 * 1024; // 10MB const MAX_EXPERIMENTAL_FETCH_SIZE = 10 * 1024 * 1024; // 10MB
const USER_AGENT = const USER_AGENT =
'Mozilla/5.0 (compatible; Google-Gemini-CLI/1.0; +https://github.com/google-gemini/gemini-cli)'; 'Mozilla/5.0 (compatible; Google-Gemini-CLI/1.0; +https://github.com/google-gemini/gemini-cli)';
@@ -190,6 +190,18 @@ function isGroundingSupportItem(item: unknown): item is GroundingSupportItem {
return typeof item === 'object' && item !== null; return typeof item === 'object' && item !== null;
} }
/**
* Sanitizes text for safe embedding in XML tags.
*/
function sanitizeXml(text: string): string {
return text
.replace(/&/g, '&amp;')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&apos;');
}
/** /**
* Parameters for the WebFetch tool * Parameters for the WebFetch tool
*/ */
@@ -263,15 +275,15 @@ class WebFetchToolInvocation extends BaseToolInvocation<
private async executeFallbackForUrl( private async executeFallbackForUrl(
urlStr: string, urlStr: string,
signal: AbortSignal, signal: AbortSignal,
contentBudget: number,
): Promise<string> { ): Promise<string> {
const url = convertGithubUrlToRaw(urlStr); const url = convertGithubUrlToRaw(urlStr);
if (this.isBlockedHost(url)) { if (this.isBlockedHost(url)) {
debugLogger.warn(`[WebFetchTool] Blocked access to host: ${url}`); debugLogger.warn(`[WebFetchTool] Blocked access to host: ${url}`);
return `Error fetching ${url}: Access to blocked or private host is not allowed.`; throw new Error(
`Access to blocked or private host ${url} is not allowed.`,
);
} }
try {
const response = await retryWithBackoff( const response = await retryWithBackoff(
async () => { async () => {
const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, { const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, {
@@ -306,10 +318,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
let textContent: string; let textContent: string;
// Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML) // Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML)
if ( if (contentType.toLowerCase().includes('text/html') || contentType === '') {
contentType.toLowerCase().includes('text/html') ||
contentType === ''
) {
textContent = convert(rawContent, { textContent = convert(rawContent, {
wordwrap: false, wordwrap: false,
selectors: [ selectors: [
@@ -322,10 +331,9 @@ class WebFetchToolInvocation extends BaseToolInvocation<
textContent = rawContent; textContent = rawContent;
} }
return truncateString(textContent, contentBudget, TRUNCATION_WARNING); // Cap at MAX_CONTENT_LENGTH initially to avoid excessive memory usage
} catch (e) { // before the global budget allocation.
return `Error fetching ${url}: ${getErrorMessage(e)}`; return truncateString(textContent, MAX_CONTENT_LENGTH, '');
}
} }
private filterAndValidateUrls(urls: string[]): { private filterAndValidateUrls(urls: string[]): {
@@ -363,30 +371,82 @@ class WebFetchToolInvocation extends BaseToolInvocation<
signal: AbortSignal, signal: AbortSignal,
): Promise<ToolResult> { ): Promise<ToolResult> {
const uniqueUrls = [...new Set(urls)]; const uniqueUrls = [...new Set(urls)];
const contentBudget = Math.floor( const successes: Array<{ url: string; content: string }> = [];
MAX_CONTENT_LENGTH / (uniqueUrls.length || 1), const errors: Array<{ url: string; message: string }> = [];
);
const results: string[] = [];
for (const url of uniqueUrls) { for (const url of uniqueUrls) {
results.push( try {
await this.executeFallbackForUrl(url, signal, contentBudget), const content = await this.executeFallbackForUrl(url, signal);
); successes.push({ url, content });
} catch (e) {
errors.push({ url, message: getErrorMessage(e) });
}
} }
const aggregatedContent = results // Change 2: Short-circuit on total failure
.map((content, i) => `URL: ${uniqueUrls[i]}\nContent:\n${content}`) if (successes.length === 0) {
.join('\n\n---\n\n'); const errorMessage = `All fallback fetch attempts failed: ${errors
.map((e) => `${e.url}: ${e.message}`)
.join(', ')}`;
debugLogger.error(`[WebFetchTool] ${errorMessage}`);
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
error: {
message: errorMessage,
type: ToolErrorType.WEB_FETCH_FALLBACK_FAILED,
},
};
}
// Smart Budget Allocation (Water-filling algorithm) for successes
const sortedSuccesses = [...successes].sort(
(a, b) => a.content.length - b.content.length,
);
let remainingBudget = MAX_CONTENT_LENGTH;
let remainingUrls = sortedSuccesses.length;
const finalContentsByUrl = new Map<string, string>();
for (const success of sortedSuccesses) {
const fairShare = Math.floor(remainingBudget / remainingUrls);
const allocated = Math.min(success.content.length, fairShare);
const truncated = truncateString(
success.content,
allocated,
TRUNCATION_WARNING,
);
finalContentsByUrl.set(success.url, truncated);
remainingBudget -= truncated.length;
remainingUrls--;
}
const aggregatedContent = uniqueUrls
.map((url) => {
const content = finalContentsByUrl.get(url);
if (content !== undefined) {
return `<source url="${sanitizeXml(url)}">\n${sanitizeXml(content)}\n</source>`;
}
const error = errors.find((e) => e.url === url);
return `<source url="${sanitizeXml(url)}">\nError: ${sanitizeXml(error?.message || 'Unknown error')}\n</source>`;
})
.join('\n');
try { try {
const geminiClient = this.context.geminiClient; const geminiClient = this.context.geminiClient;
const fallbackPrompt = `The user requested the following: "${this.params.prompt}". const fallbackPrompt = `Follow the user's instructions below using the provided webpage content.
<user_instructions>
${sanitizeXml(this.params.prompt ?? '')}
</user_instructions>
I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again. I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again.
--- <content>
${aggregatedContent} ${aggregatedContent}
--- </content>
`; `;
const result = await geminiClient.generateContent( const result = await geminiClient.generateContent(
{ model: 'web-fetch-fallback' }, { model: 'web-fetch-fallback' },
@@ -716,9 +776,19 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
try { try {
const geminiClient = this.context.geminiClient; const geminiClient = this.context.geminiClient;
const sanitizedPrompt = `Follow the user's instructions to process the authorized URLs.
<user_instructions>
${sanitizeXml(userPrompt)}
</user_instructions>
<authorized_urls>
${toFetch.join('\n')}
</authorized_urls>
`;
const response = await geminiClient.generateContent( const response = await geminiClient.generateContent(
{ model: 'web-fetch' }, { model: 'web-fetch' },
[{ role: 'user', parts: [{ text: userPrompt }] }], [{ role: 'user', parts: [{ text: sanitizedPrompt }] }],
signal, signal,
LlmRole.UTILITY_TOOL, LlmRole.UTILITY_TOOL,
); );
@@ -870,7 +940,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
_toolDisplayName?: string, _toolDisplayName?: string,
): ToolInvocation<WebFetchToolParams, ToolResult> { ): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation( return new WebFetchToolInvocation(
this.context.config, this.context,
params, params,
messageBus, messageBus,
_toolName, _toolName,
+65 -3
View File
@@ -5,7 +5,15 @@
*/ */
import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterAll } from 'vitest';
import { isPrivateIp, isAddressPrivate, fetchWithTimeout } from './fetch.js'; import {
isPrivateIp,
isPrivateIpAsync,
isAddressPrivate,
fetchWithTimeout,
} from './fetch.js';
import * as dnsPromises from 'node:dns/promises';
import type { LookupAddress, LookupAllOptions } from 'node:dns';
import ipaddr from 'ipaddr.js';
vi.mock('node:dns/promises', () => ({ vi.mock('node:dns/promises', () => ({
lookup: vi.fn(), lookup: vi.fn(),
@@ -15,9 +23,25 @@ vi.mock('node:dns/promises', () => ({
const originalFetch = global.fetch; const originalFetch = global.fetch;
global.fetch = vi.fn(); global.fetch = vi.fn();
interface ErrorWithCode extends Error {
code?: string;
}
describe('fetch utils', () => { describe('fetch utils', () => {
beforeEach(() => { beforeEach(() => {
vi.clearAllMocks(); vi.clearAllMocks();
// Default DNS lookup to return a public IP, or the IP itself if valid
vi.mocked(
dnsPromises.lookup as (
hostname: string,
options: LookupAllOptions,
) => Promise<LookupAddress[]>,
).mockImplementation(async (hostname: string) => {
if (ipaddr.isValid(hostname)) {
return [{ address: hostname, family: hostname.includes(':') ? 6 : 4 }];
}
return [{ address: '93.184.216.34', family: 4 }];
});
}); });
afterAll(() => { afterAll(() => {
@@ -99,6 +123,43 @@ describe('fetch utils', () => {
}); });
}); });
describe('isPrivateIpAsync', () => {
it('should identify private IPs directly', async () => {
expect(await isPrivateIpAsync('http://10.0.0.1/')).toBe(true);
});
it('should identify domains resolving to private IPs', async () => {
vi.mocked(
dnsPromises.lookup as (
hostname: string,
options: LookupAllOptions,
) => Promise<LookupAddress[]>,
).mockImplementation(async () => [{ address: '10.0.0.1', family: 4 }]);
expect(await isPrivateIpAsync('http://malicious.com/')).toBe(true);
});
it('should identify domains resolving to public IPs as non-private', async () => {
vi.mocked(
dnsPromises.lookup as (
hostname: string,
options: LookupAllOptions,
) => Promise<LookupAddress[]>,
).mockImplementation(async () => [{ address: '8.8.8.8', family: 4 }]);
expect(await isPrivateIpAsync('http://google.com/')).toBe(false);
});
it('should throw error if DNS resolution fails (fail closed)', async () => {
vi.mocked(dnsPromises.lookup).mockRejectedValue(new Error('DNS Error'));
await expect(isPrivateIpAsync('http://unreachable.com/')).rejects.toThrow(
'Failed to verify if URL resolves to private IP',
);
});
it('should return false for invalid URLs instead of throwing verification error', async () => {
expect(await isPrivateIpAsync('not-a-url')).toBe(false);
});
});
describe('fetchWithTimeout', () => { describe('fetchWithTimeout', () => {
it('should handle timeouts', async () => { it('should handle timeouts', async () => {
vi.mocked(global.fetch).mockImplementation( vi.mocked(global.fetch).mockImplementation(
@@ -106,9 +167,10 @@ describe('fetch utils', () => {
new Promise((_resolve, reject) => { new Promise((_resolve, reject) => {
if (init?.signal) { if (init?.signal) {
init.signal.addEventListener('abort', () => { init.signal.addEventListener('abort', () => {
const error = new Error('The operation was aborted'); const error = new Error(
'The operation was aborted',
) as ErrorWithCode;
error.name = 'AbortError'; error.name = 'AbortError';
// @ts-expect-error - for mocking purposes
error.code = 'ABORT_ERR'; error.code = 'ABORT_ERR';
reject(error); reject(error);
}); });
+32
View File
@@ -8,6 +8,7 @@ import { getErrorMessage, isNodeError } from './errors.js';
import { URL } from 'node:url'; import { URL } from 'node:url';
import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici'; import { Agent, ProxyAgent, setGlobalDispatcher } from 'undici';
import ipaddr from 'ipaddr.js'; import ipaddr from 'ipaddr.js';
import { lookup } from 'node:dns/promises';
const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes const DEFAULT_HEADERS_TIMEOUT = 300000; // 5 minutes
const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes const DEFAULT_BODY_TIMEOUT = 300000; // 5 minutes
@@ -23,6 +24,13 @@ export class FetchError extends Error {
} }
} }
export class PrivateIpError extends Error {
constructor(message = 'Access to private network is blocked') {
super(message);
this.name = 'PrivateIpError';
}
}
// Configure default global dispatcher with higher timeouts // Configure default global dispatcher with higher timeouts
setGlobalDispatcher( setGlobalDispatcher(
new Agent({ new Agent({
@@ -115,6 +123,30 @@ export function isAddressPrivate(address: string): boolean {
} }
} }
/**
* Checks if a URL resolves to a private IP address.
*/
export async function isPrivateIpAsync(url: string): Promise<boolean> {
try {
const parsedUrl = new URL(url);
const hostname = parsedUrl.hostname;
if (isLoopbackHost(hostname)) {
return false;
}
const addresses = await lookup(hostname, { all: true });
return addresses.some((addr) => isAddressPrivate(addr.address));
} catch (error) {
if (error instanceof TypeError) {
return false;
}
throw new Error('Failed to verify if URL resolves to private IP', {
cause: error,
});
}
}
/** /**
* Creates an undici ProxyAgent that incorporates safe DNS lookup. * Creates an undici ProxyAgent that incorporates safe DNS lookup.
*/ */