fix(core): implement robust URL validation in web_fetch tool (#10834)

This commit is contained in:
Abhi
2025-10-14 16:53:22 -04:00
committed by GitHub
parent 769fe8b161
commit 6f0107e7b7
2 changed files with 159 additions and 21 deletions

View File

@@ -5,7 +5,7 @@
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';
import { WebFetchTool } from './web-fetch.js';
import { WebFetchTool, parsePrompt } from './web-fetch.js';
import type { Config } from '../config/config.js';
import { ApprovalMode } from '../config/config.js';
import { ToolConfirmationOutcome } from './tools.js';
@@ -35,6 +35,87 @@ vi.mock('../utils/fetch.js', async (importOriginal) => {
};
});
describe('parsePrompt', () => {
it('should extract valid URLs separated by whitespace', () => {
const prompt = 'Go to https://example.com and http://google.com';
const { validUrls, errors } = parsePrompt(prompt);
expect(errors).toHaveLength(0);
expect(validUrls).toHaveLength(2);
expect(validUrls[0]).toBe('https://example.com/');
expect(validUrls[1]).toBe('http://google.com/');
});
it('should accept URLs with trailing punctuation', () => {
const prompt = 'Check https://example.com.';
const { validUrls, errors } = parsePrompt(prompt);
expect(errors).toHaveLength(0);
expect(validUrls).toHaveLength(1);
expect(validUrls[0]).toBe('https://example.com./');
});
it('should detect URLs wrapped in punctuation as malformed', () => {
const prompt = 'Read (https://example.com)';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Malformed URL detected');
expect(errors[0]).toContain('(https://example.com)');
});
it('should detect unsupported protocols (httpshttps://)', () => {
const prompt =
'Summarize httpshttps://github.com/JuliaLang/julia/issues/58346';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Unsupported protocol');
expect(errors[0]).toContain(
'httpshttps://github.com/JuliaLang/julia/issues/58346',
);
});
it('should detect unsupported protocols (ftp://)', () => {
const prompt = 'ftp://example.com/file.txt';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Unsupported protocol');
});
it('should detect malformed URLs', () => {
// http:// is not a valid URL in Node's new URL()
const prompt = 'http://';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('Malformed URL detected');
});
it('should handle prompts with no URLs', () => {
const prompt = 'hello world';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(0);
expect(errors).toHaveLength(0);
});
it('should handle mixed valid and invalid URLs', () => {
const prompt = 'Valid: https://google.com, Invalid: ftp://bad.com';
const { validUrls, errors } = parsePrompt(prompt);
expect(validUrls).toHaveLength(1);
expect(validUrls[0]).toBe('https://google.com,/');
expect(errors).toHaveLength(1);
expect(errors[0]).toContain('ftp://bad.com');
});
});
describe('WebFetchTool', () => {
let mockConfig: Config;
@@ -48,16 +129,36 @@ describe('WebFetchTool', () => {
} as unknown as Config;
});
describe('execute', () => {
it('should return WEB_FETCH_NO_URL_IN_PROMPT when no URL is in the prompt for fallback', async () => {
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
describe('validateToolParamValues', () => {
it('should throw if prompt is empty', () => {
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'no url here' };
expect(() => tool.build(params)).toThrow(
"The 'prompt' must contain at least one valid URL (starting with http:// or https://).",
expect(() => tool.build({ prompt: '' })).toThrow(
"The 'prompt' parameter cannot be empty",
);
});
it('should throw if prompt contains no URLs', () => {
const tool = new WebFetchTool(mockConfig);
expect(() => tool.build({ prompt: 'hello world' })).toThrow(
"The 'prompt' must contain at least one valid URL",
);
});
it('should throw if prompt contains malformed URLs (httpshttps://)', () => {
const tool = new WebFetchTool(mockConfig);
const prompt = 'fetch httpshttps://example.com';
expect(() => tool.build({ prompt })).toThrow('Error(s) in prompt URLs:');
});
it('should pass if prompt contains at least one valid URL', () => {
const tool = new WebFetchTool(mockConfig);
expect(() =>
tool.build({ prompt: 'fetch https://example.com' }),
).not.toThrow();
});
});
describe('execute', () => {
it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => {
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockRejectedValue(
@@ -135,7 +236,7 @@ describe('WebFetchTool', () => {
});
describe('shouldConfirmExecute', () => {
it('should return confirmation details with the correct prompt and urls', async () => {
it('should return confirmation details with the correct prompt and parsed urls', async () => {
const tool = new WebFetchTool(mockConfig);
const params = { prompt: 'fetch https://example.com' };
const invocation = tool.build(params);
@@ -147,7 +248,7 @@ describe('WebFetchTool', () => {
type: 'info',
title: 'Confirm Web Fetch',
prompt: 'fetch https://example.com',
urls: ['https://example.com'],
urls: ['https://example.com/'],
onConfirm: expect.any(Function),
});
});

View File

@@ -31,10 +31,42 @@ import {
const URL_FETCH_TIMEOUT_MS = 10000;
const MAX_CONTENT_LENGTH = 100000;
// Helper function to extract URLs from a string
function extractUrls(text: string): string[] {
const urlRegex = /(https?:\/\/[^\s]+)/g;
return text.match(urlRegex) || [];
/**
* Parses a prompt to extract valid URLs and identify malformed ones.
*/
export function parsePrompt(text: string): {
validUrls: string[];
errors: string[];
} {
const tokens = text.split(/\s+/);
const validUrls: string[] = [];
const errors: string[] = [];
for (const token of tokens) {
if (!token) continue;
// Heuristic to check if the url appears to contain URL-like chars.
if (token.includes('://')) {
try {
// Validate with new URL()
const url = new URL(token);
// Allowlist protocols
if (['http:', 'https:'].includes(url.protocol)) {
validUrls.push(url.href);
} else {
errors.push(
`Unsupported protocol in URL: "${token}". Only http and https are supported.`,
);
}
} catch (_) {
// new URL() threw, so it's malformed according to WHATWG standard
errors.push(`Malformed URL detected: "${token}".`);
}
}
}
return { validUrls, errors };
}
// Interfaces for grounding metadata (similar to web-search.ts)
@@ -80,7 +112,7 @@ class WebFetchToolInvocation extends BaseToolInvocation<
}
private async executeFallback(signal: AbortSignal): Promise<ToolResult> {
const urls = extractUrls(this.params.prompt);
const { validUrls: urls } = parsePrompt(this.params.prompt);
// For now, we only support one URL for fallback
let url = urls[0];
@@ -158,7 +190,8 @@ ${textContent}
// Perform GitHub URL conversion here to differentiate between user-provided
// URL and the actual URL to be fetched.
const urls = extractUrls(this.params.prompt).map((url) => {
const { validUrls } = parsePrompt(this.params.prompt);
const urls = validUrls.map((url) => {
if (url.includes('github.com') && url.includes('/blob/')) {
return url
.replace('github.com', 'raw.githubusercontent.com')
@@ -183,7 +216,7 @@ ${textContent}
async execute(signal: AbortSignal): Promise<ToolResult> {
const userPrompt = this.params.prompt;
const urls = extractUrls(userPrompt);
const { validUrls: urls } = parsePrompt(userPrompt);
const url = urls[0];
const isPrivate = isPrivateIp(url);
@@ -312,7 +345,6 @@ ${sourceListFormatted.join('\n')}`;
0,
50,
)}...": ${getErrorMessage(error)}`;
console.error(errorMessage, error);
return {
llmContent: `Error: ${errorMessage}`,
returnDisplay: `Error: ${errorMessage}`,
@@ -364,12 +396,17 @@ export class WebFetchTool extends BaseDeclarativeTool<
if (!params.prompt || params.prompt.trim() === '') {
return "The 'prompt' parameter cannot be empty and must contain URL(s) and instructions.";
}
if (
!params.prompt.includes('http://') &&
!params.prompt.includes('https://')
) {
const { validUrls, errors } = parsePrompt(params.prompt);
if (errors.length > 0) {
return `Error(s) in prompt URLs:\n- ${errors.join('\n- ')}`;
}
if (validUrls.length === 0) {
return "The 'prompt' must contain at least one valid URL (starting with http:// or https://).";
}
return null;
}