feat(core): implement Stage 2 security and consistency improvements for web_fetch (#22217)

This commit is contained in:
Aishanee Shah
2026-03-16 17:38:53 -04:00
committed by GitHub
parent b6c6da3618
commit 990d010ecf
4 changed files with 250 additions and 86 deletions
+139 -69
View File
@@ -40,7 +40,7 @@ import { LRUCache } from 'mnemonist';
import type { AgentLoopContext } from '../config/agent-loop-context.js';
const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 100000;
const MAX_CONTENT_LENGTH = 250000;
const MAX_EXPERIMENTAL_FETCH_SIZE = 10 * 1024 * 1024; // 10MB
const USER_AGENT =
'Mozilla/5.0 (compatible; Google-Gemini-CLI/1.0; +https://github.com/google-gemini/gemini-cli)';
@@ -190,6 +190,18 @@ function isGroundingSupportItem(item: unknown): item is GroundingSupportItem {
return typeof item === 'object' && item !== null;
}
/**
* Sanitizes text for safe embedding in XML tags.
*/
function sanitizeXml(text: string): string {
return text
.replace(/&/g, '&')
.replace(/</g, '&lt;')
.replace(/>/g, '&gt;')
.replace(/"/g, '&quot;')
.replace(/'/g, '&apos;');
}
/**
* Parameters for the WebFetch tool
*/
@@ -263,69 +275,65 @@ class WebFetchToolInvocation extends BaseToolInvocation<
private async executeFallbackForUrl(
urlStr: string,
signal: AbortSignal,
contentBudget: number,
): Promise<string> {
const url = convertGithubUrlToRaw(urlStr);
if (this.isBlockedHost(url)) {
debugLogger.warn(`[WebFetchTool] Blocked access to host: ${url}`);
return `Error fetching ${url}: Access to blocked or private host is not allowed.`;
throw new Error(
`Access to blocked or private host ${url} is not allowed.`,
);
}
try {
const response = await retryWithBackoff(
async () => {
const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, {
signal,
headers: {
'User-Agent': USER_AGENT,
},
});
if (!res.ok) {
const error = new Error(
`Request failed with status code ${res.status} ${res.statusText}`,
);
(error as ErrorWithStatus).status = res.status;
throw error;
}
return res;
},
{
retryFetchErrors: this.context.config.getRetryFetchErrors(),
onRetry: (attempt, error, delayMs) =>
this.handleRetry(attempt, error, delayMs),
const response = await retryWithBackoff(
async () => {
const res = await fetchWithTimeout(url, URL_FETCH_TIMEOUT_MS, {
signal,
},
);
const bodyBuffer = await this.readResponseWithLimit(
response,
MAX_EXPERIMENTAL_FETCH_SIZE,
);
const rawContent = bodyBuffer.toString('utf8');
const contentType = response.headers.get('content-type') || '';
let textContent: string;
// Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML)
if (
contentType.toLowerCase().includes('text/html') ||
contentType === ''
) {
textContent = convert(rawContent, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
headers: {
'User-Agent': USER_AGENT,
},
});
} else {
// For other content types (text/plain, application/json, etc.), use raw text
textContent = rawContent;
}
if (!res.ok) {
const error = new Error(
`Request failed with status code ${res.status} ${res.statusText}`,
);
(error as ErrorWithStatus).status = res.status;
throw error;
}
return res;
},
{
retryFetchErrors: this.context.config.getRetryFetchErrors(),
onRetry: (attempt, error, delayMs) =>
this.handleRetry(attempt, error, delayMs),
signal,
},
);
return truncateString(textContent, contentBudget, TRUNCATION_WARNING);
} catch (e) {
return `Error fetching ${url}: ${getErrorMessage(e)}`;
const bodyBuffer = await this.readResponseWithLimit(
response,
MAX_EXPERIMENTAL_FETCH_SIZE,
);
const rawContent = bodyBuffer.toString('utf8');
const contentType = response.headers.get('content-type') || '';
let textContent: string;
// Only use html-to-text if content type is HTML, or if no content type is provided (assume HTML)
if (contentType.toLowerCase().includes('text/html') || contentType === '') {
textContent = convert(rawContent, {
wordwrap: false,
selectors: [
{ selector: 'a', options: { ignoreHref: true } },
{ selector: 'img', format: 'skip' },
],
});
} else {
// For other content types (text/plain, application/json, etc.), use raw text
textContent = rawContent;
}
// Cap at MAX_CONTENT_LENGTH initially to avoid excessive memory usage
// before the global budget allocation.
return truncateString(textContent, MAX_CONTENT_LENGTH, '');
}
private filterAndValidateUrls(urls: string[]): {
@@ -363,30 +371,82 @@ class WebFetchToolInvocation extends BaseToolInvocation<
signal: AbortSignal,
): Promise<ToolResult> {
const uniqueUrls = [...new Set(urls)];
const contentBudget = Math.floor(
MAX_CONTENT_LENGTH / (uniqueUrls.length || 1),
);
const results: string[] = [];
const successes: Array<{ url: string; content: string }> = [];
const errors: Array<{ url: string; message: string }> = [];
for (const url of uniqueUrls) {
results.push(
await this.executeFallbackForUrl(url, signal, contentBudget),
);
try {
const content = await this.executeFallbackForUrl(url, signal);
successes.push({ url, content });
} catch (e) {
errors.push({ url, message: getErrorMessage(e) });
}
}
const aggregatedContent = results
.map((content, i) => `URL: ${uniqueUrls[i]}\nContent:\n${content}`)
.join('\n\n---\n\n');
// Change 2: Short-circuit on total failure
if (successes.length === 0) {
const errorMessage = `All fallback fetch attempts failed: ${errors
.map((e) => `${e.url}: ${e.message}`)
.join(', ')}`;
debugLogger.error(`[WebFetchTool] ${errorMessage}`);
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
error: {
message: errorMessage,
type: ToolErrorType.WEB_FETCH_FALLBACK_FAILED,
},
};
}
// Smart Budget Allocation (Water-filling algorithm) for successes
const sortedSuccesses = [...successes].sort(
(a, b) => a.content.length - b.content.length,
);
let remainingBudget = MAX_CONTENT_LENGTH;
let remainingUrls = sortedSuccesses.length;
const finalContentsByUrl = new Map<string, string>();
for (const success of sortedSuccesses) {
const fairShare = Math.floor(remainingBudget / remainingUrls);
const allocated = Math.min(success.content.length, fairShare);
const truncated = truncateString(
success.content,
allocated,
TRUNCATION_WARNING,
);
finalContentsByUrl.set(success.url, truncated);
remainingBudget -= truncated.length;
remainingUrls--;
}
const aggregatedContent = uniqueUrls
.map((url) => {
const content = finalContentsByUrl.get(url);
if (content !== undefined) {
return `<source url="${sanitizeXml(url)}">\n${sanitizeXml(content)}\n</source>`;
}
const error = errors.find((e) => e.url === url);
return `<source url="${sanitizeXml(url)}">\nError: ${sanitizeXml(error?.message || 'Unknown error')}\n</source>`;
})
.join('\n');
try {
const geminiClient = this.context.geminiClient;
const fallbackPrompt = `The user requested the following: "${this.params.prompt}".
const fallbackPrompt = `Follow the user's instructions below using the provided webpage content.
<user_instructions>
${sanitizeXml(this.params.prompt ?? '')}
</user_instructions>
I was unable to access the URL(s) directly using the primary fetch tool. Instead, I have fetched the raw content of the page(s). Please use the following content to answer the request. Do not attempt to access the URL(s) again.
---
<content>
${aggregatedContent}
---
</content>
`;
const result = await geminiClient.generateContent(
{ model: 'web-fetch-fallback' },
@@ -716,9 +776,19 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
try {
const geminiClient = this.context.geminiClient;
const sanitizedPrompt = `Follow the user's instructions to process the authorized URLs.
<user_instructions>
${sanitizeXml(userPrompt)}
</user_instructions>
<authorized_urls>
${toFetch.join('\n')}
</authorized_urls>
`;
const response = await geminiClient.generateContent(
{ model: 'web-fetch' },
[{ role: 'user', parts: [{ text: userPrompt }] }],
[{ role: 'user', parts: [{ text: sanitizedPrompt }] }],
signal,
LlmRole.UTILITY_TOOL,
);
@@ -870,7 +940,7 @@ export class WebFetchTool extends BaseDeclarativeTool<
_toolDisplayName?: string,
): ToolInvocation<WebFetchToolParams, ToolResult> {
return new WebFetchToolInvocation(
this.context.config,
this.context,
params,
messageBus,
_toolName,