diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index e7d675ed15..9a76646dbe 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -752,11 +752,11 @@ describe('mergeMcpServers', () => { }); describe('mergeExcludeTools', () => { - const defaultExcludes = [ + const defaultExcludes = new Set([ SHELL_TOOL_NAME, EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME, - ]; + ]); const originalIsTTY = process.stdin.isTTY; beforeEach(() => { @@ -799,7 +799,7 @@ describe('mergeExcludeTools', () => { argv, ); expect(config.getExcludeTools()).toEqual( - expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']), + new Set(['tool1', 'tool2', 'tool3', 'tool4', 'tool5']), ); expect(config.getExcludeTools()).toHaveLength(5); }); @@ -821,7 +821,7 @@ describe('mergeExcludeTools', () => { const argv = await parseArguments({} as Settings); const config = await loadCliConfig(settings, 'test-session', argv); expect(config.getExcludeTools()).toEqual( - expect.arrayContaining(['tool1', 'tool2', 'tool3']), + new Set(['tool1', 'tool2', 'tool3']), ); expect(config.getExcludeTools()).toHaveLength(3); }); @@ -852,7 +852,7 @@ describe('mergeExcludeTools', () => { const argv = await parseArguments({} as Settings); const config = await loadCliConfig(settings, 'test-session', argv); expect(config.getExcludeTools()).toEqual( - expect.arrayContaining(['tool1', 'tool2', 'tool3', 'tool4']), + new Set(['tool1', 'tool2', 'tool3', 'tool4']), ); expect(config.getExcludeTools()).toHaveLength(4); }); @@ -863,7 +863,7 @@ describe('mergeExcludeTools', () => { process.argv = ['node', 'script.js']; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getExcludeTools()).toEqual([]); + expect(config.getExcludeTools()).toEqual(new Set([])); }); it('should return default excludes when no excludeTools are specified and it is not interactive', async () => { @@ -881,9 +881,7 @@ describe('mergeExcludeTools', () => { const settings: Settings = { tools: { exclude: ['tool1', 'tool2'] } }; vi.spyOn(ExtensionManager.prototype, 'getExtensions').mockReturnValue([]); const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getExcludeTools()).toEqual( - expect.arrayContaining(['tool1', 'tool2']), - ); + expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2'])); expect(config.getExcludeTools()).toHaveLength(2); }); @@ -903,9 +901,7 @@ describe('mergeExcludeTools', () => { process.argv = ['node', 'script.js']; const argv = await parseArguments({} as Settings); const config = await loadCliConfig(settings, 'test-session', argv); - expect(config.getExcludeTools()).toEqual( - expect.arrayContaining(['tool1', 'tool2']), - ); + expect(config.getExcludeTools()).toEqual(new Set(['tool1', 'tool2'])); expect(config.getExcludeTools()).toHaveLength(2); }); diff --git a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts index d79714d8f2..0828b6684e 100644 --- a/packages/cli/src/ui/hooks/atCommandProcessor.test.ts +++ b/packages/cli/src/ui/hooks/atCommandProcessor.test.ts @@ -57,6 +57,7 @@ describe('handleAtCommand', () => { getToolRegistry, getTargetDir: () => testRootDir, isSandboxed: () => false, + getExcludeTools: vi.fn(), getFileService: () => new FileDiscoveryService(testRootDir), getFileFilteringRespectGitIgnore: () => true, getFileFilteringRespectGeminiIgnore: () => true, diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 242639c8fb..368f42bd0f 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -872,27 +872,6 @@ describe('Server Config (config.ts)', () => { expect(wasShellToolRegistered).toBe(true); }); - it('should not register a tool if excludeTools contains the non-minified class name', async () => { - const params: ConfigParameters = { - ...baseParams, - coreTools: undefined, // all tools enabled by default - excludeTools: ['ShellTool'], - }; - const config = new Config(params); - await config.initialize(); - - const registerToolMock = ( - (await vi.importMock('../tools/tool-registry')) as { - ToolRegistry: { prototype: { registerTool: Mock } }; - } - ).ToolRegistry.prototype.registerTool; - - const wasShellToolRegistered = ( - registerToolMock as Mock - ).mock.calls.some((call) => call[0] instanceof vi.mocked(ShellTool)); - expect(wasShellToolRegistered).toBe(false); - }); - it('should register a tool if coreTools contains an argument-specific pattern with the non-minified class name', async () => { const params: ConfigParameters = { ...baseParams, diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index b53530c654..6a115f0385 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -826,7 +826,7 @@ export class Config { * * May change over time. */ - getExcludeTools(): string[] | undefined { + getExcludeTools(): Set | undefined { const excludeToolsSet = new Set([...(this.excludeTools ?? [])]); for (const extension of this.getExtensionLoader().getExtensions()) { if (!extension.isActive) { @@ -836,7 +836,7 @@ export class Config { excludeToolsSet.add(tool); } } - return [...excludeToolsSet]; + return excludeToolsSet; } getToolDiscoveryCommand(): string | undefined { @@ -1282,7 +1282,6 @@ export class Config { const className = ToolClass.name; const toolName = ToolClass.Name || className; const coreTools = this.getCoreTools(); - const excludeTools = this.getExcludeTools() || []; // On some platforms, the className can be minified to _ClassName. const normalizedClassName = className.replace(/^_+/, ''); @@ -1297,14 +1296,6 @@ export class Config { ); } - const isExcluded = excludeTools.some( - (tool) => tool === toolName || tool === normalizedClassName, - ); - - if (isExcluded) { - isEnabled = false; - } - if (isEnabled) { // Pass message bus to tools when feature flag is enabled // This first implementation is only focused on the general case of @@ -1363,15 +1354,12 @@ export class Config { ); if (definition) { // We must respect the main allowed/exclude lists for agents too. - const excludeTools = this.getExcludeTools() || []; - const allowedTools = this.getAllowedTools(); - const isExcluded = excludeTools.includes(definition.name); const isAllowed = !allowedTools || allowedTools.includes(definition.name); - if (isAllowed && !isExcluded) { + if (isAllowed) { const messageBusEnabled = this.getEnableMessageBusIntegration(); const wrapper = new SubagentToolWrapper( definition, diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index f5c9205e9d..13a50b0603 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -243,6 +243,10 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< ); } + getFullyQualifiedPrefix(): string { + return `${this.serverName}__`; + } + asFullyQualifiedTool(): DiscoveredMCPTool { return new DiscoveredMCPTool( this.mcpTool, @@ -251,7 +255,7 @@ export class DiscoveredMCPTool extends BaseDeclarativeTool< this.description, this.parameterSchema, this.trust, - `${this.serverName}__${this.serverToolName}`, + `${this.getFullyQualifiedPrefix()}${this.serverToolName}`, this.cliConfig, this.extensionName, this.extensionId, diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index 2d76057a73..d93e815a1e 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -82,7 +82,7 @@ describe('ShellTool', () => { getAllowedTools: vi.fn().mockReturnValue([]), getApprovalMode: vi.fn().mockReturnValue('strict'), getCoreTools: vi.fn().mockReturnValue([]), - getExcludeTools: vi.fn().mockReturnValue([]), + getExcludeTools: vi.fn().mockReturnValue(new Set([])), getDebugMode: vi.fn().mockReturnValue(false), getTargetDir: vi.fn().mockReturnValue(tempRootDir), getSummarizeToolOutputConfig: vi.fn().mockReturnValue(undefined), diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 1d3ddb786b..17218aaaa0 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -5,7 +5,7 @@ */ /* eslint-disable @typescript-eslint/no-explicit-any */ -import type { Mocked } from 'vitest'; +import type { Mocked, MockInstance } from 'vitest'; import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; import type { ConfigParameters } from '../config/config.js'; import { Config } from '../config/config.js'; @@ -109,6 +109,9 @@ describe('ToolRegistry', () => { let config: Config; let toolRegistry: ToolRegistry; let mockConfigGetToolDiscoveryCommand: ReturnType; + let mockConfigGetExcludedTools: MockInstance< + typeof Config.prototype.getExcludeTools + >; beforeEach(() => { vi.mocked(fs.existsSync).mockReturnValue(true); @@ -132,6 +135,7 @@ describe('ToolRegistry', () => { config, 'getToolDiscoveryCommand', ); + mockConfigGetExcludedTools = vi.spyOn(config, 'getExcludeTools'); vi.spyOn(config, 'getMcpServers'); vi.spyOn(config, 'getMcpServerCommand'); vi.spyOn(config, 'getPromptRegistry').mockReturnValue({ @@ -152,6 +156,75 @@ describe('ToolRegistry', () => { }); }); + describe('excluded tools', () => { + const simpleTool = new MockTool({ + name: 'tool-a', + displayName: 'Tool a', + }); + const excludedTool = new ExcludedMockTool({ + name: 'excluded-tool-class', + displayName: 'Excluded Tool Class', + }); + const mockCallable = {} as CallableTool; + const mcpTool = new DiscoveredMCPTool( + mockCallable, + 'mcp-server', + 'excluded-mcp-tool', + 'description', + {}, + ); + const allowedTool = new MockTool({ + name: 'allowed-tool', + displayName: 'Allowed Tool', + }); + + it.each([ + { + name: 'should match simple names', + tools: [simpleTool], + excludedTools: ['tool-a'], + }, + { + name: 'should match simple MCP tool names, when qualified or unqualified', + tools: [mcpTool, mcpTool.asFullyQualifiedTool()], + excludedTools: [mcpTool.name], + }, + { + name: 'should match qualified MCP tool names when qualified or unqualified', + tools: [mcpTool, mcpTool.asFullyQualifiedTool()], + excludedTools: [`${mcpTool.getFullyQualifiedPrefix()}${mcpTool.name}`], + }, + { + name: 'should match class names', + tools: [excludedTool], + excludedTools: ['ExcludedMockTool'], + }, + ])('$name', ({ tools, excludedTools }) => { + toolRegistry.registerTool(allowedTool); + for (const tool of tools) { + toolRegistry.registerTool(tool); + } + mockConfigGetExcludedTools.mockReturnValue(new Set(excludedTools)); + + expect(toolRegistry.getAllTools()).toEqual([allowedTool]); + expect(toolRegistry.getAllToolNames()).toEqual([allowedTool.name]); + expect(toolRegistry.getFunctionDeclarations()).toEqual( + toolRegistry.getFunctionDeclarationsFiltered([allowedTool.name]), + ); + for (const tool of tools) { + expect(toolRegistry.getTool(tool.name)).toBeUndefined(); + expect( + toolRegistry.getFunctionDeclarationsFiltered([tool.name]), + ).toHaveLength(0); + if (tool instanceof DiscoveredMCPTool) { + expect(toolRegistry.getToolsByServer(tool.serverName)).toHaveLength( + 0, + ); + } + } + }); + }); + describe('getAllTools', () => { it('should return all registered tools sorted alphabetically by displayName', () => { // Register tools with displayNames in non-alphabetical order @@ -521,3 +594,12 @@ describe('ToolRegistry', () => { }); }); }); + +/** + * Used for tests that exclude by class name. + */ +class ExcludedMockTool extends MockTool { + constructor(options: ConstructorParameters[0]) { + super(options); + } +} diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index 59f45a6826..c59e82e932 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -189,7 +189,9 @@ Signal: Signal number or \`(none)\` if no signal was received. export class ToolRegistry { // The tools keyed by tool name as seen by the LLM. - private tools: Map = new Map(); + // This includes tools which are currently not active, use `getActiveTools` + // and `isActive` to get only the active tools. + private allKnownTools: Map = new Map(); private config: Config; private messageBus?: MessageBus; @@ -207,10 +209,14 @@ export class ToolRegistry { /** * Registers a tool definition. + * + * Note that excluded tools are still registered to allow for enabling them + * later in the session. + * * @param tool - The tool object containing schema and execution logic. */ registerTool(tool: AnyDeclarativeTool): void { - if (this.tools.has(tool.name)) { + if (this.allKnownTools.has(tool.name)) { if (tool instanceof DiscoveredMCPTool) { tool = tool.asFullyQualifiedTool(); } else { @@ -220,7 +226,7 @@ export class ToolRegistry { ); } } - this.tools.set(tool.name, tool); + this.allKnownTools.set(tool.name, tool); } /** @@ -229,7 +235,7 @@ export class ToolRegistry { * 2. Discovered tools. * 3. MCP tools ordered by server name. * - * This is a stable sort in that ties preseve existing order. + * This is a stable sort in that tries preserve existing order. */ sortTools(): void { const getPriority = (tool: AnyDeclarativeTool): number => { @@ -238,8 +244,8 @@ export class ToolRegistry { return 0; // Built-in }; - this.tools = new Map( - Array.from(this.tools.entries()).sort((a, b) => { + this.allKnownTools = new Map( + Array.from(this.allKnownTools.entries()).sort((a, b) => { const toolA = a[1]; const toolB = b[1]; const priorityA = getPriority(toolA); @@ -261,9 +267,9 @@ export class ToolRegistry { } private removeDiscoveredTools(): void { - for (const tool of this.tools.values()) { + for (const tool of this.allKnownTools.values()) { if (tool instanceof DiscoveredTool || tool instanceof DiscoveredMCPTool) { - this.tools.delete(tool.name); + this.allKnownTools.delete(tool.name); } } } @@ -273,9 +279,9 @@ export class ToolRegistry { * @param serverName The name of the server to remove tools from. */ removeMcpToolsByServer(serverName: string): void { - for (const [name, tool] of this.tools.entries()) { + for (const [name, tool] of this.allKnownTools.entries()) { if (tool instanceof DiscoveredMCPTool && tool.serverName === serverName) { - this.tools.delete(name); + this.allKnownTools.delete(name); } } } @@ -416,6 +422,45 @@ export class ToolRegistry { } } + /** + * @returns All the tools that are not excluded. + */ + private getActiveTools(): AnyDeclarativeTool[] { + const excludedTools = this.config.getExcludeTools() ?? new Set([]); + const activeTools: AnyDeclarativeTool[] = []; + for (const tool of this.allKnownTools.values()) { + if (this.isActiveTool(tool, excludedTools)) { + activeTools.push(tool); + } + } + return activeTools; + } + + /** + * @param tool + * @param excludeTools (optional, helps performance for repeated calls) + * @returns Whether or not the `tool` is not excluded. + */ + private isActiveTool( + tool: AnyDeclarativeTool, + excludeTools?: Set, + ): boolean { + excludeTools ??= this.config.getExcludeTools() ?? new Set([]); + const normalizedClassName = tool.constructor.name.replace(/^_+/, ''); + const possibleNames = [tool.name, normalizedClassName]; + if (tool instanceof DiscoveredMCPTool) { + // Check both the unqualified and qualified name for MCP tools. + if (tool.name.startsWith(tool.getFullyQualifiedPrefix())) { + possibleNames.push( + tool.name.substring(tool.getFullyQualifiedPrefix().length), + ); + } else { + possibleNames.push(`${tool.getFullyQualifiedPrefix()}${tool.name}`); + } + } + return !possibleNames.some((name) => excludeTools.has(name)); + } + /** * Retrieves the list of tool schemas (FunctionDeclaration array). * Extracts the declarations from the ToolListUnion structure. @@ -424,7 +469,7 @@ export class ToolRegistry { */ getFunctionDeclarations(): FunctionDeclaration[] { const declarations: FunctionDeclaration[] = []; - this.tools.forEach((tool) => { + this.getActiveTools().forEach((tool) => { declarations.push(tool.schema); }); return declarations; @@ -438,8 +483,8 @@ export class ToolRegistry { getFunctionDeclarationsFiltered(toolNames: string[]): FunctionDeclaration[] { const declarations: FunctionDeclaration[] = []; for (const name of toolNames) { - const tool = this.tools.get(name); - if (tool) { + const tool = this.allKnownTools.get(name); + if (tool && this.isActiveTool(tool)) { declarations.push(tool.schema); } } @@ -447,17 +492,18 @@ export class ToolRegistry { } /** - * Returns an array of all registered and discovered tool names. + * Returns an array of all registered and discovered tool names which are not + * excluded via configuration. */ getAllToolNames(): string[] { - return Array.from(this.tools.keys()); + return this.getActiveTools().map((tool) => tool.name); } /** * Returns an array of all registered and discovered tool instances. */ getAllTools(): AnyDeclarativeTool[] { - return Array.from(this.tools.values()).sort((a, b) => + return this.getActiveTools().sort((a, b) => a.displayName.localeCompare(b.displayName), ); } @@ -467,7 +513,7 @@ export class ToolRegistry { */ getToolsByServer(serverName: string): AnyDeclarativeTool[] { const serverTools: AnyDeclarativeTool[] = []; - for (const tool of this.tools.values()) { + for (const tool of this.getActiveTools()) { if ((tool as DiscoveredMCPTool)?.serverName === serverName) { serverTools.push(tool); } @@ -479,6 +525,10 @@ export class ToolRegistry { * Get the definition of a specific tool. */ getTool(name: string): AnyDeclarativeTool | undefined { - return this.tools.get(name); + const tool = this.allKnownTools.get(name); + if (tool && this.isActiveTool(tool)) { + return tool; + } + return; } } diff --git a/packages/core/src/utils/extensionLoader.test.ts b/packages/core/src/utils/extensionLoader.test.ts index cb175b2f64..38b4a60223 100644 --- a/packages/core/src/utils/extensionLoader.test.ts +++ b/packages/core/src/utils/extensionLoader.test.ts @@ -4,10 +4,19 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, expect, it, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + expect, + it, + vi, + beforeEach, + afterEach, + type MockInstance, +} from 'vitest'; import { SimpleExtensionLoader } from './extensionLoader.js'; -import type { Config } from '../config/config.js'; +import type { Config, GeminiCLIExtension } from '../config/config.js'; import { type McpClientManager } from '../tools/mcp-client-manager.js'; +import type { GeminiClient } from '../core/client.js'; const mockRefreshServerHierarchicalMemory = vi.hoisted(() => vi.fn()); @@ -23,15 +32,20 @@ describe('SimpleExtensionLoader', () => { let mockConfig: Config; let extensionReloadingEnabled: boolean; let mockMcpClientManager: McpClientManager; - const activeExtension = { + let mockGeminiClientSetTools: MockInstance< + typeof GeminiClient.prototype.setTools + >; + + const activeExtension: GeminiCLIExtension = { name: 'test-extension', isActive: true, version: '1.0.0', path: '/path/to/extension', contextFiles: [], + excludeTools: ['some-tool'], id: '123', }; - const inactiveExtension = { + const inactiveExtension: GeminiCLIExtension = { name: 'test-extension', isActive: false, version: '1.0.0', @@ -46,9 +60,14 @@ describe('SimpleExtensionLoader', () => { stopExtension: vi.fn(), } as unknown as McpClientManager; extensionReloadingEnabled = false; + mockGeminiClientSetTools = vi.fn(); mockConfig = { getMcpClientManager: () => mockMcpClientManager, getEnableExtensionReloading: () => extensionReloadingEnabled, + getGeminiClient: vi.fn(() => ({ + isInitialized: () => true, + setTools: mockGeminiClientSetTools, + })), } as unknown as Config; }); @@ -106,11 +125,14 @@ describe('SimpleExtensionLoader', () => { mockMcpClientManager.startExtension, ).toHaveBeenCalledExactlyOnceWith(activeExtension); expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce(); + expect(mockGeminiClientSetTools).toHaveBeenCalledOnce(); } else { expect(mockMcpClientManager.startExtension).not.toHaveBeenCalled(); expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled(); + expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce(); } mockRefreshServerHierarchicalMemory.mockClear(); + mockGeminiClientSetTools.mockClear(); await loader.unloadExtension(activeExtension); if (reloadingEnabled) { @@ -118,9 +140,11 @@ describe('SimpleExtensionLoader', () => { mockMcpClientManager.stopExtension, ).toHaveBeenCalledExactlyOnceWith(activeExtension); expect(mockRefreshServerHierarchicalMemory).toHaveBeenCalledOnce(); + expect(mockGeminiClientSetTools).toHaveBeenCalledOnce(); } else { expect(mockMcpClientManager.stopExtension).not.toHaveBeenCalled(); expect(mockRefreshServerHierarchicalMemory).not.toHaveBeenCalled(); + expect(mockGeminiClientSetTools).not.toHaveBeenCalledOnce(); } }); diff --git a/packages/core/src/utils/extensionLoader.ts b/packages/core/src/utils/extensionLoader.ts index 411ff6a7d1..ab8bed2ce9 100644 --- a/packages/core/src/utils/extensionLoader.ts +++ b/packages/core/src/utils/extensionLoader.ts @@ -73,6 +73,8 @@ export abstract class ExtensionLoader { }); try { await this.config.getMcpClientManager()!.startExtension(extension); + await this.maybeRefreshGeminiTools(extension); + // Note: Context files are loaded only once all extensions are done // loading/unloading to reduce churn, see the `maybeRefreshMemories` call // below. @@ -80,9 +82,6 @@ export abstract class ExtensionLoader { // TODO: Update custom command updating away from the event based system // and call directly into a custom command manager here. See the // useSlashCommandProcessor hook which responds to events fired here today. - - // TODO: Move all enablement of extension features here, including at least: - // - excluded tool configuration } finally { this.startCompletedCount++; this.eventEmitter?.emit('extensionsStarting', { @@ -115,6 +114,21 @@ export abstract class ExtensionLoader { } } + /** + * Refreshes the gemini tools list if it is initialized and the extension has + * any excludeTools settings. + */ + private async maybeRefreshGeminiTools( + extension: GeminiCLIExtension, + ): Promise { + if (extension.excludeTools && extension.excludeTools.length > 0) { + const geminiClient = this.config?.getGeminiClient(); + if (geminiClient?.isInitialized()) { + await geminiClient.setTools(); + } + } + } + /** * If extension reloading is enabled and `start` has already been called, * then calls `startExtension` to include all extension features into the @@ -150,6 +164,8 @@ export abstract class ExtensionLoader { try { await this.config.getMcpClientManager()!.stopExtension(extension); + await this.maybeRefreshGeminiTools(extension); + // Note: Context files are loaded only once all extensions are done // loading/unloading to reduce churn, see the `maybeRefreshMemories` call // below. @@ -157,9 +173,6 @@ export abstract class ExtensionLoader { // TODO: Update custom command updating away from the event based system // and call directly into a custom command manager here. See the // useSlashCommandProcessor hook which responds to events fired here today. - - // TODO: Remove all extension features here, including at least: - // - excluded tools } finally { this.stopCompletedCount++; this.eventEmitter?.emit('extensionsStopping', { diff --git a/packages/core/src/utils/shell-utils.test.ts b/packages/core/src/utils/shell-utils.test.ts index 1ae4049f2b..6414664e8d 100644 --- a/packages/core/src/utils/shell-utils.test.ts +++ b/packages/core/src/utils/shell-utils.test.ts @@ -58,7 +58,7 @@ beforeEach(() => { ); config = { getCoreTools: () => [], - getExcludeTools: () => [], + getExcludeTools: () => new Set([]), getAllowedTools: () => [], } as unknown as Config; }); @@ -89,7 +89,7 @@ describe('isCommandAllowed', () => { }); it('should block a command if it is in the blocked list', () => { - config.getExcludeTools = () => ['ShellTool(badCommand --danger)']; + config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']); const result = isCommandAllowed('badCommand --danger', config); expect(result.allowed).toBe(false); expect(result.reason).toBe( @@ -99,7 +99,7 @@ describe('isCommandAllowed', () => { it('should prioritize the blocklist over the allowlist', () => { config.getCoreTools = () => ['ShellTool(badCommand --danger)']; - config.getExcludeTools = () => ['ShellTool(badCommand --danger)']; + config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']); const result = isCommandAllowed('badCommand --danger', config); expect(result.allowed).toBe(false); expect(result.reason).toBe( @@ -114,7 +114,7 @@ describe('isCommandAllowed', () => { }); it('should block any command when a wildcard is in excludeTools', () => { - config.getExcludeTools = () => ['run_shell_command']; + config.getExcludeTools = () => new Set(['run_shell_command']); const result = isCommandAllowed('any random command', config); expect(result.allowed).toBe(false); expect(result.reason).toBe( @@ -124,7 +124,7 @@ describe('isCommandAllowed', () => { it('should block a command on the blocklist even with a wildcard allow', () => { config.getCoreTools = () => ['ShellTool']; - config.getExcludeTools = () => ['ShellTool(badCommand --danger)']; + config.getExcludeTools = () => new Set(['ShellTool(badCommand --danger)']); const result = isCommandAllowed('badCommand --danger', config); expect(result.allowed).toBe(false); expect(result.reason).toBe( @@ -145,7 +145,7 @@ describe('isCommandAllowed', () => { }); it('should block a chained command if any part is blocked', () => { - config.getExcludeTools = () => ['run_shell_command(badCommand)']; + config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']); const result = isCommandAllowed( 'echo "hello" && badCommand --danger', config, @@ -159,7 +159,7 @@ describe('isCommandAllowed', () => { it('should block a command that redefines an allowed function to run an unlisted command', () => { config.getCoreTools = () => ['run_shell_command(echo)']; const result = isCommandAllowed( - 'echo () (curl google.com) ; echo Hello Wolrd', + 'echo () (curl google.com) ; echo Hello World', config, ); expect(result.allowed).toBe(false); @@ -355,7 +355,7 @@ describe('checkCommandPermissions', () => { }); it('should return a detailed failure object for a blocked command', () => { - config.getExcludeTools = () => ['ShellTool(badCommand)']; + config.getExcludeTools = () => new Set(['ShellTool(badCommand)']); const result = checkCommandPermissions('badCommand --danger', config); expect(result).toEqual({ allAllowed: false, @@ -424,7 +424,7 @@ describe('checkCommandPermissions', () => { }); it('should block a command on the sessionAllowlist if it is also globally blocked', () => { - config.getExcludeTools = () => ['run_shell_command(badCommand)']; + config.getExcludeTools = () => new Set(['run_shell_command(badCommand)']); const result = checkCommandPermissions( 'badCommand --danger', config, diff --git a/packages/core/src/utils/shell-utils.ts b/packages/core/src/utils/shell-utils.ts index ea136492cc..2528c9ebc4 100644 --- a/packages/core/src/utils/shell-utils.ts +++ b/packages/core/src/utils/shell-utils.ts @@ -605,9 +605,9 @@ export function checkCommandPermissions( } as AnyToolInvocation & { params: { command: string } }; // 1. Blocklist Check (Highest Priority) - const excludeTools = config.getExcludeTools() || []; + const excludeTools = config.getExcludeTools() || new Set([]); const isWildcardBlocked = SHELL_TOOL_NAMES.some((name) => - excludeTools.includes(name), + excludeTools.has(name), ); if (isWildcardBlocked) { @@ -622,7 +622,9 @@ export function checkCommandPermissions( for (const cmd of commandsToValidate) { invocation.params['command'] = cmd; if ( - doesToolInvocationMatch('run_shell_command', invocation, excludeTools) + doesToolInvocationMatch('run_shell_command', invocation, [ + ...excludeTools, + ]) ) { return { allAllowed: false,