feat: Persistent "Always Allow" policies with granular shell & MCP support (#14737)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Allen Hutchison
2025-12-12 13:45:39 -08:00
committed by GitHub
parent d2a1a45646
commit 5f298c17d7
18 changed files with 431 additions and 21 deletions
@@ -39,6 +39,10 @@ export interface ToolConfirmationResponse {
export interface UpdatePolicy {
type: MessageBusType.UPDATE_POLICY;
toolName: string;
persist?: boolean;
argsPattern?: string;
commandPrefix?: string;
mcpName?: string;
}
export interface ToolPolicyRejection {
+5
View File
@@ -7,6 +7,11 @@
import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest';
import { detectIde, IDE_DEFINITIONS } from './detect-ide.js';
beforeEach(() => {
// Ensure Antigravity detection doesn't interfere with other tests
vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', '');
});
describe('detectIde', () => {
const ideProcessInfo = { pid: 123, command: 'some/path/to/code' };
const ideProcessInfoNoCode = { pid: 123, command: 'some/path/to/fork' };
+99 -2
View File
@@ -4,6 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import { fileURLToPath } from 'node:url';
import { Storage } from '../config/storage.js';
@@ -15,7 +16,12 @@ import {
type PolicySettings,
} from './types.js';
import type { PolicyEngine } from './policy-engine.js';
import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js';
import {
loadPoliciesFromToml,
type PolicyFileError,
escapeRegex,
} from './toml-loader.js';
import toml from '@iarna/toml';
import {
MessageBusType,
type UpdatePolicy,
@@ -233,14 +239,35 @@ export async function createPolicyEngineConfig(
};
}
interface TomlRule {
toolName?: string;
mcpName?: string;
decision?: string;
priority?: number;
commandPrefix?: string;
argsPattern?: string;
// Index signature to satisfy Record type if needed for toml.stringify
[key: string]: unknown;
}
export function createPolicyUpdater(
policyEngine: PolicyEngine,
messageBus: MessageBus,
) {
messageBus.subscribe(
MessageBusType.UPDATE_POLICY,
(message: UpdatePolicy) => {
async (message: UpdatePolicy) => {
const toolName = message.toolName;
let argsPattern = message.argsPattern
? new RegExp(message.argsPattern)
: undefined;
if (message.commandPrefix) {
// Convert commandPrefix to argsPattern for in-memory rule
// This mimics what toml-loader does
const escapedPrefix = escapeRegex(message.commandPrefix);
argsPattern = new RegExp(`"command":"${escapedPrefix}`);
}
policyEngine.addRule({
toolName,
@@ -249,7 +276,77 @@ export function createPolicyUpdater(
// This ensures user "always allow" selections are high priority
// but still lose to admin policies (3.xxx) and settings excludes (200)
priority: 2.95,
argsPattern,
});
if (message.persist) {
try {
const userPoliciesDir = Storage.getUserPoliciesDir();
await fs.mkdir(userPoliciesDir, { recursive: true });
const policyFile = path.join(userPoliciesDir, 'auto-saved.toml');
// Read existing file
let existingData: { rule?: TomlRule[] } = {};
try {
const fileContent = await fs.readFile(policyFile, 'utf-8');
existingData = toml.parse(fileContent) as { rule?: TomlRule[] };
} catch (error) {
if ((error as NodeJS.ErrnoException).code !== 'ENOENT') {
console.warn(
`Failed to parse ${policyFile}, overwriting with new policy.`,
error,
);
}
}
// Initialize rule array if needed
if (!existingData.rule) {
existingData.rule = [];
}
// Create new rule object
const newRule: TomlRule = {};
if (message.mcpName) {
newRule.mcpName = message.mcpName;
// Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2)
: toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else {
newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
}
if (message.commandPrefix) {
newRule.commandPrefix = message.commandPrefix;
} else if (message.argsPattern) {
newRule.argsPattern = message.argsPattern;
}
// Add to rules
existingData.rule.push(newRule);
// Serialize back to TOML
// @iarna/toml stringify might not produce beautiful output but it handles escaping correctly
const newContent = toml.stringify(existingData as toml.JsonMap);
// Atomic write: write to tmp then rename
const tmpFile = `${policyFile}.tmp`;
await fs.writeFile(tmpFile, newContent, 'utf-8');
await fs.rename(tmpFile, policyFile);
} catch (error) {
coreEvents.emitFeedback(
'error',
`Failed to persist policy for ${toolName}`,
error,
);
}
}
},
);
}
@@ -0,0 +1,209 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import { createPolicyUpdater } from './config.js';
import { PolicyEngine } from './policy-engine.js';
import { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js';
import { Storage } from '../config/storage.js';
vi.mock('node:fs/promises');
vi.mock('../config/storage.js');
describe('createPolicyUpdater', () => {
let policyEngine: PolicyEngine;
let messageBus: MessageBus;
beforeEach(() => {
policyEngine = new PolicyEngine({ rules: [], checkers: [] });
messageBus = new MessageBus(policyEngine);
vi.clearAllMocks();
});
afterEach(() => {
vi.restoreAllMocks();
});
it('should persist policy when persist flag is true', async () => {
createPolicyUpdater(policyEngine, messageBus);
const userPoliciesDir = '/mock/user/policies';
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
); // Simulate new file
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'test_tool';
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName,
persist: true,
});
// Wait for async operations (microtasks)
await new Promise((resolve) => setTimeout(resolve, 0));
expect(Storage.getUserPoliciesDir).toHaveBeenCalled();
expect(fs.mkdir).toHaveBeenCalledWith(userPoliciesDir, {
recursive: true,
});
// Check written content
const expectedContent = expect.stringContaining(`toolName = "test_tool"`);
expect(fs.writeFile).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
expectedContent,
'utf-8',
);
expect(fs.rename).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
path.join(userPoliciesDir, 'auto-saved.toml'),
);
});
it('should not persist policy when persist flag is false or undefined', async () => {
createPolicyUpdater(policyEngine, messageBus);
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'test_tool',
});
await new Promise((resolve) => setTimeout(resolve, 0));
expect(fs.writeFile).not.toHaveBeenCalled();
expect(fs.rename).not.toHaveBeenCalled();
});
it('should persist policy with commandPrefix when provided', async () => {
createPolicyUpdater(policyEngine, messageBus);
const userPoliciesDir = '/mock/user/policies';
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'run_shell_command';
const commandPrefix = 'git status';
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName,
persist: true,
commandPrefix,
});
await new Promise((resolve) => setTimeout(resolve, 0));
// In-memory rule check (unchanged)
const rules = policyEngine.getRules();
const addedRule = rules.find((r) => r.toolName === toolName);
expect(addedRule).toBeDefined();
expect(addedRule?.priority).toBe(2.95);
expect(addedRule?.argsPattern).toEqual(new RegExp(`"command":"git status`));
// Verify file written
expect(fs.writeFile).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
expect.stringContaining(`commandPrefix = "git status"`),
'utf-8',
);
});
it('should persist policy with mcpName and toolName when provided', async () => {
createPolicyUpdater(policyEngine, messageBus);
const userPoliciesDir = '/mock/user/policies';
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my-jira-server';
const simpleToolName = 'search';
const toolName = `${mcpName}__${simpleToolName}`;
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName,
persist: true,
mcpName,
});
await new Promise((resolve) => setTimeout(resolve, 0));
// Verify file written
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
const writtenContent = writeCall[1] as string;
expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
expect(writtenContent).toContain('priority = 200');
});
it('should escape special characters in toolName and mcpName', async () => {
createPolicyUpdater(policyEngine, messageBus);
const userPoliciesDir = '/mock/user/policies';
vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
(fs.writeFile as unknown as Mock).mockResolvedValue(undefined);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my"jira"server';
const toolName = `my"jira"server__search"tool"`;
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName,
persist: true,
mcpName,
});
await new Promise((resolve) => setTimeout(resolve, 0));
const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0];
const writtenContent = writeCall[1] as string;
// Verify escaping - should be valid TOML
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'
// instead of "foo\"bar\"" if there are no single quotes in the string.
try {
expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`);
} catch {
expect(writtenContent).toContain(`mcpName = 'my"jira"server'`);
}
try {
expect(writtenContent).toContain(`toolName = "search\\"tool\\""`);
} catch {
expect(writtenContent).toContain(`toolName = 'search"tool"'`);
}
});
});
+1 -1
View File
@@ -126,7 +126,7 @@ export interface PolicyLoadResult {
* @param str The string to escape
* @returns The escaped string safe for use in a regex
*/
function escapeRegex(str: string): string {
export function escapeRegex(str: string): string {
return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&');
}
@@ -100,6 +100,11 @@ vi.mock('../../utils/installationManager.js');
const mockUserAccount = vi.mocked(UserAccountManager.prototype);
const mockInstallMgr = vi.mocked(InstallationManager.prototype);
beforeEach(() => {
// Ensure Antigravity detection doesn't interfere with other tests
vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', '');
});
// TODO(richieforeman): Consider moving this to test setup globally.
beforeAll(() => {
server.listen({});
+1
View File
@@ -313,6 +313,7 @@ class EditToolInvocation
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
await this.publishPolicyUpdate(outcome);
if (ideConfirmation) {
const result = await ideConfirmation;
+10
View File
@@ -16,6 +16,7 @@ import {
BaseToolInvocation,
Kind,
ToolConfirmationOutcome,
type PolicyUpdateOptions,
} from './tools.js';
import type { CallableTool, FunctionCall, Part } from '@google/genai';
import { ToolErrorType } from './tool-error.js';
@@ -87,6 +88,12 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
);
}
protected override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return { mcpName: this.serverName };
}
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
@@ -115,6 +122,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation<
DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) {
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
} else if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey);
await this.publishPolicyUpdate(outcome);
}
},
};
+1
View File
@@ -226,6 +226,7 @@ class MemoryToolInvocation extends BaseToolInvocation<
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
MemoryToolInvocation.allowlist.add(allowlistKey);
}
await this.publishPolicyUpdate(outcome);
},
};
return confirmationDetails;
+11
View File
@@ -22,6 +22,7 @@ import {
BaseToolInvocation,
ToolConfirmationOutcome,
Kind,
type PolicyUpdateOptions,
} from './tools.js';
import { ApprovalMode } from '../policy/types.js';
@@ -83,6 +84,15 @@ export class ShellToolInvocation extends BaseToolInvocation<
return description;
}
protected override getPolicyUpdateOptions(
outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
return { commandPrefix: this.params.command };
}
return undefined;
}
protected override async getConfirmationDetails(
_abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> {
@@ -124,6 +134,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
commandsToConfirm.forEach((command) => this.allowlist.add(command));
}
await this.publishPolicyUpdate(outcome);
},
};
return confirmationDetails;
+1
View File
@@ -683,6 +683,7 @@ class EditToolInvocation
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
await this.publishPolicyUpdate(outcome);
if (ideConfirmation) {
const result = await ideConfirmation;
+44 -9
View File
@@ -65,6 +65,14 @@ export interface ToolInvocation<
): Promise<TResult>;
}
/**
* Options for policy updates that can be customized by tool invocations.
*/
export interface PolicyUpdateOptions {
commandPrefix?: string;
mcpName?: string;
}
/**
* A convenience base class for ToolInvocation.
*/
@@ -112,6 +120,40 @@ export abstract class BaseToolInvocation<
return this.getConfirmationDetails(abortSignal);
}
/**
* Returns tool-specific options for policy updates.
* Subclasses can override this to provide additional options like
* commandPrefix (for shell) or mcpName (for MCP tools).
*/
protected getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return undefined;
}
/**
* Helper method to publish a policy update when user selects
* ProceedAlways or ProceedAlwaysAndSave.
*/
protected async publishPolicyUpdate(
outcome: ToolConfirmationOutcome,
): Promise<void> {
if (
outcome === ToolConfirmationOutcome.ProceedAlways ||
outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave
) {
if (this.messageBus && this._toolName) {
const options = this.getPolicyUpdateOptions(outcome);
await this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave,
...options,
});
}
}
}
/**
* Subclasses should override this method to provide custom confirmation UI
* when the policy engine's decision is 'ASK_USER'.
@@ -129,15 +171,7 @@ export abstract class BaseToolInvocation<
title: `Confirm: ${this._toolDisplayName || this._toolName}`,
prompt: this.getDescription(),
onConfirm: async (outcome: ToolConfirmationOutcome) => {
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
if (this.messageBus && this._toolName) {
// eslint-disable-next-line @typescript-eslint/no-floating-promises
this.messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: this._toolName,
});
}
}
await this.publishPolicyUpdate(outcome);
},
};
return confirmationDetails;
@@ -686,6 +720,7 @@ export type ToolCallConfirmationDetails =
export enum ToolConfirmationOutcome {
ProceedOnce = 'proceed_once',
ProceedAlways = 'proceed_always',
ProceedAlwaysAndSave = 'proceed_always_and_save',
ProceedAlwaysServer = 'proceed_always_server',
ProceedAlwaysTool = 'proceed_always_tool',
ModifyWithEditor = 'modify_with_editor',
+1
View File
@@ -244,6 +244,7 @@ ${textContent}
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
await this.publishPolicyUpdate(outcome);
},
};
return confirmationDetails;
+1
View File
@@ -224,6 +224,7 @@ class WriteFileToolInvocation extends BaseToolInvocation<
if (outcome === ToolConfirmationOutcome.ProceedAlways) {
this.config.setApprovalMode(ApprovalMode.AUTO_EDIT);
}
await this.publishPolicyUpdate(outcome);
if (ideConfirmation) {
const result = await ideConfirmation;