mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
refactor(core): harden webfetch security and error handling
- Implemented strict host blocking for localhost/127.0.0.1 in all paths. - Applied consistent URL normalization and deduplication. - Standardized error extraction using getErrorMessage. - Updated warning strings for better transparency on skipped URLs.
This commit is contained in:
@@ -9,6 +9,7 @@ import {
|
||||
WebFetchTool,
|
||||
parsePrompt,
|
||||
convertGithubUrlToRaw,
|
||||
normalizeUrl,
|
||||
} from './web-fetch.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
@@ -125,6 +126,35 @@ const mockFetch = (url: string, response: Partial<Response> | Error) =>
|
||||
} as unknown as Response;
|
||||
});
|
||||
|
||||
describe('normalizeUrl', () => {
|
||||
it('should lowercase hostname', () => {
|
||||
expect(normalizeUrl('https://EXAMPLE.com/Path')).toBe(
|
||||
'https://example.com/Path',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove trailing slash except for root', () => {
|
||||
expect(normalizeUrl('https://example.com/path/')).toBe(
|
||||
'https://example.com/path',
|
||||
);
|
||||
expect(normalizeUrl('https://example.com/')).toBe('https://example.com/');
|
||||
});
|
||||
|
||||
it('should remove default ports', () => {
|
||||
expect(normalizeUrl('http://example.com:80/')).toBe('http://example.com/');
|
||||
expect(normalizeUrl('https://example.com:443/')).toBe(
|
||||
'https://example.com/',
|
||||
);
|
||||
expect(normalizeUrl('https://example.com:8443/')).toBe(
|
||||
'https://example.com:8443/',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle invalid URLs gracefully', () => {
|
||||
expect(normalizeUrl('not-a-url')).toBe('not-a-url');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parsePrompt', () => {
|
||||
it('should extract valid URLs separated by whitespace', () => {
|
||||
const prompt = 'Go to https://example.com and http://google.com';
|
||||
@@ -396,6 +426,35 @@ describe('WebFetchTool', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should skip private or local URLs but fetch others', async () => {
|
||||
vi.mocked(fetchUtils.isPrivateIp).mockImplementation(
|
||||
(url) => url === 'https://private.com/',
|
||||
);
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt:
|
||||
'fetch https://private.com and https://healthy.com and http://localhost',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'healthy response' }] } }],
|
||||
});
|
||||
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.llmContent).toContain('healthy response');
|
||||
expect(result.llmContent).toContain(
|
||||
'[Warning] The following URLs were skipped:',
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
'[Private or Local Host] https://private.com/',
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
'[Private or Local Host] http://localhost',
|
||||
);
|
||||
});
|
||||
|
||||
it('should fallback to all public URLs if primary fails', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
|
||||
@@ -419,7 +478,7 @@ describe('WebFetchTool', () => {
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt: 'fetch https://url1.com and https://url2.com',
|
||||
prompt: 'fetch https://url1.com and https://url2.com/',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
@@ -450,7 +509,7 @@ describe('WebFetchTool', () => {
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt: 'fetch https://public.com and https://private.com',
|
||||
prompt: 'fetch https://public.com/ and https://private.com',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
@@ -971,13 +1030,13 @@ describe('WebFetchTool', () => {
|
||||
});
|
||||
|
||||
it('should throw error if stream exceeds limit', async () => {
|
||||
const largeChunk = new Uint8Array(11 * 1024 * 1024);
|
||||
const large_chunk = new Uint8Array(11 * 1024 * 1024);
|
||||
mockFetch('https://example.com/large-stream', {
|
||||
body: {
|
||||
getReader: () => ({
|
||||
read: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ done: false, value: largeChunk })
|
||||
.mockResolvedValueOnce({ done: false, value: large_chunk })
|
||||
.mockResolvedValueOnce({ done: true }),
|
||||
releaseLock: vi.fn(),
|
||||
cancel: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -1025,7 +1084,7 @@ describe('WebFetchTool', () => {
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toContain(
|
||||
'Error: Access to private IP address http://localhost/ is not allowed.',
|
||||
'Error: Access to blocked or private host http://localhost/ is not allowed.',
|
||||
);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR);
|
||||
});
|
||||
|
||||
@@ -76,6 +76,31 @@ function checkRateLimit(url: string): {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalizes a URL by converting hostname to lowercase, removing trailing slashes,
|
||||
* and removing default ports.
|
||||
*/
|
||||
export function normalizeUrl(urlStr: string): string {
|
||||
try {
|
||||
const url = new URL(urlStr);
|
||||
url.hostname = url.hostname.toLowerCase();
|
||||
// Remove trailing slash if present in pathname (except for root '/')
|
||||
if (url.pathname.endsWith('/') && url.pathname.length > 1) {
|
||||
url.pathname = url.pathname.slice(0, -1);
|
||||
}
|
||||
// Remove default ports
|
||||
if (
|
||||
(url.protocol === 'http:' && url.port === '80') ||
|
||||
(url.protocol === 'https:' && url.port === '443')
|
||||
) {
|
||||
url.port = '';
|
||||
}
|
||||
return url.href;
|
||||
} catch {
|
||||
return urlStr;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a prompt to extract valid URLs and identify malformed ones.
|
||||
*/
|
||||
@@ -184,14 +209,27 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
super(params, messageBus, _toolName, _toolDisplayName);
|
||||
}
|
||||
|
||||
private isBlockedHost(urlStr: string): boolean {
|
||||
try {
|
||||
const url = new URL(urlStr);
|
||||
const hostname = url.hostname.toLowerCase();
|
||||
if (hostname === 'localhost' || hostname === '127.0.0.1') {
|
||||
return true;
|
||||
}
|
||||
return isPrivateIp(urlStr);
|
||||
} catch {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private async executeFallbackForUrl(
|
||||
urlStr: string,
|
||||
signal: AbortSignal,
|
||||
contentBudget: number,
|
||||
): Promise<string> {
|
||||
const url = convertGithubUrlToRaw(urlStr);
|
||||
if (isPrivateIp(url)) {
|
||||
return `Error fetching ${url}: Access to private IP address is not allowed.`;
|
||||
if (this.isBlockedHost(url)) {
|
||||
return `Error fetching ${url}: Access to blocked or private host is not allowed.`;
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -245,9 +283,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
|
||||
|
||||
return truncateString(textContent, contentBudget, TRUNCATION_WARNING);
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const error = e as Error;
|
||||
return `Error fetching ${url}: ${error.message}`;
|
||||
return `Error fetching ${url}: ${getErrorMessage(e)}`;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,9 +326,7 @@ ${aggregatedContent}
|
||||
returnDisplay: `Content for ${urls.length} URL(s) processed using fallback fetch.`,
|
||||
};
|
||||
} catch (e) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion
|
||||
const error = e as Error;
|
||||
const errorMessage = `Error during fallback processing: ${error.message}`;
|
||||
const errorMessage = `Error during fallback processing: ${getErrorMessage(e)}`;
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
@@ -413,8 +447,8 @@ ${aggregatedContent}
|
||||
// Convert GitHub blob URL to raw URL
|
||||
url = convertGithubUrlToRaw(url);
|
||||
|
||||
if (isPrivateIp(url)) {
|
||||
const errorMessage = `Access to private IP address ${url} is not allowed.`;
|
||||
if (this.isBlockedHost(url)) {
|
||||
const errorMessage = `Access to blocked or private host ${url} is not allowed.`;
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
@@ -548,14 +582,14 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun
|
||||
const userPrompt = this.params.prompt!;
|
||||
const { validUrls } = parsePrompt(userPrompt);
|
||||
|
||||
// Filter unique URLs and perform pre-flight checks (Rate Limit & Private IP)
|
||||
const uniqueUrls = [...new Set(validUrls)];
|
||||
// Unit 1: Normalization & Deduplication
|
||||
const uniqueUrls = [...new Set(validUrls.map(normalizeUrl))];
|
||||
const toFetch: string[] = [];
|
||||
const skipped: string[] = [];
|
||||
|
||||
for (const url of uniqueUrls) {
|
||||
if (isPrivateIp(url)) {
|
||||
skipped.push(`[Private IP] ${url}`);
|
||||
if (this.isBlockedHost(url)) {
|
||||
skipped.push(`[Private or Local Host] ${url}`);
|
||||
continue;
|
||||
}
|
||||
if (!checkRateLimit(url).allowed) {
|
||||
|
||||
Reference in New Issue
Block a user