Unify shell security policy and remove legacy logic (#15770)

This commit is contained in:
Abhi
2026-01-04 00:19:00 -05:00
committed by GitHub
parent f0a039f7c0
commit d3c206c677
14 changed files with 770 additions and 222 deletions
+37
View File
@@ -858,4 +858,41 @@ name = "invalid-name"
// Priority 10 in default tier → 1.010
expect(discoveredRule?.priority).toBeCloseTo(1.01, 5);
});
it('should normalize legacy "ShellTool" alias to "run_shell_command"', async () => {
vi.resetModules();
// Mock fs to return empty for policies
const actualFs =
await vi.importActual<typeof import('node:fs/promises')>(
'node:fs/promises',
);
const mockReaddir = vi.fn(
async () => [] as unknown as Awaited<ReturnType<typeof actualFs.readdir>>,
);
vi.doMock('node:fs/promises', () => ({
...actualFs,
default: { ...actualFs, readdir: mockReaddir },
readdir: mockReaddir,
}));
const { createPolicyEngineConfig } = await import('./config.js');
const settings: PolicySettings = {
tools: { allowed: ['ShellTool'] },
};
const config = await createPolicyEngineConfig(
settings,
ApprovalMode.DEFAULT,
'/tmp/mock/default/policies',
);
const rule = config.rules?.find(
(r) =>
r.toolName === 'run_shell_command' &&
r.decision === PolicyDecision.ALLOW,
);
expect(rule).toBeDefined();
expect(rule?.priority).toBeCloseTo(2.3, 5); // Command line allow
vi.doUnmock('node:fs/promises');
});
});
+59 -30
View File
@@ -16,11 +16,8 @@ import {
type PolicySettings,
} from './types.js';
import type { PolicyEngine } from './policy-engine.js';
import {
loadPoliciesFromToml,
type PolicyFileError,
escapeRegex,
} from './toml-loader.js';
import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js';
import { buildArgsPatterns } from './utils.js';
import toml from '@iarna/toml';
import {
MessageBusType,
@@ -29,6 +26,8 @@ import {
import { type MessageBus } from '../confirmation-bus/message-bus.js';
import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js';
import { SHELL_TOOL_NAMES } from '../utils/shell-utils.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js';
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
@@ -195,11 +194,48 @@ export async function createPolicyEngineConfig(
// Priority: 2.3 (user tier - explicit temporary allows)
if (settings.tools?.allowed) {
for (const tool of settings.tools.allowed) {
rules.push({
toolName: tool,
decision: PolicyDecision.ALLOW,
priority: 2.3,
});
// Check for legacy format: toolName(args)
const match = tool.match(/^([a-zA-Z0-9_-]+)\((.*)\)$/);
if (match) {
const [, rawToolName, args] = match;
// Normalize shell tool aliases
const toolName = SHELL_TOOL_NAMES.includes(rawToolName)
? SHELL_TOOL_NAME
: rawToolName;
// Treat args as a command prefix for shell tool
if (toolName === SHELL_TOOL_NAME) {
const patterns = buildArgsPatterns(undefined, args);
for (const pattern of patterns) {
if (pattern) {
rules.push({
toolName,
decision: PolicyDecision.ALLOW,
priority: 2.3,
argsPattern: new RegExp(pattern),
});
}
}
} else {
// For non-shell tools, we allow the tool itself but ignore args
// as args matching was only supported for shell tools historically.
rules.push({
toolName,
decision: PolicyDecision.ALLOW,
priority: 2.3,
});
}
} else {
// Standard tool name
const toolName = SHELL_TOOL_NAMES.includes(tool)
? SHELL_TOOL_NAME
: tool;
rules.push({
toolName,
decision: PolicyDecision.ALLOW,
priority: 2.3,
});
}
}
}
@@ -263,26 +299,19 @@ export function createPolicyUpdater(
if (message.commandPrefix) {
// Convert commandPrefix(es) to argsPatterns for in-memory rules
const prefixes = Array.isArray(message.commandPrefix)
? message.commandPrefix
: [message.commandPrefix];
for (const prefix of prefixes) {
const escapedPrefix = escapeRegex(prefix);
// Use robust regex to match whole words (e.g. "git" but not "github")
const argsPattern = new RegExp(
`"command":"${escapedPrefix}(?:[\\s"]|$)`,
);
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
// User tier (2) + high priority (950/1000) = 2.95
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern,
});
const patterns = buildArgsPatterns(undefined, message.commandPrefix);
for (const pattern of patterns) {
if (pattern) {
policyEngine.addRule({
toolName,
decision: PolicyDecision.ALLOW,
// User tier (2) + high priority (950/1000) = 2.95
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern: new RegExp(pattern),
});
}
}
} else {
const argsPattern = message.argsPattern
+1 -1
View File
@@ -127,7 +127,7 @@ describe('createPolicyUpdater', () => {
expect(addedRule).toBeDefined();
expect(addedRule?.priority).toBe(2.95);
expect(addedRule?.argsPattern).toEqual(
new RegExp(`"command":"git status(?:[\\s"]|$)`),
new RegExp(`"command":"git\\ status(?:[\\s"]|$)`),
);
// Verify file written
+275 -1
View File
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, beforeEach, vi } from 'vitest';
import { describe, it, expect, beforeEach, beforeAll, vi } from 'vitest';
import { PolicyEngine } from './policy-engine.js';
import {
PolicyDecision,
@@ -17,11 +17,40 @@ import {
import type { FunctionCall } from '@google/genai';
import { SafetyCheckDecision } from '../safety/protocol.js';
import type { CheckerRunner } from '../safety/checker-runner.js';
import { initializeShellParsers } from '../utils/shell-utils.js';
import { buildArgsPatterns } from './utils.js';
// Mock shell-utils to ensure consistent behavior across platforms (especially Windows CI)
// We want to test PolicyEngine logic, not the shell parser's ability to parse commands
vi.mock('../utils/shell-utils.js', async (importOriginal) => {
const actual =
await importOriginal<typeof import('../utils/shell-utils.js')>();
return {
...actual,
initializeShellParsers: vi.fn().mockResolvedValue(undefined),
splitCommands: vi.fn().mockImplementation((command: string) => {
// Simple mock splitting logic for test cases
if (command.includes('&&')) {
return command.split('&&').map((c) => c.trim());
}
return [command];
}),
hasRedirection: vi.fn().mockImplementation(
(command: string) =>
// Simple mock: true if '>' is present, unless it looks like "-> arrow"
command.includes('>') && !command.includes('-> arrow'),
),
};
});
describe('PolicyEngine', () => {
let engine: PolicyEngine;
let mockCheckerRunner: CheckerRunner;
beforeAll(async () => {
await initializeShellParsers();
});
beforeEach(() => {
mockCheckerRunner = {
runChecker: vi.fn(),
@@ -457,6 +486,29 @@ describe('PolicyEngine', () => {
);
});
it('should correctly match commands with quotes in commandPrefix', async () => {
const prefix = 'git commit -m "fix"';
const patterns = buildArgsPatterns(undefined, prefix);
const rules: PolicyRule[] = [
{
toolName: 'run_shell_command',
argsPattern: new RegExp(patterns[0]!),
decision: PolicyDecision.ALLOW,
},
];
engine = new PolicyEngine({ rules });
const result = await engine.check(
{
name: 'run_shell_command',
args: { command: 'git commit -m "fix"' },
},
undefined,
);
expect(result.decision).toBe(PolicyDecision.ALLOW);
});
it('should handle tools with no args', async () => {
const rules: PolicyRule[] = [
{
@@ -952,6 +1004,228 @@ describe('PolicyEngine', () => {
).toBe(PolicyDecision.ALLOW);
});
it('should preserve dir_path during recursive shell command checks', async () => {
const rules: PolicyRule[] = [
{
toolName: 'run_shell_command',
// Rule that only allows echo in a specific directory
// Note: stableStringify sorts keys alphabetically and has no spaces: {"command":"echo hello","dir_path":"/safe/path"}
argsPattern: /"command":"echo hello".*"dir_path":"\/safe\/path"/,
decision: PolicyDecision.ALLOW,
},
{
// Catch-all ALLOW for shell but with low priority
toolName: 'run_shell_command',
decision: PolicyDecision.ALLOW,
priority: -100,
},
];
engine = new PolicyEngine({ rules });
// Compound command. The decomposition will call check() for "echo hello"
// which should match our specific high-priority rule IF dir_path is preserved.
const result = await engine.check(
{
name: 'run_shell_command',
args: { command: 'echo hello && pwd', dir_path: '/safe/path' },
},
undefined,
);
expect(result.decision).toBe(PolicyDecision.ALLOW);
});
it('should upgrade ASK_USER to ALLOW if all sub-commands are allowed', async () => {
const rules: PolicyRule[] = [
{
toolName: 'run_shell_command',
argsPattern: /"command":"git status/,
decision: PolicyDecision.ALLOW,
priority: 20,
},
{
toolName: 'run_shell_command',
argsPattern: /"command":"ls/,
decision: PolicyDecision.ALLOW,
priority: 20,
},
{
// Catch-all ASK_USER for shell
toolName: 'run_shell_command',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
];
engine = new PolicyEngine({ rules });
// "git status && ls" matches the catch-all ASK_USER rule initially.
// But since both parts are explicitly ALLOWed, the result should be upgraded to ALLOW.
const result = await engine.check(
{
name: 'run_shell_command',
args: { command: 'git status && ls' },
},
undefined,
);
expect(result.decision).toBe(PolicyDecision.ALLOW);
});
it('should respect explicit DENY for compound commands even if parts are allowed', async () => {
const rules: PolicyRule[] = [
{
// Explicitly DENY the compound command
toolName: 'run_shell_command',
argsPattern: /"command":"git status && ls"/,
decision: PolicyDecision.DENY,
priority: 30,
},
{
toolName: 'run_shell_command',
argsPattern: /"command":"git status/,
decision: PolicyDecision.ALLOW,
priority: 20,
},
{
toolName: 'run_shell_command',
argsPattern: /"command":"ls/,
decision: PolicyDecision.ALLOW,
priority: 20,
},
];
engine = new PolicyEngine({ rules });
const result = await engine.check(
{
name: 'run_shell_command',
args: { command: 'git status && ls' },
},
undefined,
);
expect(result.decision).toBe(PolicyDecision.DENY);
});
it('should propagate DENY from any sub-command', async () => {
const rules: PolicyRule[] = [
{
toolName: 'run_shell_command',
argsPattern: /"command":"rm/,
decision: PolicyDecision.DENY,
priority: 20,
},
{
toolName: 'run_shell_command',
argsPattern: /"command":"echo/,
decision: PolicyDecision.ALLOW,
priority: 20,
},
{
toolName: 'run_shell_command',
decision: PolicyDecision.ASK_USER,
priority: 10,
},
];
engine = new PolicyEngine({ rules });
// "echo hello && rm -rf /" -> echo is ALLOW, rm is DENY -> Result DENY
const result = await engine.check(
{
name: 'run_shell_command',
args: { command: 'echo hello && rm -rf /' },
},
undefined,
);
expect(result.decision).toBe(PolicyDecision.DENY);
});
it('should DENY redirected shell commands in non-interactive mode', async () => {
const config: PolicyEngineConfig = {
nonInteractive: true,
rules: [
{
toolName: 'run_shell_command',
decision: PolicyDecision.ALLOW,
},
],
};
engine = new PolicyEngine(config);
// Redirected command should be DENIED in non-interactive mode
// (Normally ASK_USER, but ASK_USER -> DENY in non-interactive)
expect(
(
await engine.check(
{
name: 'run_shell_command',
args: { command: 'echo "hello" > file.txt' },
},
undefined,
)
).decision,
).toBe(PolicyDecision.DENY);
});
it('should default to ASK_USER for atomic commands when matching a wildcard ASK_USER rule', async () => {
// Regression test: atomic commands were auto-allowing because of optimistic initialization
const rules: PolicyRule[] = [
{
toolName: 'run_shell_command',
decision: PolicyDecision.ASK_USER,
},
];
engine = new PolicyEngine({ rules });
// Atomic command "whoami" matches the wildcard rule (ASK_USER).
// It should NOT be upgraded to ALLOW.
expect(
(
await engine.check(
{
name: 'run_shell_command',
args: { command: 'whoami' },
},
undefined,
)
).decision,
).toBe(PolicyDecision.ASK_USER);
});
it('should allow redirected shell commands in non-interactive mode if allowRedirection is true', async () => {
const config: PolicyEngineConfig = {
nonInteractive: true,
rules: [
{
toolName: 'run_shell_command',
decision: PolicyDecision.ALLOW,
allowRedirection: true,
},
],
};
engine = new PolicyEngine(config);
// Redirected command should stay ALLOW even in non-interactive mode
expect(
(
await engine.check(
{
name: 'run_shell_command',
args: { command: 'echo "hello" > file.txt' },
},
undefined,
)
).decision,
).toBe(PolicyDecision.ALLOW);
});
it('should avoid infinite recursion for commands with substitution', async () => {
const rules: PolicyRule[] = [
{
+25 -6
View File
@@ -172,8 +172,13 @@ export class PolicyEngine {
`[PolicyEngine.check] Validating shell command: ${subCommands.length} parts`,
);
// Start with the decision for the full command string
let aggregateDecision = ruleDecision;
if (ruleDecision === PolicyDecision.DENY) {
return PolicyDecision.DENY;
}
// Start optimistically. If all parts are ALLOW, the whole is ALLOW.
// We will downgrade if any part is ASK_USER or DENY.
let aggregateDecision = PolicyDecision.ALLOW;
for (const subCmd of subCommands) {
// Prevent infinite recursion for the root command
@@ -186,6 +191,17 @@ export class PolicyEngine {
if (aggregateDecision === PolicyDecision.ALLOW) {
aggregateDecision = PolicyDecision.ASK_USER;
}
} else {
// If the command is atomic (cannot be split further) and didn't
// trigger infinite recursion checks, we must respect the decision
// of the rule that triggered this check. If the rule was ASK_USER
// (e.g. wildcard), we must downgrade.
if (
ruleDecision === PolicyDecision.ASK_USER &&
aggregateDecision === PolicyDecision.ALLOW
) {
aggregateDecision = PolicyDecision.ASK_USER;
}
}
continue;
}
@@ -195,13 +211,16 @@ export class PolicyEngine {
serverName,
);
// subResult.decision is already filtered through applyNonInteractiveMode by this.check()
const subDecision = subResult.decision;
// If any part is DENIED, the whole command is DENIED
if (subResult.decision === PolicyDecision.DENY) {
if (subDecision === PolicyDecision.DENY) {
return PolicyDecision.DENY;
}
// If any part requires ASK_USER, the whole command requires ASK_USER (unless already DENY)
if (subResult.decision === PolicyDecision.ASK_USER) {
// If any part requires ASK_USER, the whole command requires ASK_USER
if (subDecision === PolicyDecision.ASK_USER) {
if (aggregateDecision === PolicyDecision.ALLOW) {
aggregateDecision = PolicyDecision.ASK_USER;
}
@@ -209,7 +228,7 @@ export class PolicyEngine {
// Check for redirection in allowed sub-commands
if (
subResult.decision === PolicyDecision.ALLOW &&
subDecision === PolicyDecision.ALLOW &&
!allowRedirection &&
hasRedirection(subCmd)
) {
@@ -157,6 +157,21 @@ modes = ["yolo"]
expect(result.errors).toHaveLength(0);
});
it('should parse and transform allow_redirection property', async () => {
const result = await runLoadPoliciesFromToml(`
[[rule]]
toolName = "run_shell_command"
commandPrefix = "echo"
decision = "allow"
priority = 100
allow_redirection = true
`);
expect(result.rules).toHaveLength(1);
expect(result.rules[0].allowRedirection).toBe(true);
expect(result.errors).toHaveLength(0);
});
it('should return error if modes property is used for Tier 2 and Tier 3 policies', async () => {
await fs.writeFile(
path.join(tempDir, 'tier2.toml'),
+13 -49
View File
@@ -12,6 +12,7 @@ import {
type SafetyCheckerRule,
InProcessCheckerType,
} from './types.js';
import { buildArgsPatterns } from './utils.js';
import fs from 'node:fs/promises';
import path from 'node:path';
import toml from '@iarna/toml';
@@ -44,6 +45,7 @@ const PolicyRuleSchema = z.object({
'priority must be <= 999 to prevent tier overflow. Priorities >= 1000 would jump to the next tier.',
}),
modes: z.array(z.nativeEnum(ApprovalMode)).optional(),
allow_redirection: z.boolean().optional(),
});
/**
@@ -119,17 +121,6 @@ export interface PolicyLoadResult {
errors: PolicyFileError[];
}
/**
* Escapes special regex characters in a string for use in a regex pattern.
* This is used for commandPrefix to ensure literal string matching.
*
* @param str The string to escape
* @returns The escaped string safe for use in a regex
*/
export function escapeRegex(str: string): string {
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
}
/**
* Converts a tier number to a human-readable tier name.
*/
@@ -354,26 +345,11 @@ export async function loadPoliciesFromToml(
// Transform rules
const parsedRules: PolicyRule[] = (validationResult.data.rule ?? [])
.flatMap((rule) => {
// Transform commandPrefix/commandRegex to argsPattern
let effectiveArgsPattern = rule.argsPattern;
const commandPrefixes: string[] = [];
if (rule.commandPrefix) {
const prefixes = Array.isArray(rule.commandPrefix)
? rule.commandPrefix
: [rule.commandPrefix];
commandPrefixes.push(...prefixes);
} else if (rule.commandRegex) {
effectiveArgsPattern = `"command":"${rule.commandRegex}`;
}
// Expand command prefixes to multiple patterns
const argsPatterns: Array<string | undefined> =
commandPrefixes.length > 0
? commandPrefixes.map(
(prefix) => `"command":"${escapeRegex(prefix)}(?:[\\s"]|$)`,
)
: [effectiveArgsPattern];
const argsPatterns = buildArgsPatterns(
rule.argsPattern,
rule.commandPrefix,
rule.commandRegex,
);
// For each argsPattern, expand toolName arrays
return argsPatterns.flatMap((argsPattern) => {
@@ -400,6 +376,7 @@ export async function loadPoliciesFromToml(
decision: rule.decision,
priority: transformPriority(rule.priority, tier),
modes: tier === 1 ? rule.modes : undefined,
allowRedirection: rule.allow_redirection,
};
// Compile regex pattern
@@ -436,24 +413,11 @@ export async function loadPoliciesFromToml(
validationResult.data.safety_checker ?? []
)
.flatMap((checker) => {
let effectiveArgsPattern = checker.argsPattern;
const commandPrefixes: string[] = [];
if (checker.commandPrefix) {
const prefixes = Array.isArray(checker.commandPrefix)
? checker.commandPrefix
: [checker.commandPrefix];
commandPrefixes.push(...prefixes);
} else if (checker.commandRegex) {
effectiveArgsPattern = `"command":"${checker.commandRegex}`;
}
const argsPatterns: Array<string | undefined> =
commandPrefixes.length > 0
? commandPrefixes.map(
(prefix) => `"command":"${escapeRegex(prefix)}(?:[\\s"]|$)`,
)
: [effectiveArgsPattern];
const argsPatterns = buildArgsPatterns(
checker.argsPattern,
checker.commandPrefix,
checker.commandRegex,
);
return argsPatterns.flatMap((argsPattern) => {
const toolNames: Array<string | undefined> = checker.toolName
+77
View File
@@ -0,0 +1,77 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect } from 'vitest';
import { escapeRegex, buildArgsPatterns } from './utils.js';
describe('policy/utils', () => {
describe('escapeRegex', () => {
it('should escape special regex characters', () => {
const input = '.-*+?^${}()|[]\\ "';
const escaped = escapeRegex(input);
expect(escaped).toBe(
'\\.\\-\\*\\+\\?\\^\\$\\{\\}\\(\\)\\|\\[\\]\\\\\\ \\"',
);
});
it('should return the same string if no special characters are present', () => {
const input = 'abcABC123';
expect(escapeRegex(input)).toBe(input);
});
});
describe('buildArgsPatterns', () => {
it('should return argsPattern if provided and no commandPrefix/regex', () => {
const result = buildArgsPatterns('my-pattern', undefined, undefined);
expect(result).toEqual(['my-pattern']);
});
it('should build pattern from a single commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'ls', undefined);
expect(result).toEqual(['"command":"ls(?:[\\s"]|$)']);
});
it('should build patterns from an array of commandPrefixes', () => {
const result = buildArgsPatterns(undefined, ['ls', 'cd'], undefined);
expect(result).toEqual([
'"command":"ls(?:[\\s"]|$)',
'"command":"cd(?:[\\s"]|$)',
]);
});
it('should build pattern from commandRegex', () => {
const result = buildArgsPatterns(undefined, undefined, 'rm -rf .*');
expect(result).toEqual(['"command":"rm -rf .*']);
});
it('should prioritize commandPrefix over commandRegex and argsPattern', () => {
const result = buildArgsPatterns('raw', 'prefix', 'regex');
expect(result).toEqual(['"command":"prefix(?:[\\s"]|$)']);
});
it('should prioritize commandRegex over argsPattern if no commandPrefix', () => {
const result = buildArgsPatterns('raw', undefined, 'regex');
expect(result).toEqual(['"command":"regex']);
});
it('should escape characters in commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'git checkout -b', undefined);
expect(result).toEqual(['"command":"git\\ checkout\\ \\-b(?:[\\s"]|$)']);
});
it('should correctly escape quotes in commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'git "fix"', undefined);
expect(result).toEqual([
'"command":"git\\ \\\\\\"fix\\\\\\"(?:[\\s"]|$)',
]);
});
it('should handle undefined correctly when no inputs are provided', () => {
const result = buildArgsPatterns(undefined, undefined, undefined);
expect(result).toEqual([undefined]);
});
});
});
+50
View File
@@ -0,0 +1,50 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
/**
* Escapes a string for use in a regular expression.
*/
export function escapeRegex(text: string): string {
return text.replace(/[-[\]{}()*+?.,\\^$|#\s"]/g, '\\$&');
}
/**
* Builds a list of args patterns for policy matching.
*
* This function handles the transformation of command prefixes and regexes into
* the internal argsPattern representation used by the PolicyEngine.
*
* @param argsPattern An optional raw regex string for arguments.
* @param commandPrefix An optional command prefix (or list of prefixes) to allow.
* @param commandRegex An optional command regex string to allow.
* @returns An array of string patterns (or undefined) for the PolicyEngine.
*/
export function buildArgsPatterns(
argsPattern?: string,
commandPrefix?: string | string[],
commandRegex?: string,
): Array<string | undefined> {
if (commandPrefix) {
const prefixes = Array.isArray(commandPrefix)
? commandPrefix
: [commandPrefix];
// Expand command prefixes to multiple patterns.
// We append [\\s"] to ensure we match whole words only (e.g., "git" but not
// "github"). Since we match against JSON stringified args, the value is
// always followed by a space or a closing quote.
return prefixes.map((prefix) => {
const jsonPrefix = JSON.stringify(prefix).slice(1, -1);
return `"command":"${escapeRegex(jsonPrefix)}(?:[\\s"]|$)`;
});
}
if (commandRegex) {
return [`"command":"${commandRegex}`];
}
return [argsPattern];
}
@@ -24,6 +24,7 @@ export class MockMessageBus {
publishedMessages: Message[] = [];
hookRequests: HookExecutionRequest[] = [];
hookResponses: HookExecutionResponse[] = [];
defaultToolDecision: 'allow' | 'deny' | 'ask_user' = 'allow';
/**
* Mock publish method that captures messages and simulates responses
@@ -50,6 +51,34 @@ export class MockMessageBus {
// Emit response to subscribers
this.emit(MessageBusType.HOOK_EXECUTION_RESPONSE, response);
}
// Handle tool confirmation requests
if (message.type === MessageBusType.TOOL_CONFIRMATION_REQUEST) {
if (this.defaultToolDecision === 'allow') {
this.emit(MessageBusType.TOOL_CONFIRMATION_RESPONSE, {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: message.correlationId,
confirmed: true,
});
} else if (this.defaultToolDecision === 'deny') {
this.emit(MessageBusType.TOOL_CONFIRMATION_RESPONSE, {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: message.correlationId,
confirmed: false,
});
} else {
// ask_user
this.emit(MessageBusType.TOOL_CONFIRMATION_RESPONSE, {
type: MessageBusType.TOOL_CONFIRMATION_RESPONSE,
correlationId: message.correlationId,
confirmed: false,
requiresUserConfirmation: true,
});
}
}
// Emit the message to subscribers (mimicking real MessageBus behavior)
this.emit(message.type, message);
});
/**
+62 -82
View File
@@ -37,7 +37,6 @@ vi.mock('crypto');
vi.mock('../utils/summarizer.js');
import { initializeShellParsers } from '../utils/shell-utils.js';
import { isCommandAllowed } from '../utils/shell-permissions.js';
import { ShellTool } from './shell.js';
import { type Config } from '../config/config.js';
import {
@@ -51,9 +50,23 @@ import * as path from 'node:path';
import * as crypto from 'node:crypto';
import * as summarizer from '../utils/summarizer.js';
import { ToolErrorType } from './tool-error.js';
import { ToolConfirmationOutcome } from './tools.js';
import { OUTPUT_UPDATE_INTERVAL_MS } from './shell.js';
import { SHELL_TOOL_NAME } from './tool-names.js';
import { WorkspaceContext } from '../utils/workspaceContext.js';
import {
createMockMessageBus,
getMockMessageBusInstance,
} from '../test-utils/mock-message-bus.js';
import {
MessageBusType,
type UpdatePolicy,
} from '../confirmation-bus/types.js';
import { type MessageBus } from '../confirmation-bus/message-bus.js';
interface TestableMockMessageBus extends MessageBus {
defaultToolDecision: 'allow' | 'deny' | 'ask_user';
}
const originalComSpec = process.env['ComSpec'];
const itWindowsOnly = process.platform === 'win32' ? it : it.skip;
@@ -92,7 +105,29 @@ describe('ShellTool', () => {
getShellToolInactivityTimeout: vi.fn().mockReturnValue(300000),
} as unknown as Config;
shellTool = new ShellTool(mockConfig);
const bus = createMockMessageBus();
const mockBus = getMockMessageBusInstance(
bus,
) as unknown as TestableMockMessageBus;
mockBus.defaultToolDecision = 'ask_user';
// Simulate policy update
bus.subscribe(MessageBusType.UPDATE_POLICY, (msg: UpdatePolicy) => {
if (msg.commandPrefix) {
const prefixes = Array.isArray(msg.commandPrefix)
? msg.commandPrefix
: [msg.commandPrefix];
const current = mockConfig.getAllowedTools() || [];
(mockConfig.getAllowedTools as Mock).mockReturnValue([
...current,
...prefixes,
]);
// Simulate Policy Engine allowing the tool after update
mockBus.defaultToolDecision = 'allow';
}
});
shellTool = new ShellTool(mockConfig, bus);
mockPlatform.mockReturnValue('linux');
(vi.mocked(crypto.randomBytes) as Mock).mockReturnValue(
@@ -124,25 +159,6 @@ describe('ShellTool', () => {
}
});
describe('isCommandAllowed', () => {
it('should allow a command if no restrictions are provided', () => {
(mockConfig.getCoreTools as Mock).mockReturnValue(undefined);
(mockConfig.getExcludeTools as Mock).mockReturnValue(undefined);
expect(isCommandAllowed('goodCommand --safe', mockConfig).allowed).toBe(
true,
);
});
it('should allow a command with command substitution using $()', () => {
const evaluation = isCommandAllowed(
'echo $(goodCommand --safe)',
mockConfig,
);
expect(evaluation.allowed).toBe(true);
expect(evaluation.reason).toBeUndefined();
});
});
describe('build', () => {
it('should return an invocation for a valid command', () => {
const invocation = shellTool.build({ command: 'goodCommand --safe' });
@@ -471,90 +487,54 @@ describe('ShellTool', () => {
});
describe('shouldConfirmExecute', () => {
it('should return confirmation details when PolicyEngine delegates', async () => {
it('should request confirmation for a new command and allowlist it on "Always"', async () => {
const params = { command: 'npm install' };
const invocation = shellTool.build(params);
// Accessing protected messageBus for testing purposes
const bus = (shellTool as unknown as { messageBus: MessageBus })
.messageBus;
const mockBus = getMockMessageBusInstance(
bus,
) as unknown as TestableMockMessageBus;
// Initially needs confirmation
mockBus.defaultToolDecision = 'ask_user';
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).not.toBe(false);
expect(confirmation && confirmation.type).toBe('exec');
if (confirmation && confirmation.type === 'exec') {
await confirmation.onConfirm(ToolConfirmationOutcome.ProceedAlways);
}
// After "Always", it should be allowlisted in the mock engine
mockBus.defaultToolDecision = 'allow';
const secondInvocation = shellTool.build({ command: 'npm test' });
const secondConfirmation = await secondInvocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(secondConfirmation).toBe(false);
});
it('should throw an error if validation fails', () => {
expect(() => shellTool.build({ command: '' })).toThrow();
});
describe('in non-interactive mode', () => {
beforeEach(() => {
(mockConfig.isInteractive as Mock).mockReturnValue(false);
});
it('should not throw an error or block for an allowed command', async () => {
(mockConfig.getAllowedTools as Mock).mockReturnValue(['ShellTool(wc)']);
const invocation = shellTool.build({ command: 'wc -l foo.txt' });
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
it('should not throw an error or block for an allowed command with arguments', async () => {
(mockConfig.getAllowedTools as Mock).mockReturnValue([
'ShellTool(wc -l)',
]);
const invocation = shellTool.build({ command: 'wc -l foo.txt' });
const confirmation = await invocation.shouldConfirmExecute(
new AbortController().signal,
);
expect(confirmation).toBe(false);
});
it('should throw an error for command that is not allowed', async () => {
(mockConfig.getAllowedTools as Mock).mockReturnValue([
'ShellTool(wc -l)',
]);
const invocation = shellTool.build({ command: 'madeupcommand' });
await expect(
invocation.shouldConfirmExecute(new AbortController().signal),
).rejects.toThrow('madeupcommand');
});
it('should throw an error for a command that is a prefix of an allowed command', async () => {
(mockConfig.getAllowedTools as Mock).mockReturnValue([
'ShellTool(wc -l)',
]);
const invocation = shellTool.build({ command: 'wc' });
await expect(
invocation.shouldConfirmExecute(new AbortController().signal),
).rejects.toThrow('wc');
});
it('should require all segments of a chained command to be allowlisted', async () => {
(mockConfig.getAllowedTools as Mock).mockReturnValue([
'ShellTool(echo)',
]);
const invocation = shellTool.build({ command: 'echo "foo" && ls -l' });
await expect(
invocation.shouldConfirmExecute(new AbortController().signal),
).rejects.toThrow(
'Command "echo "foo" && ls -l" is not in the list of allowed tools for non-interactive mode.',
);
});
});
});
describe('getDescription', () => {
it('should return the windows description when on windows', () => {
mockPlatform.mockReturnValue('win32');
const shellTool = new ShellTool(mockConfig);
const shellTool = new ShellTool(mockConfig, createMockMessageBus());
expect(shellTool.description).toMatchSnapshot();
});
it('should return the non-windows description when not on windows', () => {
mockPlatform.mockReturnValue('linux');
const shellTool = new ShellTool(mockConfig);
const shellTool = new ShellTool(mockConfig, createMockMessageBus());
expect(shellTool.description).toMatchSnapshot();
});
});
+8 -45
View File
@@ -9,7 +9,7 @@ import path from 'node:path';
import os, { EOL } from 'node:os';
import crypto from 'node:crypto';
import type { Config } from '../config/config.js';
import { debugLogger, type AnyToolInvocation } from '../index.js';
import { debugLogger } from '../index.js';
import { ToolErrorType } from './tool-error.js';
import type {
ToolInvocation,
@@ -24,7 +24,6 @@ import {
Kind,
type PolicyUpdateOptions,
} from './tools.js';
import { ApprovalMode } from '../policy/types.js';
import { getErrorMessage } from '../utils/errors.js';
import { summarizeToolOutput } from '../utils/summarizer.js';
@@ -40,10 +39,6 @@ import {
initializeShellParsers,
stripShellWrapper,
} from '../utils/shell-utils.js';
import {
isCommandAllowed,
isShellInvocationAllowlisted,
} from '../utils/shell-permissions.js';
import { SHELL_TOOL_NAME } from './tool-names.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
@@ -106,24 +101,15 @@ export class ShellToolInvocation extends BaseToolInvocation<
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
const command = stripShellWrapper(this.params.command);
const rootCommands = [...new Set(getCommandRoots(command))];
let rootCommands = [...new Set(getCommandRoots(command))];
// In non-interactive mode, we need to prevent the tool from hanging while
// waiting for user input. If a tool is not fully allowed (e.g. via
// --allowed-tools="ShellTool(wc)"), we should throw an error instead of
// prompting for confirmation. This check is skipped in YOLO mode.
if (
!this.config.isInteractive() &&
this.config.getApprovalMode() !== ApprovalMode.YOLO
) {
if (this.isInvocationAllowlisted(command)) {
// If it's an allowed shell command, we don't need to confirm execution.
return false;
// Fallback for UI display if parser fails or returns no commands (e.g.
// variable assignments only)
if (rootCommands.length === 0 && command.trim()) {
const fallback = command.trim().split(/\s+/)[0];
if (fallback) {
rootCommands = [fallback];
}
throw new Error(
`Command "${command}" is not in the list of allowed tools for non-interactive mode.`,
);
}
// Rely entirely on PolicyEngine for interactive confirmation.
@@ -394,16 +380,6 @@ export class ShellToolInvocation extends BaseToolInvocation<
}
}
}
private isInvocationAllowlisted(command: string): boolean {
const allowedTools = this.config.getAllowedTools() || [];
if (allowedTools.length === 0) {
return false;
}
const invocation = { params: { command } } as unknown as AnyToolInvocation;
return isShellInvocationAllowlisted(invocation, allowedTools);
}
}
function getShellToolDescription(): string {
@@ -487,19 +463,6 @@ export class ShellTool extends BaseDeclarativeTool<
return 'Command cannot be empty.';
}
const commandCheck = isCommandAllowed(params.command, this.config);
if (!commandCheck.allowed) {
if (!commandCheck.reason) {
debugLogger.error(
'Unexpected: isCommandAllowed returned false without a reason',
);
return `Command is not allowed: ${params.command}`;
}
return commandCheck.reason;
}
if (getCommandRoots(params.command).length === 0) {
return 'Could not identify command root to obtain permission from user.';
}
if (params.dir_path) {
const resolvedPath = path.resolve(
this.config.getTargetDir(),