diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index f6d74d0953..d5fd4310f6 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -9,6 +9,7 @@ import { WebFetchTool, parsePrompt, convertGithubUrlToRaw, + normalizeUrl, } from './web-fetch.js'; import type { Config } from '../config/config.js'; import { ApprovalMode } from '../policy/types.js'; @@ -125,6 +126,35 @@ const mockFetch = (url: string, response: Partial | Error) => } as unknown as Response; }); +describe('normalizeUrl', () => { + it('should lowercase hostname', () => { + expect(normalizeUrl('https://EXAMPLE.com/Path')).toBe( + 'https://example.com/Path', + ); + }); + + it('should remove trailing slash except for root', () => { + expect(normalizeUrl('https://example.com/path/')).toBe( + 'https://example.com/path', + ); + expect(normalizeUrl('https://example.com/')).toBe('https://example.com/'); + }); + + it('should remove default ports', () => { + expect(normalizeUrl('http://example.com:80/')).toBe('http://example.com/'); + expect(normalizeUrl('https://example.com:443/')).toBe( + 'https://example.com/', + ); + expect(normalizeUrl('https://example.com:8443/')).toBe( + 'https://example.com:8443/', + ); + }); + + it('should handle invalid URLs gracefully', () => { + expect(normalizeUrl('not-a-url')).toBe('not-a-url'); + }); +}); + describe('parsePrompt', () => { it('should extract valid URLs separated by whitespace', () => { const prompt = 'Go to https://example.com and http://google.com'; @@ -396,6 +426,35 @@ describe('WebFetchTool', () => { ); }); + it('should skip private or local URLs but fetch others', async () => { + vi.mocked(fetchUtils.isPrivateIp).mockImplementation( + (url) => url === 'https://private.com/', + ); + + const tool = new WebFetchTool(mockConfig, bus); + const params = { + prompt: + 'fetch https://private.com and https://healthy.com and http://localhost', + }; + const invocation = tool.build(params); + + mockGenerateContent.mockResolvedValueOnce({ + candidates: [{ content: { parts: [{ text: 'healthy response' }] } }], + }); + + const result = await invocation.execute(new AbortController().signal); + expect(result.llmContent).toContain('healthy response'); + expect(result.llmContent).toContain( + '[Warning] The following URLs were skipped:', + ); + expect(result.llmContent).toContain( + '[Private or Local Host] https://private.com/', + ); + expect(result.llmContent).toContain( + '[Private or Local Host] http://localhost', + ); + }); + it('should fallback to all public URLs if primary fails', async () => { vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); @@ -419,7 +478,7 @@ describe('WebFetchTool', () => { const tool = new WebFetchTool(mockConfig, bus); const params = { - prompt: 'fetch https://url1.com and https://url2.com', + prompt: 'fetch https://url1.com and https://url2.com/', }; const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); @@ -450,7 +509,7 @@ describe('WebFetchTool', () => { const tool = new WebFetchTool(mockConfig, bus); const params = { - prompt: 'fetch https://public.com and https://private.com', + prompt: 'fetch https://public.com/ and https://private.com', }; const invocation = tool.build(params); const result = await invocation.execute(new AbortController().signal); @@ -971,13 +1030,13 @@ describe('WebFetchTool', () => { }); it('should throw error if stream exceeds limit', async () => { - const largeChunk = new Uint8Array(11 * 1024 * 1024); + const large_chunk = new Uint8Array(11 * 1024 * 1024); mockFetch('https://example.com/large-stream', { body: { getReader: () => ({ read: vi .fn() - .mockResolvedValueOnce({ done: false, value: largeChunk }) + .mockResolvedValueOnce({ done: false, value: large_chunk }) .mockResolvedValueOnce({ done: true }), releaseLock: vi.fn(), cancel: vi.fn().mockResolvedValue(undefined), @@ -1025,7 +1084,7 @@ describe('WebFetchTool', () => { const result = await invocation.execute(new AbortController().signal); expect(result.llmContent).toContain( - 'Error: Access to private IP address http://localhost/ is not allowed.', + 'Error: Access to blocked or private host http://localhost/ is not allowed.', ); expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR); }); diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index b31220f459..076eea88c6 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -76,6 +76,31 @@ function checkRateLimit(url: string): { } } +/** + * Normalizes a URL by converting hostname to lowercase, removing trailing slashes, + * and removing default ports. + */ +export function normalizeUrl(urlStr: string): string { + try { + const url = new URL(urlStr); + url.hostname = url.hostname.toLowerCase(); + // Remove trailing slash if present in pathname (except for root '/') + if (url.pathname.endsWith('/') && url.pathname.length > 1) { + url.pathname = url.pathname.slice(0, -1); + } + // Remove default ports + if ( + (url.protocol === 'http:' && url.port === '80') || + (url.protocol === 'https:' && url.port === '443') + ) { + url.port = ''; + } + return url.href; + } catch { + return urlStr; + } +} + /** * Parses a prompt to extract valid URLs and identify malformed ones. */ @@ -184,14 +209,27 @@ class WebFetchToolInvocation extends BaseToolInvocation< super(params, messageBus, _toolName, _toolDisplayName); } + private isBlockedHost(urlStr: string): boolean { + try { + const url = new URL(urlStr); + const hostname = url.hostname.toLowerCase(); + if (hostname === 'localhost' || hostname === '127.0.0.1') { + return true; + } + return isPrivateIp(urlStr); + } catch { + return true; + } + } + private async executeFallbackForUrl( urlStr: string, signal: AbortSignal, contentBudget: number, ): Promise { const url = convertGithubUrlToRaw(urlStr); - if (isPrivateIp(url)) { - return `Error fetching ${url}: Access to private IP address is not allowed.`; + if (this.isBlockedHost(url)) { + return `Error fetching ${url}: Access to blocked or private host is not allowed.`; } try { @@ -245,9 +283,7 @@ class WebFetchToolInvocation extends BaseToolInvocation< return truncateString(textContent, contentBudget, TRUNCATION_WARNING); } catch (e) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const error = e as Error; - return `Error fetching ${url}: ${error.message}`; + return `Error fetching ${url}: ${getErrorMessage(e)}`; } } @@ -290,9 +326,7 @@ ${aggregatedContent} returnDisplay: `Content for ${urls.length} URL(s) processed using fallback fetch.`, }; } catch (e) { - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - const error = e as Error; - const errorMessage = `Error during fallback processing: ${error.message}`; + const errorMessage = `Error during fallback processing: ${getErrorMessage(e)}`; return { llmContent: `Error: ${errorMessage}`, returnDisplay: `Error: ${errorMessage}`, @@ -413,8 +447,8 @@ ${aggregatedContent} // Convert GitHub blob URL to raw URL url = convertGithubUrlToRaw(url); - if (isPrivateIp(url)) { - const errorMessage = `Access to private IP address ${url} is not allowed.`; + if (this.isBlockedHost(url)) { + const errorMessage = `Access to blocked or private host ${url} is not allowed.`; return { llmContent: `Error: ${errorMessage}`, returnDisplay: `Error: ${errorMessage}`, @@ -548,14 +582,14 @@ Response: ${truncateString(rawResponseText, 10000, '\n\n... [Error response trun const userPrompt = this.params.prompt!; const { validUrls } = parsePrompt(userPrompt); - // Filter unique URLs and perform pre-flight checks (Rate Limit & Private IP) - const uniqueUrls = [...new Set(validUrls)]; + // Unit 1: Normalization & Deduplication + const uniqueUrls = [...new Set(validUrls.map(normalizeUrl))]; const toFetch: string[] = []; const skipped: string[] = []; for (const url of uniqueUrls) { - if (isPrivateIp(url)) { - skipped.push(`[Private IP] ${url}`); + if (this.isBlockedHost(url)) { + skipped.push(`[Private or Local Host] ${url}`); continue; } if (!checkRateLimit(url).allowed) {