mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-04-27 13:34:15 -07:00
326 lines
9.7 KiB
TypeScript
326 lines
9.7 KiB
TypeScript
/**
|
|
* @license
|
|
* Copyright 2025 Google LLC
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
*/
|
|
|
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
|
import * as fs from 'node:fs/promises';
|
|
import { createPolicyUpdater, ALWAYS_ALLOW_PRIORITY } from './config.js';
|
|
import { PolicyEngine } from './policy-engine.js';
|
|
import { MessageBus } from '../confirmation-bus/message-bus.js';
|
|
import { MessageBusType } from '../confirmation-bus/types.js';
|
|
import { Storage } from '../config/storage.js';
|
|
import toml from '@iarna/toml';
|
|
import { ShellToolInvocation } from '../tools/shell.js';
|
|
import { type Config } from '../config/config.js';
|
|
import {
|
|
ToolConfirmationOutcome,
|
|
type PolicyUpdateOptions,
|
|
} from '../tools/tools.js';
|
|
import * as shellUtils from '../utils/shell-utils.js';
|
|
import { escapeRegex } from './utils.js';
|
|
|
|
vi.mock('node:fs/promises');
|
|
vi.mock('../config/storage.js');
|
|
vi.mock('../utils/shell-utils.js', () => ({
|
|
getCommandRoots: vi.fn(),
|
|
stripShellWrapper: vi.fn(),
|
|
hasRedirection: vi.fn(),
|
|
}));
|
|
interface ParsedPolicy {
|
|
rule?: Array<{
|
|
commandPrefix?: string | string[];
|
|
mcpName?: string;
|
|
toolName?: string;
|
|
}>;
|
|
}
|
|
|
|
interface TestableShellToolInvocation {
|
|
getPolicyUpdateOptions(
|
|
outcome: ToolConfirmationOutcome,
|
|
): PolicyUpdateOptions | undefined;
|
|
}
|
|
|
|
describe('createPolicyUpdater', () => {
|
|
let policyEngine: PolicyEngine;
|
|
let messageBus: MessageBus;
|
|
let mockStorage: Storage;
|
|
|
|
beforeEach(() => {
|
|
vi.resetAllMocks();
|
|
policyEngine = new PolicyEngine({});
|
|
vi.spyOn(policyEngine, 'addRule');
|
|
|
|
messageBus = new MessageBus(policyEngine);
|
|
mockStorage = new Storage('/mock/project');
|
|
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(
|
|
'/mock/user/.gemini/policies/auto-saved.toml',
|
|
);
|
|
});
|
|
|
|
afterEach(() => {
|
|
vi.restoreAllMocks();
|
|
});
|
|
|
|
it('should add multiple rules when commandPrefix is an array', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'run_shell_command',
|
|
commandPrefix: ['echo', 'ls'],
|
|
mcpName: 'test-mcp',
|
|
persist: false,
|
|
});
|
|
|
|
expect(policyEngine.addRule).toHaveBeenCalledTimes(2);
|
|
expect(policyEngine.addRule).toHaveBeenNthCalledWith(
|
|
1,
|
|
expect.objectContaining({
|
|
toolName: 'run_shell_command',
|
|
priority: ALWAYS_ALLOW_PRIORITY,
|
|
mcpName: 'test-mcp',
|
|
argsPattern: new RegExp(
|
|
escapeRegex('"command":"echo') + '(?:[\\s"]|\\\\")',
|
|
),
|
|
}),
|
|
);
|
|
expect(policyEngine.addRule).toHaveBeenNthCalledWith(
|
|
2,
|
|
expect.objectContaining({
|
|
toolName: 'run_shell_command',
|
|
priority: ALWAYS_ALLOW_PRIORITY,
|
|
mcpName: 'test-mcp',
|
|
argsPattern: new RegExp(
|
|
escapeRegex('"command":"ls') + '(?:[\\s"]|\\\\")',
|
|
),
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should pass mcpName to policyEngine.addRule for argsPattern updates', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'test_tool',
|
|
argsPattern: '"foo":"bar"',
|
|
mcpName: 'test-mcp',
|
|
persist: false,
|
|
});
|
|
|
|
expect(policyEngine.addRule).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
toolName: 'test_tool',
|
|
mcpName: 'test-mcp',
|
|
argsPattern: /"foo":"bar"/,
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should persist mcpName to TOML', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
|
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
|
|
|
const mockFileHandle = {
|
|
writeFile: vi.fn().mockResolvedValue(undefined),
|
|
close: vi.fn().mockResolvedValue(undefined),
|
|
};
|
|
vi.mocked(fs.open).mockResolvedValue(
|
|
mockFileHandle as unknown as fs.FileHandle,
|
|
);
|
|
vi.mocked(fs.rename).mockResolvedValue(undefined);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'mcp_test-mcp_tool',
|
|
mcpName: 'test-mcp',
|
|
commandPrefix: 'ls',
|
|
persist: true,
|
|
});
|
|
|
|
// Wait for the async listener to complete
|
|
await new Promise((resolve) => setTimeout(resolve, 0));
|
|
|
|
expect(fs.open).toHaveBeenCalled();
|
|
const [content] = mockFileHandle.writeFile.mock.calls[0] as [
|
|
string,
|
|
string,
|
|
];
|
|
const parsed = toml.parse(content) as unknown as ParsedPolicy;
|
|
|
|
expect(parsed.rule).toHaveLength(1);
|
|
expect(parsed.rule![0].mcpName).toBe('test-mcp');
|
|
expect(parsed.rule![0].toolName).toBe('tool'); // toolName should be stripped of MCP prefix
|
|
});
|
|
|
|
it('should add a single rule when commandPrefix is a string', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'run_shell_command',
|
|
commandPrefix: 'git',
|
|
persist: false,
|
|
});
|
|
|
|
expect(policyEngine.addRule).toHaveBeenCalledTimes(1);
|
|
expect(policyEngine.addRule).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
toolName: 'run_shell_command',
|
|
priority: ALWAYS_ALLOW_PRIORITY,
|
|
argsPattern: new RegExp(
|
|
escapeRegex('"command":"git') + '(?:[\\s"]|\\\\")',
|
|
),
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should pass allowRedirection to policyEngine.addRule', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'run_shell_command',
|
|
commandPrefix: 'ls',
|
|
persist: false,
|
|
allowRedirection: true,
|
|
});
|
|
|
|
expect(policyEngine.addRule).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
toolName: 'run_shell_command',
|
|
allowRedirection: true,
|
|
}),
|
|
);
|
|
});
|
|
|
|
it('should persist multiple rules correctly to TOML', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
vi.mocked(fs.readFile).mockRejectedValue({ code: 'ENOENT' });
|
|
vi.mocked(fs.mkdir).mockResolvedValue(undefined);
|
|
|
|
const mockFileHandle = {
|
|
writeFile: vi.fn().mockResolvedValue(undefined),
|
|
close: vi.fn().mockResolvedValue(undefined),
|
|
};
|
|
vi.mocked(fs.open).mockResolvedValue(
|
|
mockFileHandle as unknown as fs.FileHandle,
|
|
);
|
|
vi.mocked(fs.rename).mockResolvedValue(undefined);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'run_shell_command',
|
|
commandPrefix: ['echo', 'ls'],
|
|
persist: true,
|
|
});
|
|
|
|
// Wait for the async listener to complete
|
|
await new Promise((resolve) => setTimeout(resolve, 0));
|
|
|
|
expect(fs.open).toHaveBeenCalled();
|
|
const [content] = mockFileHandle.writeFile.mock.calls[0] as [
|
|
string,
|
|
string,
|
|
];
|
|
const parsed = toml.parse(content) as unknown as ParsedPolicy;
|
|
|
|
expect(parsed.rule).toHaveLength(1);
|
|
expect(parsed.rule![0].commandPrefix).toEqual(['echo', 'ls']);
|
|
});
|
|
|
|
it('should reject unsafe regex patterns', async () => {
|
|
createPolicyUpdater(policyEngine, messageBus, mockStorage);
|
|
|
|
await messageBus.publish({
|
|
type: MessageBusType.UPDATE_POLICY,
|
|
toolName: 'test_tool',
|
|
argsPattern: '(a+)+',
|
|
persist: false,
|
|
});
|
|
|
|
expect(policyEngine.addRule).not.toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('ShellToolInvocation Policy Update', () => {
|
|
let mockConfig: Config;
|
|
let mockMessageBus: MessageBus;
|
|
|
|
beforeEach(() => {
|
|
vi.resetAllMocks();
|
|
mockConfig = {} as Config;
|
|
mockMessageBus = {} as MessageBus;
|
|
|
|
vi.mocked(shellUtils.stripShellWrapper).mockImplementation(
|
|
(c: string) => c,
|
|
);
|
|
vi.mocked(shellUtils.hasRedirection).mockReturnValue(false);
|
|
});
|
|
|
|
it('should extract multiple root commands for chained commands', () => {
|
|
vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['git', 'npm']);
|
|
|
|
const invocation = new ShellToolInvocation(
|
|
mockConfig,
|
|
{ command: 'git status && npm test' },
|
|
mockMessageBus,
|
|
'run_shell_command',
|
|
'Shell',
|
|
);
|
|
|
|
// Accessing protected method for testing
|
|
const options = (
|
|
invocation as unknown as TestableShellToolInvocation
|
|
).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways);
|
|
expect(options!.commandPrefix).toEqual(['git', 'npm']);
|
|
expect(shellUtils.getCommandRoots).toHaveBeenCalledWith(
|
|
'git status && npm test',
|
|
);
|
|
});
|
|
|
|
it('should extract a single root command', () => {
|
|
vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['ls']);
|
|
|
|
const invocation = new ShellToolInvocation(
|
|
mockConfig,
|
|
{ command: 'ls -la /tmp' },
|
|
mockMessageBus,
|
|
'run_shell_command',
|
|
'Shell',
|
|
);
|
|
|
|
// Accessing protected method for testing
|
|
const options = (
|
|
invocation as unknown as TestableShellToolInvocation
|
|
).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways);
|
|
expect(options!.commandPrefix).toEqual(['ls']);
|
|
expect(shellUtils.getCommandRoots).toHaveBeenCalledWith('ls -la /tmp');
|
|
});
|
|
|
|
it('should include allowRedirection if command has redirection', () => {
|
|
vi.mocked(shellUtils.getCommandRoots).mockReturnValue(['echo']);
|
|
vi.mocked(shellUtils.hasRedirection).mockReturnValue(true);
|
|
|
|
const invocation = new ShellToolInvocation(
|
|
mockConfig,
|
|
{ command: 'echo "hello" > file.txt' },
|
|
mockMessageBus,
|
|
'run_shell_command',
|
|
'Shell',
|
|
);
|
|
|
|
const options = (
|
|
invocation as unknown as TestableShellToolInvocation
|
|
).getPolicyUpdateOptions(ToolConfirmationOutcome.ProceedAlways);
|
|
expect(options!.commandPrefix).toEqual(['echo']);
|
|
expect(options!.allowRedirection).toBe(true);
|
|
expect(shellUtils.hasRedirection).toHaveBeenCalledWith(
|
|
'echo "hello" > file.txt',
|
|
);
|
|
});
|
|
});
|