diff --git a/integration-tests/run_shell_command.test.ts b/integration-tests/run_shell_command.test.ts index c0b06bdb2a..a0aab14193 100644 --- a/integration-tests/run_shell_command.test.ts +++ b/integration-tests/run_shell_command.test.ts @@ -13,15 +13,11 @@ const { shell } = getShellConfiguration(); function getLineCountCommand(): { command: string; tool: string } { switch (shell) { case 'powershell': - return { - command: `(Get-Content test.txt).Length`, - tool: 'Get-Content', - }; case 'cmd': - return { command: `find /c /v "" test.txt`, tool: 'find' }; + return { command: `find /c /v`, tool: 'find' }; case 'bash': default: - return { command: `wc -l test.txt`, tool: 'wc' }; + return { command: `wc -l`, tool: 'wc' }; } } @@ -91,8 +87,8 @@ describe('run_shell_command', () => { await rig.setup('should run allowed sub-command in non-interactive mode'); const testFile = rig.createFile('test.txt', 'Lorem\nIpsum\nDolor\n'); - const { tool } = getLineCountCommand(); - const prompt = `use ${tool} to tell me how many lines there are in ${testFile}`; + const { tool, command } = getLineCountCommand(); + const prompt = `use ${command} to tell me how many lines there are in ${testFile}`; // Provide the prompt via stdin to simulate non-interactive mode const result = await rig.run( @@ -129,8 +125,8 @@ describe('run_shell_command', () => { await rig.setup('should succeed with no parens in non-interactive mode'); const testFile = rig.createFile('test.txt', 'Lorem\nIpsum\nDolor\n'); - const { tool } = getLineCountCommand(); - const prompt = `use ${tool} to tell me how many lines there are in ${testFile}`; + const { command } = getLineCountCommand(); + const prompt = `use ${command} to tell me how many lines there are in ${testFile}`; const result = await rig.run( { @@ -166,8 +162,8 @@ describe('run_shell_command', () => { await rig.setup('should succeed with --yolo mode'); const testFile = rig.createFile('test.txt', 'Lorem\nIpsum\nDolor\n'); - const { tool } = getLineCountCommand(); - const prompt = `use ${tool} to tell me how many lines there are in ${testFile}`; + const { command } = getLineCountCommand(); + const prompt = `use ${command} to tell me how many lines there are in ${testFile}`; const result = await rig.run({ prompt: prompt, @@ -200,8 +196,8 @@ describe('run_shell_command', () => { await rig.setup('should work with ShellTool alias'); const testFile = rig.createFile('test.txt', 'Lorem\nIpsum\nDolor\n'); - const { tool } = getLineCountCommand(); - const prompt = `use ${tool} to tell me how many lines there are in ${testFile}`; + const { tool, command } = getLineCountCommand(); + const prompt = `use ${command} to tell me how many lines there are in ${testFile}`; const result = await rig.run( { @@ -238,9 +234,9 @@ describe('run_shell_command', () => { const rig = new TestRig(); await rig.setup('should combine multiple --allowed-tools flags'); - const { tool } = getLineCountCommand(); + const { tool, command } = getLineCountCommand(); const prompt = - `use both ${tool} and ls to count the number of lines in files in this ` + + `use both ${command} and ls to count the number of lines in files in this ` + `directory. Do not pipe these commands into each other, run them separately.`; const result = await rig.run( diff --git a/packages/core/src/tools/shell.test.ts b/packages/core/src/tools/shell.test.ts index d5854df49f..7c861eeac4 100644 --- a/packages/core/src/tools/shell.test.ts +++ b/packages/core/src/tools/shell.test.ts @@ -51,6 +51,8 @@ describe('ShellTool', () => { vi.clearAllMocks(); mockConfig = { + getAllowedTools: vi.fn().mockReturnValue([]), + getApprovalMode: vi.fn().mockReturnValue('strict'), getCoreTools: vi.fn().mockReturnValue([]), getExcludeTools: vi.fn().mockReturnValue([]), getDebugMode: vi.fn().mockReturnValue(false), @@ -410,6 +412,52 @@ describe('ShellTool', () => { it('should throw an error if validation fails', () => { expect(() => shellTool.build({ command: '' })).toThrow(); }); + + describe('in non-interactive mode', () => { + beforeEach(() => { + (mockConfig.isInteractive as Mock).mockReturnValue(false); + }); + + it('should not throw an error or block for an allowed command', async () => { + (mockConfig.getAllowedTools as Mock).mockReturnValue(['ShellTool(wc)']); + const invocation = shellTool.build({ command: 'wc -l foo.txt' }); + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(confirmation).toBe(false); + }); + + it('should not throw an error or block for an allowed command with arguments', async () => { + (mockConfig.getAllowedTools as Mock).mockReturnValue([ + 'ShellTool(wc -l)', + ]); + const invocation = shellTool.build({ command: 'wc -l foo.txt' }); + const confirmation = await invocation.shouldConfirmExecute( + new AbortController().signal, + ); + expect(confirmation).toBe(false); + }); + + it('should throw an error for command that is not allowed', async () => { + (mockConfig.getAllowedTools as Mock).mockReturnValue([ + 'ShellTool(wc -l)', + ]); + const invocation = shellTool.build({ command: 'madeupcommand' }); + await expect( + invocation.shouldConfirmExecute(new AbortController().signal), + ).rejects.toThrow('madeupcommand'); + }); + + it('should throw an error for a command that is a prefix of an allowed command', async () => { + (mockConfig.getAllowedTools as Mock).mockReturnValue([ + 'ShellTool(wc -l)', + ]); + const invocation = shellTool.build({ command: 'wc' }); + await expect( + invocation.shouldConfirmExecute(new AbortController().signal), + ).rejects.toThrow('wc'); + }); + }); }); describe('getDescription', () => { diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index f9d017b626..29a75b024b 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -38,55 +38,10 @@ import { SHELL_TOOL_NAMES, stripShellWrapper, } from '../utils/shell-utils.js'; +import { doesToolInvocationMatch } from '../utils/tool-utils.js'; export const OUTPUT_UPDATE_INTERVAL_MS = 1000; -/** - * Parses the `--allowed-tools` flag to determine which sub-commands of the - * ShellTool are allowed. The flag can be provided multiple times. - * - * @param allowedTools The list of allowed tools from the config. - * @returns A Set of allowed sub-commands, or null if all commands are allowed. - * - `null`: All sub-commands are allowed (e.g., --allowed-tools="ShellTool"). - * - `Set`: A set of specifically allowed sub-commands (e.g., --allowed-tools="ShellTool(wc)" --allowed-tools="ShellTool(ls)"). - * - `Set<>` (empty): No sub-commands are allowed (e.g., --allowed-tools="ShellTool()"). - */ -function parseAllowedSubcommands( - allowedTools: readonly string[], -): Set | null { - const shellToolEntries = allowedTools.filter((tool) => - SHELL_TOOL_NAMES.some((name) => tool.startsWith(name)), - ); - - if (shellToolEntries.length === 0) { - return new Set(); // ShellTool not mentioned, so no subcommands are allowed. - } - - // If any entry is just "run_shell_command" or "ShellTool", all subcommands are allowed. - if (shellToolEntries.some((entry) => SHELL_TOOL_NAMES.includes(entry))) { - return null; - } - - const allSubcommands = new Set(); - const toolNamePattern = SHELL_TOOL_NAMES.join('|'); - const regex = new RegExp(`^(${toolNamePattern})\\((.*)\\)$`); - - for (const entry of shellToolEntries) { - const match = entry.match(regex); - if (match) { - const subcommands = match[2]; - if (subcommands) { - subcommands - .split(',') - .map((s) => s.trim()) - .forEach((s) => s && allSubcommands.add(s)); - } - } - } - - return allSubcommands; -} - export interface ShellToolParams { command: string; description?: string; @@ -133,19 +88,16 @@ export class ShellToolInvocation extends BaseToolInvocation< !this.config.isInteractive() && this.config.getApprovalMode() !== ApprovalMode.YOLO ) { - const allowed = this.config.getAllowedTools() || []; - const allowedSubcommands = parseAllowedSubcommands(allowed); - if (allowedSubcommands !== null) { - // Not all commands are allowed, so we need to check. - const allCommandsAllowed = rootCommands.every((cmd) => - allowedSubcommands.has(cmd), - ); - if (!allCommandsAllowed) { - throw new Error( - `Command "${command}" is not in the list of allowed tools for non-interactive mode.`, - ); - } + const allowedTools = this.config.getAllowedTools() || []; + const [SHELL_TOOL_NAME] = SHELL_TOOL_NAMES; + if (doesToolInvocationMatch(SHELL_TOOL_NAME, command, allowedTools)) { + // If it's an allowed shell command, we don't need to confirm execution. + return false; } + + throw new Error( + `Command "${command}" is not in the list of allowed tools for non-interactive mode.`, + ); } const commandsToConfirm = rootCommands.filter( diff --git a/packages/core/src/utils/tool-utils.ts b/packages/core/src/utils/tool-utils.ts index fe3da9856e..ebf73f90bd 100644 --- a/packages/core/src/utils/tool-utils.ts +++ b/packages/core/src/utils/tool-utils.ts @@ -12,7 +12,7 @@ import { SHELL_TOOL_NAMES } from './shell-utils.js'; * Checks if a tool invocation matches any of a list of patterns. * * @param toolOrToolName The tool object or the name of the tool being invoked. - * @param invocation The invocation object for the tool. + * @param invocation The invocation object for the tool or the command invoked. * @param patterns A list of patterns to match against. * Patterns can be: * - A tool name (e.g., "ReadFileTool") to match any invocation of that tool. @@ -22,7 +22,7 @@ import { SHELL_TOOL_NAMES } from './shell-utils.js'; */ export function doesToolInvocationMatch( toolOrToolName: AnyDeclarativeTool | string, - invocation: AnyToolInvocation, + invocation: AnyToolInvocation | string, patterns: string[], ): boolean { let toolNames: string[]; @@ -58,14 +58,19 @@ export function doesToolInvocationMatch( const argPattern = pattern.substring(openParen + 1, pattern.length - 1); - if ( - 'command' in invocation.params && - toolNames.some((name) => SHELL_TOOL_NAMES.includes(name)) - ) { - const argValue = String( - (invocation.params as { command: string }).command, - ); - if (argValue === argPattern || argValue.startsWith(argPattern + ' ')) { + let command: string; + if (typeof invocation === 'string') { + command = invocation; + } else { + if (!('command' in invocation.params)) { + // This invocation has no command - nothing to check. + continue; + } + command = String((invocation.params as { command: string }).command); + } + + if (toolNames.some((name) => SHELL_TOOL_NAMES.includes(name))) { + if (command === argPattern || command.startsWith(argPattern + ' ')) { return true; } }