mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-28 14:04:41 -07:00
feat(core): implement Stage 1 improvements for webfetch tool (#21313)
This commit is contained in:
@@ -9,6 +9,7 @@ import {
|
||||
WebFetchTool,
|
||||
parsePrompt,
|
||||
convertGithubUrlToRaw,
|
||||
normalizeUrl,
|
||||
} from './web-fetch.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { ApprovalMode } from '../policy/types.js';
|
||||
@@ -43,7 +44,7 @@ vi.mock('html-to-text', () => ({
|
||||
|
||||
vi.mock('../telemetry/index.js', () => ({
|
||||
logWebFetchFallbackAttempt: vi.fn(),
|
||||
WebFetchFallbackAttemptEvent: vi.fn(),
|
||||
WebFetchFallbackAttemptEvent: vi.fn((reason) => ({ reason })),
|
||||
}));
|
||||
|
||||
vi.mock('../utils/fetch.js', async (importOriginal) => {
|
||||
@@ -125,6 +126,35 @@ const mockFetch = (url: string, response: Partial<Response> | Error) =>
|
||||
} as unknown as Response;
|
||||
});
|
||||
|
||||
describe('normalizeUrl', () => {
|
||||
it('should lowercase hostname', () => {
|
||||
expect(normalizeUrl('https://EXAMPLE.com/Path')).toBe(
|
||||
'https://example.com/Path',
|
||||
);
|
||||
});
|
||||
|
||||
it('should remove trailing slash except for root', () => {
|
||||
expect(normalizeUrl('https://example.com/path/')).toBe(
|
||||
'https://example.com/path',
|
||||
);
|
||||
expect(normalizeUrl('https://example.com/')).toBe('https://example.com/');
|
||||
});
|
||||
|
||||
it('should remove default ports', () => {
|
||||
expect(normalizeUrl('http://example.com:80/')).toBe('http://example.com/');
|
||||
expect(normalizeUrl('https://example.com:443/')).toBe(
|
||||
'https://example.com/',
|
||||
);
|
||||
expect(normalizeUrl('https://example.com:8443/')).toBe(
|
||||
'https://example.com:8443/',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle invalid URLs gracefully', () => {
|
||||
expect(normalizeUrl('not-a-url')).toBe('not-a-url');
|
||||
});
|
||||
});
|
||||
|
||||
describe('parsePrompt', () => {
|
||||
it('should extract valid URLs separated by whitespace', () => {
|
||||
const prompt = 'Go to https://example.com and http://google.com';
|
||||
@@ -355,49 +385,164 @@ describe('WebFetchTool', () => {
|
||||
// The 11th time should fail due to rate limit
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR);
|
||||
expect(result.error?.message).toContain('Rate limit exceeded for host');
|
||||
expect(result.error?.message).toContain(
|
||||
'All requested URLs were skipped',
|
||||
);
|
||||
});
|
||||
|
||||
it('should skip rate-limited URLs but fetch others', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt: 'fetch https://ratelimit-multi.com and https://healthy.com',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
|
||||
// Hit rate limit for one host
|
||||
for (let i = 0; i < 10; i++) {
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'response' }] } }],
|
||||
});
|
||||
await tool
|
||||
.build({ prompt: 'fetch https://ratelimit-multi.com' })
|
||||
.execute(new AbortController().signal);
|
||||
}
|
||||
// 11th call - should be rate limited and not use a mock
|
||||
await tool
|
||||
.build({ prompt: 'fetch https://ratelimit-multi.com' })
|
||||
.execute(new AbortController().signal);
|
||||
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'healthy response' }] } }],
|
||||
});
|
||||
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.llmContent).toContain('healthy response');
|
||||
expect(result.llmContent).toContain(
|
||||
'[Warning] The following URLs were skipped:',
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
'[Rate limit exceeded] https://ratelimit-multi.com/',
|
||||
);
|
||||
});
|
||||
|
||||
it('should skip private or local URLs but fetch others and log telemetry', async () => {
|
||||
vi.mocked(fetchUtils.isPrivateIp).mockImplementation(
|
||||
(url) => url === 'https://private.com/',
|
||||
);
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt:
|
||||
'fetch https://private.com and https://healthy.com and http://localhost',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'healthy response' }] } }],
|
||||
});
|
||||
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(logWebFetchFallbackAttempt).toHaveBeenCalledTimes(2);
|
||||
expect(logWebFetchFallbackAttempt).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.objectContaining({ reason: 'private_ip_skipped' }),
|
||||
);
|
||||
|
||||
expect(result.llmContent).toContain('healthy response');
|
||||
expect(result.llmContent).toContain(
|
||||
'[Warning] The following URLs were skipped:',
|
||||
);
|
||||
expect(result.llmContent).toContain(
|
||||
'[Blocked Host] https://private.com/',
|
||||
);
|
||||
expect(result.llmContent).toContain('[Blocked Host] http://localhost');
|
||||
});
|
||||
|
||||
it('should fallback to all public URLs if primary fails', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
|
||||
// Primary fetch fails
|
||||
mockGenerateContent.mockRejectedValueOnce(new Error('primary fail'));
|
||||
|
||||
// Mock fallback fetch for BOTH URLs
|
||||
mockFetch('https://url1.com/', {
|
||||
text: () => Promise.resolve('content 1'),
|
||||
});
|
||||
mockFetch('https://url2.com/', {
|
||||
text: () => Promise.resolve('content 2'),
|
||||
});
|
||||
|
||||
// Mock fallback LLM call
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [
|
||||
{ content: { parts: [{ text: 'fallback processed response' }] } },
|
||||
],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt: 'fetch https://url1.com and https://url2.com/',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toBe('fallback processed response');
|
||||
expect(result.returnDisplay).toContain(
|
||||
'2 URL(s) processed using fallback fetch',
|
||||
);
|
||||
});
|
||||
|
||||
it('should NOT include private URLs in fallback', async () => {
|
||||
vi.mocked(fetchUtils.isPrivateIp).mockImplementation(
|
||||
(url) => url === 'https://private.com/',
|
||||
);
|
||||
|
||||
// Primary fetch fails
|
||||
mockGenerateContent.mockRejectedValueOnce(new Error('primary fail'));
|
||||
|
||||
// Mock fallback fetch only for public URL
|
||||
mockFetch('https://public.com/', {
|
||||
text: () => Promise.resolve('public content'),
|
||||
});
|
||||
|
||||
// Mock fallback LLM call
|
||||
mockGenerateContent.mockResolvedValueOnce({
|
||||
candidates: [{ content: { parts: [{ text: 'fallback response' }] } }],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = {
|
||||
prompt: 'fetch https://public.com/ and https://private.com',
|
||||
};
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toBe('fallback response');
|
||||
// Verify private URL was NOT fetched (mockFetch would throw if it was called for private.com)
|
||||
});
|
||||
|
||||
it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
|
||||
mockFetch('https://private.ip/', new Error('fetch failed'));
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
mockGenerateContent.mockRejectedValue(new Error('primary fail'));
|
||||
mockFetch('https://public.ip/', new Error('fallback fetch failed'));
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://private.ip' };
|
||||
const params = { prompt: 'fetch https://public.ip' };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED);
|
||||
});
|
||||
|
||||
it('should return WEB_FETCH_PROCESSING_ERROR on general processing failure', async () => {
|
||||
it('should return WEB_FETCH_FALLBACK_FAILED on general processing failure (when fallback also fails)', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
mockGenerateContent.mockRejectedValue(new Error('API error'));
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://public.ip' };
|
||||
const invocation = tool.build(params);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR);
|
||||
});
|
||||
|
||||
it('should log telemetry when falling back due to private IP', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
|
||||
// Mock fetchWithTimeout to succeed so fallback proceeds
|
||||
mockFetch('https://private.ip/', {
|
||||
text: () => Promise.resolve('some content'),
|
||||
});
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: 'fallback response' }] } }],
|
||||
});
|
||||
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://private.ip' };
|
||||
const invocation = tool.build(params);
|
||||
await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(logWebFetchFallbackAttempt).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.any(WebFetchFallbackAttemptEvent),
|
||||
);
|
||||
expect(WebFetchFallbackAttemptEvent).toHaveBeenCalledWith('private_ip');
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_FALLBACK_FAILED);
|
||||
});
|
||||
|
||||
it('should log telemetry when falling back due to primary fetch failure', async () => {
|
||||
@@ -422,7 +567,7 @@ describe('WebFetchTool', () => {
|
||||
|
||||
expect(logWebFetchFallbackAttempt).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.any(WebFetchFallbackAttemptEvent),
|
||||
expect.objectContaining({ reason: 'primary_failed' }),
|
||||
);
|
||||
expect(WebFetchFallbackAttemptEvent).toHaveBeenCalledWith(
|
||||
'primary_failed',
|
||||
@@ -891,13 +1036,13 @@ describe('WebFetchTool', () => {
|
||||
});
|
||||
|
||||
it('should throw error if stream exceeds limit', async () => {
|
||||
const largeChunk = new Uint8Array(11 * 1024 * 1024);
|
||||
const large_chunk = new Uint8Array(11 * 1024 * 1024);
|
||||
mockFetch('https://example.com/large-stream', {
|
||||
body: {
|
||||
getReader: () => ({
|
||||
read: vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ done: false, value: largeChunk })
|
||||
.mockResolvedValueOnce({ done: false, value: large_chunk })
|
||||
.mockResolvedValueOnce({ done: true }),
|
||||
releaseLock: vi.fn(),
|
||||
cancel: vi.fn().mockResolvedValue(undefined),
|
||||
@@ -934,5 +1079,20 @@ describe('WebFetchTool', () => {
|
||||
expect(result.llmContent).toContain('Error: Invalid URL "not-a-url"');
|
||||
expect(result.error?.type).toBe(ToolErrorType.INVALID_TOOL_PARAMS);
|
||||
});
|
||||
|
||||
it('should block private IP (experimental)', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const invocation = tool['createInvocation'](
|
||||
{ url: 'http://localhost' },
|
||||
bus,
|
||||
);
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
|
||||
expect(result.llmContent).toContain(
|
||||
'Error: Access to blocked or private host http://localhost/ is not allowed.',
|
||||
);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user