diff --git a/evals/grep_search_functionality.eval.ts b/evals/grep_search_functionality.eval.ts new file mode 100644 index 0000000000..77df3b950f --- /dev/null +++ b/evals/grep_search_functionality.eval.ts @@ -0,0 +1,170 @@ +/** + * @license + * Copyright 202 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, expect } from 'vitest'; +import { evalTest, TestRig } from './test-helper.js'; +import { + assertModelHasOutput, + checkModelOutputContent, +} from './test-helper.js'; + +describe('grep_search_functionality', () => { + const TEST_PREFIX = 'Grep Search Functionality: '; + + evalTest('USUALLY_PASSES', { + name: 'should find a simple string in a file', + files: { + 'test.txt': `hello + world + hello world`, + }, + prompt: 'Find "world" in test.txt', + assert: async (rig: TestRig, result: string) => { + await rig.waitForToolCall('grep_search'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/L2: world/, /L3: hello world/], + testName: `${TEST_PREFIX}simple search`, + }); + }, + }); + + evalTest('USUALLY_PASSES', { + name: 'should perform a case-sensitive search', + files: { + 'test.txt': `Hello + hello`, + }, + prompt: 'Find "Hello" in test.txt, case-sensitively.', + assert: async (rig: TestRig, result: string) => { + const wasToolCalled = await rig.waitForToolCall( + 'grep_search', + undefined, + (args) => { + const params = JSON.parse(args); + return params.case_sensitive === true; + }, + ); + expect( + wasToolCalled, + 'Expected grep_search to be called with case_sensitive: true', + ).toBe(true); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/L1: Hello/], + forbiddenContent: [/L2: hello/], + testName: `${TEST_PREFIX}case-sensitive search`, + }); + }, + }); + + evalTest('USUALLY_PASSES', { + name: 'should return only file names when names_only is used', + files: { + 'file1.txt': 'match me', + 'file2.txt': 'match me', + }, + prompt: 'Find the files containing "match me".', + assert: async (rig: TestRig, result: string) => { + const wasToolCalled = await rig.waitForToolCall( + 'grep_search', + undefined, + (args) => { + const params = JSON.parse(args); + return params.names_only === true; + }, + ); + expect( + wasToolCalled, + 'Expected grep_search to be called with names_only: true', + ).toBe(true); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/file1.txt/, /file2.txt/], + forbiddenContent: [/L1:/], + testName: `${TEST_PREFIX}names_only search`, + }); + }, + }); + + evalTest('USUALLY_PASSES', { + name: 'should search only within the specified include glob', + files: { + 'file.js': 'my_function();', + 'file.ts': 'my_function();', + }, + prompt: 'Find "my_function" in .js files.', + assert: async (rig: TestRig, result: string) => { + const wasToolCalled = await rig.waitForToolCall( + 'grep_search', + undefined, + (args) => { + const params = JSON.parse(args); + return params.include === '*.js'; + }, + ); + expect( + wasToolCalled, + 'Expected grep_search to be called with include: "*.js"', + ).toBe(true); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/file.js/], + forbiddenContent: [/file.ts/], + testName: `${TEST_PREFIX}include glob search`, + }); + }, + }); + + evalTest('USUALLY_PASSES', { + name: 'should search within a specific subdirectory', + files: { + 'src/main.js': 'unique_string_1', + 'lib/main.js': 'unique_string_2', + }, + prompt: 'Find "unique_string" in the src directory.', + assert: async (rig: TestRig, result: string) => { + const wasToolCalled = await rig.waitForToolCall( + 'grep_search', + undefined, + (args) => { + const params = JSON.parse(args); + return params.dir_path === 'src'; + }, + ); + expect( + wasToolCalled, + 'Expected grep_search to be called with dir_path: "src"', + ).toBe(true); + + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/unique_string_1/], + forbiddenContent: [/unique_string_2/], + testName: `${TEST_PREFIX}subdirectory search`, + }); + }, + }); + + evalTest('USUALLY_PASSES', { + name: 'should report no matches correctly', + files: { + 'file.txt': 'nothing to see here', + }, + prompt: 'Find "nonexistent" in file.txt', + assert: async (rig: TestRig, result: string) => { + await rig.waitForToolCall('grep_search'); + assertModelHasOutput(result); + checkModelOutputContent(result, { + expectedContent: [/No matches found/], + testName: `${TEST_PREFIX}no matches`, + }); + }, + }); +}); diff --git a/integration-tests/concurrency-limit.responses b/integration-tests/concurrency-limit.responses new file mode 100644 index 0000000000..e2bd5efe2a --- /dev/null +++ b/integration-tests/concurrency-limit.responses @@ -0,0 +1,12 @@ +{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/1"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/2"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/3"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/4"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/5"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/6"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/7"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/8"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/9"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/10"}}},{"functionCall":{"name":"web_fetch","args":{"prompt":"fetch https://example.com/11"}}}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":500,"totalTokenCount":600}}]} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 1 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 2 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 3 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 4 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 5 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 6 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 7 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 8 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 9 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContent","response":{"candidates":[{"content":{"parts":[{"text":"Page 10 content"}],"role":"model"},"finishReason":"STOP","index":0}]}} +{"method":"generateContentStream","response":[{"candidates":[{"content":{"parts":[{"text":"Some requests were rate limited: Rate limit exceeded for host. Please wait 60 seconds before trying again."}],"role":"model"},"finishReason":"STOP","index":0}],"usageMetadata":{"promptTokenCount":1000,"candidatesTokenCount":50,"totalTokenCount":1050}}]} diff --git a/integration-tests/concurrency-limit.test.ts b/integration-tests/concurrency-limit.test.ts new file mode 100644 index 0000000000..ba165b3393 --- /dev/null +++ b/integration-tests/concurrency-limit.test.ts @@ -0,0 +1,48 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { TestRig } from './test-helper.js'; +import { join } from 'node:path'; + +describe('web-fetch rate limiting', () => { + let rig: TestRig; + + beforeEach(() => { + rig = new TestRig(); + }); + + afterEach(async () => { + if (rig) { + await rig.cleanup(); + } + }); + + it('should rate limit multiple requests to the same host', async () => { + rig.setup('web-fetch rate limit', { + settings: { tools: { core: ['web_fetch'] } }, + fakeResponsesPath: join( + import.meta.dirname, + 'concurrency-limit.responses', + ), + }); + + const result = await rig.run({ + args: `Fetch 11 pages from example.com`, + }); + + // We expect to find at least one tool call that failed with a rate limit error. + const toolLogs = rig.readToolLogs(); + const rateLimitedCalls = toolLogs.filter( + (log) => + log.toolRequest.name === 'web_fetch' && + log.toolRequest.error?.includes('Rate limit exceeded'), + ); + + expect(rateLimitedCalls.length).toBeGreaterThan(0); + expect(result).toContain('Rate limit exceeded'); + }); +}); diff --git a/packages/cli/src/config/extensionRegistryClient.test.ts b/packages/cli/src/config/extensionRegistryClient.test.ts index 187390ceb0..4b9699d5e3 100644 --- a/packages/cli/src/config/extensionRegistryClient.test.ts +++ b/packages/cli/src/config/extensionRegistryClient.test.ts @@ -224,4 +224,59 @@ describe('ExtensionRegistryClient', () => { 'Failed to fetch extensions: Not Found', ); }); + + it('should not return irrelevant results', async () => { + fetchMock.mockResolvedValue({ + ok: true, + json: async () => [ + ...mockExtensions, + { + id: 'dataplex', + extensionName: 'dataplex', + extensionDescription: 'Connect to Dataplex Universal Catalog...', + fullName: 'google-cloud/dataplex', + rank: 6, + stars: 6, + url: '', + repoDescription: '', + lastUpdated: '', + extensionVersion: '1.0.0', + avatarUrl: '', + hasMCP: false, + hasContext: false, + isGoogleOwned: true, + licenseKey: '', + hasHooks: false, + hasCustomCommands: false, + hasSkills: false, + }, + { + id: 'conductor', + extensionName: 'conductor', + extensionDescription: 'A conductor extension that actually matches.', + fullName: 'someone/conductor', + rank: 100, + stars: 100, + url: '', + repoDescription: '', + lastUpdated: '', + extensionVersion: '1.0.0', + avatarUrl: '', + hasMCP: false, + hasContext: false, + isGoogleOwned: false, + licenseKey: '', + hasHooks: false, + hasCustomCommands: false, + hasSkills: false, + }, + ], + }); + + const results = await client.searchExtensions('conductor'); + const ids = results.map((r) => r.id); + + expect(ids).not.toContain('dataplex'); + expect(ids).toContain('conductor'); + }); }); diff --git a/packages/cli/src/config/extensionRegistryClient.ts b/packages/cli/src/config/extensionRegistryClient.ts index aeda50dc48..3735f0a798 100644 --- a/packages/cli/src/config/extensionRegistryClient.ts +++ b/packages/cli/src/config/extensionRegistryClient.ts @@ -79,7 +79,7 @@ export class ExtensionRegistryClient { const fzf = new AsyncFzf(allExtensions, { selector: (ext: RegistryExtension) => `${ext.extensionName} ${ext.extensionDescription} ${ext.fullName}`, - fuzzy: 'v2', + fuzzy: true, }); const results = await fzf.find(query); return results.map((r: { item: RegistryExtension }) => r.item); @@ -108,7 +108,6 @@ export class ExtensionRegistryClient { // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion return (await response.json()) as RegistryExtension[]; } catch (error) { - // Clear the promise on failure so that subsequent calls can try again ExtensionRegistryClient.fetchPromise = null; throw error; } diff --git a/packages/cli/src/ui/commands/extensionsCommand.ts b/packages/cli/src/ui/commands/extensionsCommand.ts index c7359a2a46..0a8a8d74e3 100644 --- a/packages/cli/src/ui/commands/extensionsCommand.ts +++ b/packages/cli/src/ui/commands/extensionsCommand.ts @@ -20,6 +20,7 @@ import { import { type CommandContext, type SlashCommand, + type SlashCommandActionReturn, CommandKind, } from './types.js'; import open from 'open'; @@ -35,6 +36,7 @@ import { stat } from 'node:fs/promises'; import { ExtensionSettingScope } from '../../config/extensions/extensionSettings.js'; import { type ConfigLogger } from '../../commands/extensions/utils.js'; import { ConfigExtensionDialog } from '../components/ConfigExtensionDialog.js'; +import { ExtensionRegistryView } from '../components/views/ExtensionRegistryView.js'; import React from 'react'; function showMessageIfNoExtensions( @@ -265,7 +267,28 @@ async function restartAction( } } -async function exploreAction(context: CommandContext) { +async function exploreAction( + context: CommandContext, +): Promise { + const settings = context.services.settings.merged; + const useRegistryUI = settings.experimental?.extensionRegistry; + + if (useRegistryUI) { + const extensionManager = context.services.config?.getExtensionLoader(); + if (extensionManager instanceof ExtensionManager) { + return { + type: 'custom_dialog' as const, + component: React.createElement(ExtensionRegistryView, { + onSelect: (extension) => { + debugLogger.debug(`Selected extension: ${extension.extensionName}`); + }, + onClose: () => context.ui.removeComponent(), + extensionManager, + }), + }; + } + } + const extensionsUrl = 'https://geminicli.com/extensions/'; // Only check for NODE_ENV for explicit test mode, not for unit test framework diff --git a/packages/cli/src/ui/components/shared/SearchableList.test.tsx b/packages/cli/src/ui/components/shared/SearchableList.test.tsx new file mode 100644 index 0000000000..42b118e251 --- /dev/null +++ b/packages/cli/src/ui/components/shared/SearchableList.test.tsx @@ -0,0 +1,233 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render } from '../../../test-utils/render.js'; +import { waitFor } from '../../../test-utils/async.js'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { + SearchableList, + type SearchableListProps, + type SearchListState, + type GenericListItem, +} from './SearchableList.js'; +import { KeypressProvider } from '../../contexts/KeypressContext.js'; +import { useTextBuffer } from './text-buffer.js'; + +const useMockSearch = (props: { + items: GenericListItem[]; + initialQuery?: string; + onSearch?: (query: string) => void; +}): SearchListState => { + const { onSearch, items, initialQuery = '' } = props; + const [text, setText] = React.useState(initialQuery); + const filteredItems = React.useMemo( + () => + items.filter((item: GenericListItem) => + item.label.toLowerCase().includes(text.toLowerCase()), + ), + [items, text], + ); + + React.useEffect(() => { + onSearch?.(text); + }, [text, onSearch]); + + const searchBuffer = useTextBuffer({ + initialText: text, + onChange: setText, + viewport: { width: 100, height: 1 }, + singleLine: true, + }); + + return { + filteredItems, + searchBuffer, + searchQuery: text, + setSearchQuery: setText, + maxLabelWidth: 10, + }; +}; + +vi.mock('../../contexts/UIStateContext.js', () => ({ + useUIState: () => ({ + mainAreaWidth: 100, + }), +})); + +const mockItems: GenericListItem[] = [ + { + key: 'item-1', + label: 'Item One', + description: 'Description for item one', + }, + { + key: 'item-2', + label: 'Item Two', + description: 'Description for item two', + }, + { + key: 'item-3', + label: 'Item Three', + description: 'Description for item three', + }, +]; + +describe('SearchableList', () => { + let mockOnSelect: ReturnType; + let mockOnClose: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + mockOnSelect = vi.fn(); + mockOnClose = vi.fn(); + }); + + const renderList = ( + props: Partial> = {}, + ) => { + const defaultProps: SearchableListProps = { + title: 'Test List', + items: mockItems, + onSelect: mockOnSelect, + onClose: mockOnClose, + useSearch: useMockSearch, + ...props, + }; + + return render( + + + , + ); + }; + + it('should render all items initially', async () => { + const { lastFrame, waitUntilReady } = renderList(); + await waitUntilReady(); + const frame = lastFrame(); + + expect(frame).toContain('Test List'); + + expect(frame).toContain('Item One'); + expect(frame).toContain('Item Two'); + expect(frame).toContain('Item Three'); + + expect(frame).toContain('Description for item one'); + }); + + it('should reset selection to top when items change if resetSelectionOnItemsChange is true', async () => { + const { lastFrame, stdin, waitUntilReady } = renderList({ + resetSelectionOnItemsChange: true, + }); + await waitUntilReady(); + + await React.act(async () => { + stdin.write('\u001B[B'); // Down arrow + }); + + await waitFor(() => { + const frame = lastFrame(); + expect(frame).toContain('> Item Two'); + }); + + await React.act(async () => { + stdin.write('One'); + }); + + await waitFor(() => { + const frame = lastFrame(); + expect(frame).toContain('Item One'); + expect(frame).not.toContain('Item Two'); + }); + + await React.act(async () => { + // Backspace "One" (3 chars) + stdin.write('\u007F\u007F\u007F'); + }); + + await waitFor(() => { + const frame = lastFrame(); + expect(frame).toContain('Item Two'); + expect(frame).toContain('> Item One'); + expect(frame).not.toContain('> Item Two'); + }); + }); + + it('should filter items based on search query', async () => { + const { lastFrame, stdin } = renderList(); + + await React.act(async () => { + stdin.write('Two'); + }); + + await waitFor(() => { + const frame = lastFrame(); + expect(frame).toContain('Item Two'); + expect(frame).not.toContain('Item One'); + expect(frame).not.toContain('Item Three'); + }); + }); + + it('should show "No items found." when no items match', async () => { + const { lastFrame, stdin } = renderList(); + + await React.act(async () => { + stdin.write('xyz123'); + }); + + await waitFor(() => { + const frame = lastFrame(); + expect(frame).toContain('No items found.'); + }); + }); + + it('should handle selection with Enter', async () => { + const { stdin } = renderList(); + + await React.act(async () => { + stdin.write('\r'); // Enter + }); + + await waitFor(() => { + expect(mockOnSelect).toHaveBeenCalledWith(mockItems[0]); + }); + }); + + it('should handle navigation and selection', async () => { + const { stdin } = renderList(); + + await React.act(async () => { + stdin.write('\u001B[B'); // Down arrow + }); + + await React.act(async () => { + stdin.write('\r'); // Enter + }); + + await waitFor(() => { + expect(mockOnSelect).toHaveBeenCalledWith(mockItems[1]); + }); + }); + + it('should handle close with Esc', async () => { + const { stdin } = renderList(); + + await React.act(async () => { + stdin.write('\u001B'); // Esc + }); + + await waitFor(() => { + expect(mockOnClose).toHaveBeenCalled(); + }); + }); + + it('should match snapshot', async () => { + const { lastFrame, waitUntilReady } = renderList(); + await waitUntilReady(); + expect(lastFrame()).toMatchSnapshot(); + }); +}); diff --git a/packages/cli/src/ui/components/shared/SearchableList.tsx b/packages/cli/src/ui/components/shared/SearchableList.tsx new file mode 100644 index 0000000000..a20a44be42 --- /dev/null +++ b/packages/cli/src/ui/components/shared/SearchableList.tsx @@ -0,0 +1,231 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import React, { useMemo, useCallback } from 'react'; +import { Box, Text } from 'ink'; +import { theme } from '../../semantic-colors.js'; +import { useSelectionList } from '../../hooks/useSelectionList.js'; +import { TextInput } from './TextInput.js'; +import type { TextBuffer } from './text-buffer.js'; +import { useKeypress } from '../../hooks/useKeypress.js'; +import { keyMatchers, Command } from '../../keyMatchers.js'; + +/** + * Generic interface for items in a searchable list. + */ +export interface GenericListItem { + key: string; + label: string; + description?: string; + [key: string]: unknown; +} + +/** + * State returned by the search hook. + */ +export interface SearchListState { + filteredItems: T[]; + searchBuffer: TextBuffer | undefined; + searchQuery: string; + setSearchQuery: (query: string) => void; + maxLabelWidth: number; +} + +/** + * Props for the SearchableList component. + */ +export interface SearchableListProps { + title?: string; + items: T[]; + onSelect: (item: T) => void; + onClose: () => void; + searchPlaceholder?: string; + /** Custom item renderer */ + renderItem?: ( + item: T, + isActive: boolean, + labelWidth: number, + ) => React.ReactNode; + /** Optional header content */ + header?: React.ReactNode; + /** Optional footer content */ + footer?: (info: { + startIndex: number; + endIndex: number; + totalVisible: number; + }) => React.ReactNode; + maxItemsToShow?: number; + /** Hook to handle search logic */ + useSearch: (props: { + items: T[]; + onSearch?: (query: string) => void; + }) => SearchListState; + onSearch?: (query: string) => void; + /** Whether to reset selection to the top when items change (e.g. after search) */ + resetSelectionOnItemsChange?: boolean; +} + +/** + * A generic searchable list component with keyboard navigation. + */ +export function SearchableList({ + title, + items, + onSelect, + onClose, + searchPlaceholder = 'Search...', + renderItem, + header, + footer, + maxItemsToShow = 10, + useSearch, + onSearch, + resetSelectionOnItemsChange = false, +}: SearchableListProps): React.JSX.Element { + const { filteredItems, searchBuffer, maxLabelWidth } = useSearch({ + items, + onSearch, + }); + + const selectionItems = useMemo( + () => + filteredItems.map((item) => ({ + key: item.key, + value: item, + })), + [filteredItems], + ); + + const handleSelectValue = useCallback( + (item: T) => { + onSelect(item); + }, + [onSelect], + ); + + const { activeIndex, setActiveIndex } = useSelectionList({ + items: selectionItems, + onSelect: handleSelectValue, + isFocused: true, + showNumbers: false, + wrapAround: true, + }); + + // Reset selection to top when items change if requested + const prevItemsRef = React.useRef(filteredItems); + React.useEffect(() => { + if (resetSelectionOnItemsChange && filteredItems !== prevItemsRef.current) { + setActiveIndex(0); + } + prevItemsRef.current = filteredItems; + }, [filteredItems, setActiveIndex, resetSelectionOnItemsChange]); + + // Handle global Escape key to close the list + useKeypress( + (key) => { + if (keyMatchers[Command.ESCAPE](key)) { + onClose(); + return true; + } + return false; + }, + { isActive: true }, + ); + + const scrollOffset = Math.max( + 0, + Math.min( + activeIndex - Math.floor(maxItemsToShow / 2), + Math.max(0, filteredItems.length - maxItemsToShow), + ), + ); + + const visibleItems = filteredItems.slice( + scrollOffset, + scrollOffset + maxItemsToShow, + ); + + const defaultRenderItem = ( + item: T, + isActive: boolean, + labelWidth: number, + ) => ( + + + {isActive ? '> ' : ' '} + {item.label.padEnd(labelWidth)} + + {item.description && ( + + + {item.description} + + + )} + + ); + + return ( + + {title && ( + + + {title} + + + )} + + {searchBuffer && ( + + + + )} + + {header && {header}} + + + {filteredItems.length === 0 ? ( + + No items found. + + ) : ( + visibleItems.map((item, index) => { + const isSelected = activeIndex === scrollOffset + index; + return ( + + {renderItem + ? renderItem(item, isSelected, maxLabelWidth) + : defaultRenderItem(item, isSelected, maxLabelWidth)} + + ); + }) + )} + + + {footer && ( + + {footer({ + startIndex: scrollOffset, + endIndex: scrollOffset + visibleItems.length, + totalVisible: filteredItems.length, + })} + + )} + + ); +} diff --git a/packages/cli/src/ui/components/shared/__snapshots__/SearchableList.test.tsx.snap b/packages/cli/src/ui/components/shared/__snapshots__/SearchableList.test.tsx.snap new file mode 100644 index 0000000000..e596373e01 --- /dev/null +++ b/packages/cli/src/ui/components/shared/__snapshots__/SearchableList.test.tsx.snap @@ -0,0 +1,19 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`SearchableList > should match snapshot 1`] = ` +" Test List + + ╭────────────────────────────────────────────────────────────────────────────────────────────────╮ + │ Search... │ + ╰────────────────────────────────────────────────────────────────────────────────────────────────╯ + + > Item One + Description for item one + + Item Two + Description for item two + + Item Three + Description for item three +" +`; diff --git a/packages/cli/src/ui/components/views/ExtensionRegistryView.test.tsx b/packages/cli/src/ui/components/views/ExtensionRegistryView.test.tsx new file mode 100644 index 0000000000..58f687eb6d --- /dev/null +++ b/packages/cli/src/ui/components/views/ExtensionRegistryView.test.tsx @@ -0,0 +1,204 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import React from 'react'; +import { render } from '../../../test-utils/render.js'; +import { waitFor } from '../../../test-utils/async.js'; +import { describe, it, expect, vi, beforeEach } from 'vitest'; +import { ExtensionRegistryView } from './ExtensionRegistryView.js'; +import { type ExtensionManager } from '../../../config/extension-manager.js'; +import { useExtensionRegistry } from '../../hooks/useExtensionRegistry.js'; +import { useExtensionUpdates } from '../../hooks/useExtensionUpdates.js'; +import { useRegistrySearch } from '../../hooks/useRegistrySearch.js'; +import { type RegistryExtension } from '../../../config/extensionRegistryClient.js'; +import { useUIState } from '../../contexts/UIStateContext.js'; +import { useConfig } from '../../contexts/ConfigContext.js'; +import { KeypressProvider } from '../../contexts/KeypressContext.js'; +import { + type SearchListState, + type GenericListItem, +} from '../shared/SearchableList.js'; +import { type TextBuffer } from '../shared/text-buffer.js'; + +// Mocks +vi.mock('../../hooks/useExtensionRegistry.js'); +vi.mock('../../hooks/useExtensionUpdates.js'); +vi.mock('../../hooks/useRegistrySearch.js'); +vi.mock('../../../config/extension-manager.js'); +vi.mock('../../contexts/UIStateContext.js'); +vi.mock('../../contexts/ConfigContext.js'); + +const mockExtensions: RegistryExtension[] = [ + { + id: 'ext1', + extensionName: 'Test Extension 1', + extensionDescription: 'Description 1', + fullName: 'author/ext1', + extensionVersion: '1.0.0', + rank: 1, + stars: 10, + url: 'http://example.com', + repoDescription: 'Repo Desc 1', + avatarUrl: 'http://avatar.com', + lastUpdated: '2023-01-01', + hasMCP: false, + hasContext: false, + hasHooks: false, + hasSkills: false, + hasCustomCommands: false, + isGoogleOwned: false, + licenseKey: 'mit', + }, + { + id: 'ext2', + extensionName: 'Test Extension 2', + extensionDescription: 'Description 2', + fullName: 'author/ext2', + extensionVersion: '2.0.0', + rank: 2, + stars: 20, + url: 'http://example.com/2', + repoDescription: 'Repo Desc 2', + avatarUrl: 'http://avatar.com/2', + lastUpdated: '2023-01-02', + hasMCP: true, + hasContext: true, + hasHooks: true, + hasSkills: true, + hasCustomCommands: true, + isGoogleOwned: true, + licenseKey: 'apache-2.0', + }, +]; + +describe('ExtensionRegistryView', () => { + let mockExtensionManager: ExtensionManager; + let mockOnSelect: ReturnType; + let mockOnClose: ReturnType; + let mockSearch: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + + mockExtensionManager = { + getExtensions: vi.fn().mockReturnValue([]), + } as unknown as ExtensionManager; + + mockOnSelect = vi.fn(); + mockOnClose = vi.fn(); + mockSearch = vi.fn(); + + vi.mocked(useExtensionRegistry).mockReturnValue({ + extensions: mockExtensions, + loading: false, + error: null, + search: mockSearch, + }); + + vi.mocked(useExtensionUpdates).mockReturnValue({ + extensionsUpdateState: new Map(), + } as unknown as ReturnType); + + // Mock useRegistrySearch implementation + vi.mocked(useRegistrySearch).mockImplementation( + (props: { items: GenericListItem[]; onSearch?: (q: string) => void }) => + ({ + filteredItems: props.items, // Pass through items + searchBuffer: { + text: '', + cursorOffset: 0, + viewport: { width: 10, height: 1 }, + visualCursor: [0, 0] as [number, number], + viewportVisualLines: [{ text: '', visualRowIndex: 0 }], + visualScrollRow: 0, + lines: [''], + cursor: [0, 0] as [number, number], + selectionAnchor: undefined, + } as unknown as TextBuffer, + searchQuery: '', + setSearchQuery: vi.fn(), + maxLabelWidth: 10, + }) as unknown as SearchListState, + ); + + vi.mocked(useUIState).mockReturnValue({ + mainAreaWidth: 100, + } as unknown as ReturnType); + + vi.mocked(useConfig).mockReturnValue({ + getEnableExtensionReloading: vi.fn().mockReturnValue(false), + } as unknown as ReturnType); + }); + + const renderView = () => + render( + + + , + ); + + it('should render extensions', async () => { + const { lastFrame } = renderView(); + await waitFor(() => { + expect(lastFrame()).toContain('Test Extension 1'); + expect(lastFrame()).toContain('Test Extension 2'); + }); + }); + + it('should use useRegistrySearch hook', () => { + renderView(); + expect(useRegistrySearch).toHaveBeenCalled(); + }); + + it('should call search function when typing', async () => { + // Mock useRegistrySearch to trigger onSearch + vi.mocked(useRegistrySearch).mockImplementation( + (props: { + items: GenericListItem[]; + onSearch?: (q: string) => void; + }): SearchListState => { + const { onSearch } = props; + // Simulate typing + React.useEffect(() => { + if (onSearch) { + onSearch('test query'); + } + }, [onSearch]); + return { + filteredItems: props.items, + searchBuffer: { + text: 'test query', + cursorOffset: 10, + viewport: { width: 10, height: 1 }, + visualCursor: [0, 10] as [number, number], + viewportVisualLines: [{ text: 'test query', visualRowIndex: 0 }], + visualScrollRow: 0, + lines: ['test query'], + cursor: [0, 10] as [number, number], + selectionAnchor: undefined, + } as unknown as TextBuffer, + searchQuery: 'test query', + setSearchQuery: vi.fn(), + maxLabelWidth: 10, + } as unknown as SearchListState; + }, + ); + + renderView(); + + await waitFor(() => { + expect(useRegistrySearch).toHaveBeenCalledWith( + expect.objectContaining({ + onSearch: mockSearch, + }), + ); + }); + }); +}); diff --git a/packages/cli/src/ui/components/views/ExtensionRegistryView.tsx b/packages/cli/src/ui/components/views/ExtensionRegistryView.tsx new file mode 100644 index 0000000000..9a7c15144a --- /dev/null +++ b/packages/cli/src/ui/components/views/ExtensionRegistryView.tsx @@ -0,0 +1,200 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type React from 'react'; +import { useMemo, useCallback } from 'react'; +import { Box, Text } from 'ink'; +import type { RegistryExtension } from '../../../config/extensionRegistryClient.js'; + +import { + SearchableList, + type GenericListItem, +} from '../shared/SearchableList.js'; +import { theme } from '../../semantic-colors.js'; + +import { useExtensionRegistry } from '../../hooks/useExtensionRegistry.js'; +import { ExtensionUpdateState } from '../../state/extensions.js'; +import { useExtensionUpdates } from '../../hooks/useExtensionUpdates.js'; +import { useConfig } from '../../contexts/ConfigContext.js'; +import type { ExtensionManager } from '../../../config/extension-manager.js'; +import { useRegistrySearch } from '../../hooks/useRegistrySearch.js'; + +interface ExtensionRegistryViewProps { + onSelect?: (extension: RegistryExtension) => void; + onClose?: () => void; + extensionManager: ExtensionManager; +} + +interface ExtensionItem extends GenericListItem { + extension: RegistryExtension; +} + +export function ExtensionRegistryView({ + onSelect, + onClose, + extensionManager, +}: ExtensionRegistryViewProps): React.JSX.Element { + const { extensions, loading, error, search } = useExtensionRegistry(); + const config = useConfig(); + + const { extensionsUpdateState } = useExtensionUpdates( + extensionManager, + () => 0, + config.getEnableExtensionReloading(), + ); + + const installedExtensions = extensionManager.getExtensions(); + + const items: ExtensionItem[] = useMemo( + () => + extensions.map((ext) => ({ + key: ext.id, + label: ext.extensionName, + description: ext.extensionDescription || ext.repoDescription, + extension: ext, + })), + [extensions], + ); + + const handleSelect = useCallback( + (item: ExtensionItem) => { + onSelect?.(item.extension); + }, + [onSelect], + ); + + const renderItem = useCallback( + (item: ExtensionItem, isActive: boolean, _labelWidth: number) => { + const isInstalled = installedExtensions.some( + (e) => e.name === item.extension.extensionName, + ); + const updateState = extensionsUpdateState.get( + item.extension.extensionName, + ); + const hasUpdate = updateState === ExtensionUpdateState.UPDATE_AVAILABLE; + + return ( + + + + + {isActive ? '> ' : ' '} + + + + + {item.label} + + + + | + + {isInstalled && ( + + [Installed] + + )} + {hasUpdate && ( + + [Update available] + + )} + + + {item.description} + + + + + + + {' '} + {item.extension.stars || 0} + + + + ); + }, + [installedExtensions, extensionsUpdateState], + ); + + const header = useMemo( + () => ( + + + + Browse and search extensions from the registry. + + + + + {installedExtensions.length && + `${installedExtensions.length} installed`} + + + + ), + [installedExtensions.length], + ); + + const footer = useCallback( + ({ + startIndex, + endIndex, + totalVisible, + }: { + startIndex: number; + endIndex: number; + totalVisible: number; + }) => ( + + ({startIndex + 1}-{endIndex}) / {totalVisible} + + ), + [], + ); + + if (loading) { + return ( + + Loading extensions... + + ); + } + + if (error) { + return ( + + Error loading extensions: + {error} + + ); + } + + return ( + + title="Extensions" + items={items} + onSelect={handleSelect} + onClose={onClose || (() => {})} + searchPlaceholder="Search extension gallery" + renderItem={renderItem} + header={header} + footer={footer} + maxItemsToShow={8} + useSearch={useRegistrySearch} + onSearch={search} + resetSelectionOnItemsChange={true} + /> + ); +} diff --git a/packages/cli/src/ui/hooks/useExtensionRegistry.ts b/packages/cli/src/ui/hooks/useExtensionRegistry.ts new file mode 100644 index 0000000000..cfd85ef229 --- /dev/null +++ b/packages/cli/src/ui/hooks/useExtensionRegistry.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useEffect, useMemo, useCallback, useRef } from 'react'; +import { + ExtensionRegistryClient, + type RegistryExtension, +} from '../../config/extensionRegistryClient.js'; + +export interface UseExtensionRegistryResult { + extensions: RegistryExtension[]; + loading: boolean; + error: string | null; + search: (query: string) => void; +} + +export function useExtensionRegistry( + initialQuery = '', +): UseExtensionRegistryResult { + const [extensions, setExtensions] = useState([]); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + + const client = useMemo(() => new ExtensionRegistryClient(), []); + + // Ref to track the latest query to avoid race conditions + const latestQueryRef = useRef(initialQuery); + + // Ref for debounce timeout + const debounceTimeoutRef = useRef(undefined); + + const searchExtensions = useCallback( + async (query: string) => { + try { + setLoading(true); + const results = await client.searchExtensions(query); + + // Only update if this is still the latest query + if (query === latestQueryRef.current) { + // Check if results are different from current extensions + setExtensions((prev) => { + if ( + prev.length === results.length && + prev.every((ext, i) => ext.id === results[i].id) + ) { + return prev; + } + return results; + }); + setError(null); + setLoading(false); + } + } catch (err) { + if (query === latestQueryRef.current) { + setError(err instanceof Error ? err.message : String(err)); + setExtensions([]); + setLoading(false); + } + } + }, + [client], + ); + + const search = useCallback( + (query: string) => { + latestQueryRef.current = query; + + // Clear existing timeout + if (debounceTimeoutRef.current) { + clearTimeout(debounceTimeoutRef.current); + } + + // Debounce + debounceTimeoutRef.current = setTimeout(() => { + void searchExtensions(query); + }, 300); + }, + [searchExtensions], + ); + + // Initial load + useEffect(() => { + void searchExtensions(initialQuery); + + return () => { + if (debounceTimeoutRef.current) { + clearTimeout(debounceTimeoutRef.current); + } + }; + }, [initialQuery, searchExtensions]); + + return { + extensions, + loading, + error, + search, + }; +} diff --git a/packages/cli/src/ui/hooks/useRegistrySearch.ts b/packages/cli/src/ui/hooks/useRegistrySearch.ts new file mode 100644 index 0000000000..e1a1c4191b --- /dev/null +++ b/packages/cli/src/ui/hooks/useRegistrySearch.ts @@ -0,0 +1,67 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { useState, useEffect } from 'react'; +import { + useTextBuffer, + type TextBuffer, +} from '../components/shared/text-buffer.js'; +import { useUIState } from '../contexts/UIStateContext.js'; +import type { GenericListItem } from '../components/shared/SearchableList.js'; + +const MIN_VIEWPORT_WIDTH = 20; +const VIEWPORT_WIDTH_OFFSET = 8; + +export interface UseRegistrySearchResult { + filteredItems: T[]; + searchBuffer: TextBuffer | undefined; + searchQuery: string; + setSearchQuery: (query: string) => void; + maxLabelWidth: number; +} + +export function useRegistrySearch(props: { + items: T[]; + initialQuery?: string; + onSearch?: (query: string) => void; +}): UseRegistrySearchResult { + const { items, initialQuery = '', onSearch } = props; + + const [searchQuery, setSearchQuery] = useState(initialQuery); + + useEffect(() => { + onSearch?.(searchQuery); + }, [searchQuery, onSearch]); + + const { mainAreaWidth } = useUIState(); + const viewportWidth = Math.max( + MIN_VIEWPORT_WIDTH, + mainAreaWidth - VIEWPORT_WIDTH_OFFSET, + ); + + const searchBuffer = useTextBuffer({ + initialText: searchQuery, + initialCursorOffset: searchQuery.length, + viewport: { + width: viewportWidth, + height: 1, + }, + singleLine: true, + onChange: (text) => setSearchQuery(text), + }); + + const maxLabelWidth = 0; + + const filteredItems = items; + + return { + filteredItems, + searchBuffer, + searchQuery, + setSearchQuery, + maxLabelWidth, + }; +} diff --git a/packages/core/src/scheduler/policy.ts b/packages/core/src/scheduler/policy.ts index 247b696f22..579f081d2c 100644 --- a/packages/core/src/scheduler/policy.ts +++ b/packages/core/src/scheduler/policy.ts @@ -77,7 +77,10 @@ export async function checkPolicy( } } - return { decision, rule: result.rule }; + return { + decision, + rule: result.rule, + }; } /** diff --git a/packages/core/src/scheduler/tool-executor.ts b/packages/core/src/scheduler/tool-executor.ts index 116598a2b9..b94b0e5184 100644 --- a/packages/core/src/scheduler/tool-executor.ts +++ b/packages/core/src/scheduler/tool-executor.ts @@ -192,6 +192,8 @@ export class ToolExecutor { tool: call.tool, invocation: call.invocation, durationMs: startTime ? Date.now() - startTime : undefined, + startTime, + endTime: Date.now(), outcome: call.outcome, }; } @@ -263,6 +265,8 @@ export class ToolExecutor { response: successResponse, invocation: call.invocation, durationMs: startTime ? Date.now() - startTime : undefined, + startTime, + endTime: Date.now(), outcome: call.outcome, }; } @@ -287,6 +291,8 @@ export class ToolExecutor { response, tool: call.tool, durationMs: startTime ? Date.now() - startTime : undefined, + startTime, + endTime: Date.now(), outcome: call.outcome, }; } diff --git a/packages/core/src/scheduler/types.ts b/packages/core/src/scheduler/types.ts index 7da611f23a..5fe6028bac 100644 --- a/packages/core/src/scheduler/types.ts +++ b/packages/core/src/scheduler/types.ts @@ -86,6 +86,8 @@ export type ErroredToolCall = { response: ToolCallResponseInfo; tool?: AnyDeclarativeTool; durationMs?: number; + startTime?: number; + endTime?: number; outcome?: ToolConfirmationOutcome; schedulerId?: string; approvalMode?: ApprovalMode; @@ -98,6 +100,8 @@ export type SuccessfulToolCall = { response: ToolCallResponseInfo; invocation: AnyToolInvocation; durationMs?: number; + startTime?: number; + endTime?: number; outcome?: ToolConfirmationOutcome; schedulerId?: string; approvalMode?: ApprovalMode; @@ -125,6 +129,8 @@ export type CancelledToolCall = { tool: AnyDeclarativeTool; invocation: AnyToolInvocation; durationMs?: number; + startTime?: number; + endTime?: number; outcome?: ToolConfirmationOutcome; schedulerId?: string; approvalMode?: ApprovalMode; diff --git a/packages/core/src/telemetry/types.ts b/packages/core/src/telemetry/types.ts index 497ff97469..e1a4079f3d 100644 --- a/packages/core/src/telemetry/types.ts +++ b/packages/core/src/telemetry/types.ts @@ -243,6 +243,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { mcp_server_name?: string; extension_name?: string; extension_id?: string; + start_time?: number; + end_time?: number; // eslint-disable-next-line @typescript-eslint/no-explicit-any metadata?: { [key: string]: any }; @@ -256,6 +258,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { prompt_id: string, tool_type: 'native' | 'mcp', error?: string, + start_time?: number, + end_time?: number, ); constructor( call?: CompletedToolCall, @@ -266,6 +270,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { prompt_id?: string, tool_type?: 'native' | 'mcp', error?: string, + start_time?: number, + end_time?: number, ) { this['event.name'] = 'tool_call'; this['event.timestamp'] = new Date().toISOString(); @@ -282,6 +288,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { this.error_type = call.response.errorType; this.prompt_id = call.request.prompt_id; this.content_length = call.response.contentLength; + this.start_time = call.startTime; + this.end_time = call.endTime; if ( typeof call.tool !== 'undefined' && call.tool instanceof DiscoveredMCPTool @@ -332,6 +340,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { this.prompt_id = prompt_id!; this.tool_type = tool_type!; this.error = error; + this.start_time = start_time; + this.end_time = end_time; } } @@ -351,6 +361,8 @@ export class ToolCallEvent implements BaseTelemetryEvent { mcp_server_name: this.mcp_server_name, extension_name: this.extension_name, extension_id: this.extension_id, + start_time: this.start_time, + end_time: this.end_time, metadata: this.metadata, }; diff --git a/packages/core/src/tools/definitions/__snapshots__/coreToolsModelSnapshots.test.ts.snap b/packages/core/src/tools/definitions/__snapshots__/coreToolsModelSnapshots.test.ts.snap index 9767829f0e..cdbb5d44a8 100644 --- a/packages/core/src/tools/definitions/__snapshots__/coreToolsModelSnapshots.test.ts.snap +++ b/packages/core/src/tools/definitions/__snapshots__/coreToolsModelSnapshots.test.ts.snap @@ -1089,7 +1089,7 @@ exports[`coreTools snapshots for specific models > Model: gemini-3-pro-preview > exports[`coreTools snapshots for specific models > Model: gemini-3-pro-preview > snapshot for tool: grep_search_ripgrep 1`] = ` { - "description": "Searches for a regular expression pattern within file contents.", + "description": "Searches for a regular expression pattern within file contents. This tool is FAST and optimized, powered by ripgrep. PREFERRED over standard \`run_shell_command("grep ...")\` due to better performance and automatic output limiting (defaults to 100 matches, but can be increased via \`total_max_matches\`).", "name": "grep_search", "parametersJsonSchema": { "properties": { diff --git a/packages/core/src/tools/definitions/model-family-sets/gemini-3.ts b/packages/core/src/tools/definitions/model-family-sets/gemini-3.ts index 71e8aaec1c..ce5f3fe429 100644 --- a/packages/core/src/tools/definitions/model-family-sets/gemini-3.ts +++ b/packages/core/src/tools/definitions/model-family-sets/gemini-3.ts @@ -131,7 +131,7 @@ The user has the ability to modify \`content\`. If modified, this will be stated grep_search_ripgrep: { name: GREP_TOOL_NAME, description: - 'Searches for a regular expression pattern within file contents.', + 'Searches for a regular expression pattern within file contents. This tool is FAST and optimized, powered by ripgrep. PREFERRED over standard `run_shell_command("grep ...")` due to better performance and automatic output limiting (defaults to 100 matches, but can be increased via `total_max_matches`).', parametersJsonSchema: { type: 'object', properties: { diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 19430c2f9a..3e592825dd 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -2056,6 +2056,90 @@ describe('connectToMcpServer with OAuth', () => { capturedTransport._requestInit?.headers?.['Authorization']; expect(authHeader).toBe('Bearer test-access-token-from-discovery'); }); + + it('should use discoverOAuthFromWWWAuthenticate when it succeeds and skip discoverOAuthConfig', async () => { + const serverUrl = 'http://test-server.com/mcp'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new StreamableHTTPError( + 401, + `Unauthorized\nwww-authenticate: ${wwwAuthHeader}`, + ), + ); + + vi.mocked(OAuthUtils.discoverOAuthFromWWWAuthenticate).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['read'], + }); + + vi.mocked(mockedClient.connect).mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + '0.0.1', + 'test-server', + { httpUrl: serverUrl, oauth: { enabled: true } }, + false, + workspaceContext, + EMPTY_CONFIG, + ); + + expect(client).toBe(mockedClient); + expect(OAuthUtils.discoverOAuthFromWWWAuthenticate).toHaveBeenCalledWith( + wwwAuthHeader, + serverUrl, + ); + expect(OAuthUtils.discoverOAuthConfig).not.toHaveBeenCalled(); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + }); + + it('should fall back to extractBaseUrl + discoverOAuthConfig when discoverOAuthFromWWWAuthenticate returns null', async () => { + const serverUrl = 'http://test-server.com/mcp'; + const baseUrl = 'http://test-server.com'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + const wwwAuthHeader = `Bearer realm="test"`; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new StreamableHTTPError( + 401, + `Unauthorized\nwww-authenticate: ${wwwAuthHeader}`, + ), + ); + + vi.mocked(OAuthUtils.discoverOAuthFromWWWAuthenticate).mockResolvedValue( + null, + ); + vi.mocked(OAuthUtils.extractBaseUrl).mockReturnValue(baseUrl); + vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['read'], + }); + + vi.mocked(mockedClient.connect).mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + '0.0.1', + 'test-server', + { httpUrl: serverUrl, oauth: { enabled: true } }, + false, + workspaceContext, + EMPTY_CONFIG, + ); + + expect(client).toBe(mockedClient); + expect(OAuthUtils.discoverOAuthFromWWWAuthenticate).toHaveBeenCalledWith( + wwwAuthHeader, + serverUrl, + ); + expect(OAuthUtils.extractBaseUrl).toHaveBeenCalledWith(serverUrl); + expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(baseUrl); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + }); }); describe('connectToMcpServer - HTTP→SSE fallback', () => { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index a838cf76e5..ccc6bbec3c 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -719,18 +719,17 @@ async function handleAutomaticOAuth( try { debugLogger.log(`🔐 '${mcpServerName}' requires OAuth authentication`); - // Always try to parse the resource metadata URI from the www-authenticate header - let oauthConfig; - const resourceMetadataUri = - OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate); - if (resourceMetadataUri) { - oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri); - } else if (hasNetworkTransport(mcpServerConfig)) { + const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url; + + // Try to discover OAuth config from the WWW-Authenticate header first + let oauthConfig = await OAuthUtils.discoverOAuthFromWWWAuthenticate( + wwwAuthenticate, + serverUrl, + ); + + if (!oauthConfig && hasNetworkTransport(mcpServerConfig)) { // Fallback: try to discover OAuth config from the base URL - const serverUrl = new URL( - mcpServerConfig.httpUrl || mcpServerConfig.url!, - ); - const baseUrl = `${serverUrl.protocol}//${serverUrl.host}`; + const baseUrl = OAuthUtils.extractBaseUrl(serverUrl!); oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl); } @@ -754,8 +753,6 @@ async function handleAutomaticOAuth( }; // Perform OAuth authentication - // Pass the server URL for proper discovery - const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url; debugLogger.log( `Starting OAuth authentication for server '${mcpServerName}'...`, ); diff --git a/packages/core/src/tools/web-fetch.test.ts b/packages/core/src/tools/web-fetch.test.ts index f0c6ff2c7e..2e06a46ee5 100644 --- a/packages/core/src/tools/web-fetch.test.ts +++ b/packages/core/src/tools/web-fetch.test.ts @@ -183,6 +183,26 @@ describe('WebFetchTool', () => { }); describe('execute', () => { + it('should return WEB_FETCH_PROCESSING_ERROR on rate limit exceeded', async () => { + vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(false); + mockGenerateContent.mockResolvedValue({ + candidates: [{ content: { parts: [{ text: 'response' }] } }], + }); + const tool = new WebFetchTool(mockConfig, bus); + const params = { prompt: 'fetch https://ratelimit.example.com' }; + const invocation = tool.build(params); + + // Execute 10 times to hit the limit + for (let i = 0; i < 10; i++) { + await invocation.execute(new AbortController().signal); + } + + // The 11th time should fail due to rate limit + const result = await invocation.execute(new AbortController().signal); + expect(result.error?.type).toBe(ToolErrorType.WEB_FETCH_PROCESSING_ERROR); + expect(result.error?.message).toContain('Rate limit exceeded for host'); + }); + it('should return WEB_FETCH_FALLBACK_FAILED on fallback fetch failure', async () => { vi.spyOn(fetchUtils, 'isPrivateIp').mockReturnValue(true); vi.spyOn(fetchUtils, 'fetchWithTimeout').mockRejectedValue( diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index 214cf4916b..9b6f832971 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -33,10 +33,46 @@ import { debugLogger } from '../utils/debugLogger.js'; import { retryWithBackoff } from '../utils/retry.js'; import { WEB_FETCH_DEFINITION } from './definitions/coreTools.js'; import { resolveToolDeclaration } from './definitions/resolver.js'; +import { LRUCache } from 'mnemonist'; const URL_FETCH_TIMEOUT_MS = 10000; const MAX_CONTENT_LENGTH = 100000; +// Rate limiting configuration +const RATE_LIMIT_WINDOW_MS = 60000; // 1 minute +const MAX_REQUESTS_PER_WINDOW = 10; +const hostRequestHistory = new LRUCache(1000); + +function checkRateLimit(url: string): { + allowed: boolean; + waitTimeMs?: number; +} { + try { + const hostname = new URL(url).hostname; + const now = Date.now(); + const windowStart = now - RATE_LIMIT_WINDOW_MS; + + let history = hostRequestHistory.get(hostname) || []; + // Clean up old timestamps + history = history.filter((timestamp) => timestamp > windowStart); + + if (history.length >= MAX_REQUESTS_PER_WINDOW) { + // Calculate wait time based on the oldest timestamp in the current window + const oldestTimestamp = history[0]; + const waitTimeMs = oldestTimestamp + RATE_LIMIT_WINDOW_MS - now; + hostRequestHistory.set(hostname, history); // Update cleaned history + return { allowed: false, waitTimeMs: Math.max(0, waitTimeMs) }; + } + + history.push(now); + hostRequestHistory.set(hostname, history); + return { allowed: true }; + } catch (_e) { + // If URL parsing fails, we fallback to allowed (should be caught by parsePrompt anyway) + return { allowed: true }; + } +} + /** * Parses a prompt to extract valid URLs and identify malformed ones. */ @@ -258,6 +294,23 @@ ${textContent} const userPrompt = this.params.prompt; const { validUrls: urls } = parsePrompt(userPrompt); const url = urls[0]; + + // Enforce rate limiting + const rateLimitResult = checkRateLimit(url); + if (!rateLimitResult.allowed) { + const waitTimeSecs = Math.ceil((rateLimitResult.waitTimeMs || 0) / 1000); + const errorMessage = `Rate limit exceeded for host. Please wait ${waitTimeSecs} seconds before trying again.`; + debugLogger.warn(`[WebFetchTool] Rate limit exceeded for ${url}`); + return { + llmContent: `Error: ${errorMessage}`, + returnDisplay: `Error: ${errorMessage}`, + error: { + message: errorMessage, + type: ToolErrorType.WEB_FETCH_PROCESSING_ERROR, + }, + }; + } + const isPrivate = isPrivateIp(url); if (isPrivate) { diff --git a/packages/test-utils/src/test-rig.ts b/packages/test-utils/src/test-rig.ts index 6e32ec7790..1cd55b84f7 100644 --- a/packages/test-utils/src/test-rig.ts +++ b/packages/test-utils/src/test-rig.ts @@ -208,6 +208,7 @@ export interface ParsedLog { stdout?: string; stderr?: string; error?: string; + error_type?: string; prompt_id?: string; }; scopeMetrics?: { @@ -1255,6 +1256,8 @@ export class TestRig { success: boolean; duration_ms: number; prompt_id?: string; + error?: string; + error_type?: string; }; }[] = []; @@ -1272,6 +1275,8 @@ export class TestRig { success: logData.attributes.success ?? false, duration_ms: logData.attributes.duration_ms ?? 0, prompt_id: logData.attributes.prompt_id, + error: logData.attributes.error, + error_type: logData.attributes.error_type, }, }); }