mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 14:10:37 -07:00
fix(security): rate limit web_fetch tool to mitigate DDoS via prompt injection (#19567)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
12
integration-tests/concurrency-limit.responses
Normal file
12
integration-tests/concurrency-limit.responses
Normal file
@@ -0,0 +1,12 @@
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/1"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/2"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/3"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/4"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/5"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/6"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/7"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/8"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/9"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/10"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/11"}}}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":500,"totalTokenCount":600}}]}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 1 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 2 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 3 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 4 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 5 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 6 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 7 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 8 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 9 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 10 content"}],"role":"model"},"finishReason":"STOP","index":0}]}}
|
||||
{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Some requests were rate limited: Rate limit exceeded for host. Please wait 60 seconds before trying again."}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":1000,"candidatesTokenCount":50,"totalTokenCount":1050}}]}
|
||||
48
integration-tests/concurrency-limit.test.ts
Normal file
48
integration-tests/concurrency-limit.test.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2026 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { TestRig } from './test-helper.js';
|
||||
import { join } from 'node:path';
|
||||
|
||||
describe('web-fetch rate limiting', () => {
|
||||
let rig: TestRig;
|
||||
|
||||
beforeEach(() => {
|
||||
rig = new TestRig();
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
if (rig) {
|
||||
await rig.cleanup();
|
||||
}
|
||||
});
|
||||
|
||||
it('should rate limit multiple requests to the same host', async () => {
|
||||
rig.setup('web-fetch rate limit', {
|
||||
settings: { tools: { core: ['web_fetch'] } },
|
||||
fakeResponsesPath: join(
|
||||
import.meta.dirname,
|
||||
'concurrency-limit.responses',
|
||||
),
|
||||
});
|
||||
|
||||
const result = await rig.run({
|
||||
args: `Fetch 11 pages from example.com`,
|
||||
});
|
||||
|
||||
// We expect to find at least one tool call that failed with a rate limit error.
|
||||
const toolLogs = rig.readToolLogs();
|
||||
const rateLimitedCalls = toolLogs.filter(
|
||||
(log) =>
|
||||
log.toolRequest.name === 'web_fetch' &&
|
||||
log.toolRequest.error?.includes('Rate limit exceeded'),
|
||||
);
|
||||
|
||||
expect(rateLimitedCalls.length).toBeGreaterThan(0);
|
||||
expect(result).toContain('Rate limit exceeded');
|
||||
});
|
||||
});
|
||||
@@ -77,7 +77,10 @@ export async function checkPolicy(
|
||||
}
|
||||
}
|
||||
|
||||
return { decision, rule: result.rule };
|
||||
return {
|
||||
decision,
|
||||
rule: result.rule,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -192,6 +192,8 @@ export class ToolExecutor {
|
||||
tool: call.tool,
|
||||
invocation: call.invocation,
|
||||
durationMs: startTime ? Date.now() - startTime : undefined,
|
||||
startTime,
|
||||
endTime: Date.now(),
|
||||
outcome: call.outcome,
|
||||
};
|
||||
}
|
||||
@@ -263,6 +265,8 @@ export class ToolExecutor {
|
||||
response: successResponse,
|
||||
invocation: call.invocation,
|
||||
durationMs: startTime ? Date.now() - startTime : undefined,
|
||||
startTime,
|
||||
endTime: Date.now(),
|
||||
outcome: call.outcome,
|
||||
};
|
||||
}
|
||||
@@ -287,6 +291,8 @@ export class ToolExecutor {
|
||||
response,
|
||||
tool: call.tool,
|
||||
durationMs: startTime ? Date.now() - startTime : undefined,
|
||||
startTime,
|
||||
endTime: Date.now(),
|
||||
outcome: call.outcome,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -86,6 +86,8 @@ export type ErroredToolCall = {
|
||||
response: ToolCallResponseInfo;
|
||||
tool?: AnyDeclarativeTool;
|
||||
durationMs?: number;
|
||||
startTime?: number;
|
||||
endTime?: number;
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
@@ -98,6 +100,8 @@ export type SuccessfulToolCall = {
|
||||
response: ToolCallResponseInfo;
|
||||
invocation: AnyToolInvocation;
|
||||
durationMs?: number;
|
||||
startTime?: number;
|
||||
endTime?: number;
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
@@ -125,6 +129,8 @@ export type CancelledToolCall = {
|
||||
tool: AnyDeclarativeTool;
|
||||
invocation: AnyToolInvocation;
|
||||
durationMs?: number;
|
||||
startTime?: number;
|
||||
endTime?: number;
|
||||
outcome?: ToolConfirmationOutcome;
|
||||
schedulerId?: string;
|
||||
approvalMode?: ApprovalMode;
|
||||
|
||||
@@ -243,6 +243,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
mcp_server_name?: string;
|
||||
extension_name?: string;
|
||||
extension_id?: string;
|
||||
start_time?: number;
|
||||
end_time?: number;
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
metadata?: { [key: string]: any };
|
||||
|
||||
@@ -256,6 +258,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
prompt_id: string,
|
||||
tool_type: 'native' | 'mcp',
|
||||
error?: string,
|
||||
start_time?: number,
|
||||
end_time?: number,
|
||||
);
|
||||
constructor(
|
||||
call?: CompletedToolCall,
|
||||
@@ -266,6 +270,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
prompt_id?: string,
|
||||
tool_type?: 'native' | 'mcp',
|
||||
error?: string,
|
||||
start_time?: number,
|
||||
end_time?: number,
|
||||
) {
|
||||
this['event.name'] = 'tool_call';
|
||||
this['event.timestamp'] = new Date().toISOString();
|
||||
@@ -282,6 +288,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
this.error_type = call.response.errorType;
|
||||
this.prompt_id = call.request.prompt_id;
|
||||
this.content_length = call.response.contentLength;
|
||||
this.start_time = call.startTime;
|
||||
this.end_time = call.endTime;
|
||||
if (
|
||||
typeof call.tool !== 'undefined' &&
|
||||
call.tool instanceof DiscoveredMCPTool
|
||||
@@ -332,6 +340,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
this.prompt_id = prompt_id!;
|
||||
this.tool_type = tool_type!;
|
||||
this.error = error;
|
||||
this.start_time = start_time;
|
||||
this.end_time = end_time;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -351,6 +361,8 @@ export class ToolCallEvent implements BaseTelemetryEvent {
|
||||
mcp_server_name: this.mcp_server_name,
|
||||
extension_name: this.extension_name,
|
||||
extension_id: this.extension_id,
|
||||
start_time: this.start_time,
|
||||
end_time: this.end_time,
|
||||
metadata: this.metadata,
|
||||
};
|
||||
|
||||
|
||||
@@ -183,6 +183,26 @@ describe('WebFetchTool', () => {
|
||||
});
|
||||
|
||||
describe('execute', () => {
|
||||
it('should return WEB_FETCH_PROCESSING_ERROR on rate limit exceeded', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false);
|
||||
mockGenerateContent.mockResolvedValue({
|
||||
candidates: [{ content: { parts: [{ text: 'response' }] } }],
|
||||
});
|
||||
const tool = new WebFetchTool(mockConfig, bus);
|
||||
const params = { prompt: 'fetch https://ratelimit.example.com' };
|
||||
const invocation = tool.build(params);
|
||||
|
||||
// Execute 10 times to hit the limit
|
||||
for (let i = 0; i < 10; i++) {
|
||||
await invocation.execute(new AbortController().signal);
|
||||
}
|
||||
|
||||
// The 11th time should fail due to rate limit
|
||||
const result = await invocation.execute(new AbortController().signal);
|
||||
expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR);
|
||||
expect(result.error?.message).toContain('Rate limit exceeded for host');
|
||||
});
|
||||
|
||||
it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => {
|
||||
vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true);
|
||||
vi.spyOn(fetchUtils, 'fetchWithTimeout').mockRejectedValue(
|
||||
|
||||
@@ -33,10 +33,46 @@ import { debugLogger } from '../utils/debugLogger.js';
|
||||
import { retryWithBackoff } from '../utils/retry.js';
|
||||
import { WEB_FETCH_DEFINITION } from './definitions/coreTools.js';
|
||||
import { resolveToolDeclaration } from './definitions/resolver.js';
|
||||
import { LRUCache } from 'mnemonist';
|
||||
|
||||
const URL_FETCH_TIMEOUT_MS = 10000;
|
||||
const MAX_CONTENT_LENGTH = 100000;
|
||||
|
||||
// Rate limiting configuration
|
||||
const RATE_LIMIT_WINDOW_MS = 60000; // 1 minute
|
||||
const MAX_REQUESTS_PER_WINDOW = 10;
|
||||
const hostRequestHistory = new LRUCache<string, number[]>(1000);
|
||||
|
||||
function checkRateLimit(url: string): {
|
||||
allowed: boolean;
|
||||
waitTimeMs?: number;
|
||||
} {
|
||||
try {
|
||||
const hostname = new URL(url).hostname;
|
||||
const now = Date.now();
|
||||
const windowStart = now - RATE_LIMIT_WINDOW_MS;
|
||||
|
||||
let history = hostRequestHistory.get(hostname) || [];
|
||||
// Clean up old timestamps
|
||||
history = history.filter((timestamp) => timestamp > windowStart);
|
||||
|
||||
if (history.length >= MAX_REQUESTS_PER_WINDOW) {
|
||||
// Calculate wait time based on the oldest timestamp in the current window
|
||||
const oldestTimestamp = history[0];
|
||||
const waitTimeMs = oldestTimestamp + RATE_LIMIT_WINDOW_MS - now;
|
||||
hostRequestHistory.set(hostname, history); // Update cleaned history
|
||||
return { allowed: false, waitTimeMs: Math.max(0, waitTimeMs) };
|
||||
}
|
||||
|
||||
history.push(now);
|
||||
hostRequestHistory.set(hostname, history);
|
||||
return { allowed: true };
|
||||
} catch (_e) {
|
||||
// If URL parsing fails, we fallback to allowed (should be caught by parsePrompt anyway)
|
||||
return { allowed: true };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Parses a prompt to extract valid URLs and identify malformed ones.
|
||||
*/
|
||||
@@ -258,6 +294,23 @@ ${textContent}
|
||||
const userPrompt = this.params.prompt;
|
||||
const { validUrls: urls } = parsePrompt(userPrompt);
|
||||
const url = urls[0];
|
||||
|
||||
// Enforce rate limiting
|
||||
const rateLimitResult = checkRateLimit(url);
|
||||
if (!rateLimitResult.allowed) {
|
||||
const waitTimeSecs = Math.ceil((rateLimitResult.waitTimeMs || 0) / 1000);
|
||||
const errorMessage = `Rate limit exceeded for host. Please wait ${waitTimeSecs} seconds before trying again.`;
|
||||
debugLogger.warn(`[WebFetchTool] Rate limit exceeded for ${url}`);
|
||||
return {
|
||||
llmContent: `Error: ${errorMessage}`,
|
||||
returnDisplay: `Error: ${errorMessage}`,
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: ToolErrorType.WEB_FETCH_PROCESSING_ERROR,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const isPrivate = isPrivateIp(url);
|
||||
|
||||
if (isPrivate) {
|
||||
|
||||
@@ -208,6 +208,7 @@ export interface ParsedLog {
|
||||
stdout?: string;
|
||||
stderr?: string;
|
||||
error?: string;
|
||||
error_type?: string;
|
||||
prompt_id?: string;
|
||||
};
|
||||
scopeMetrics?: {
|
||||
@@ -1255,6 +1256,8 @@ export class TestRig {
|
||||
success: boolean;
|
||||
duration_ms: number;
|
||||
prompt_id?: string;
|
||||
error?: string;
|
||||
error_type?: string;
|
||||
};
|
||||
}[] = [];
|
||||
|
||||
@@ -1272,6 +1275,8 @@ export class TestRig {
|
||||
success: logData.attributes.success ?? false,
|
||||
duration_ms: logData.attributes.duration_ms ?? 0,
|
||||
prompt_id: logData.attributes.prompt_id,
|
||||
error: logData.attributes.error,
|
||||
error_type: logData.attributes.error_type,
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user