From 5f298c17d7f6c3a1eabbeac249904cd1cf141352 Mon Sep 17 00:00:00 2001 From: Allen Hutchison Date: Fri, 12 Dec 2025 13:45:39 -0800 Subject: [PATCH] feat: Persistent "Always Allow" policies with granular shell & MCP support (#14737) Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../messages/ToolConfirmationMessage.tsx | 20 ++ .../ToolConfirmationMessage.test.tsx.snap | 20 +- .../ToolGroupMessage.test.tsx.snap | 6 +- .../cli/src/zed-integration/zedIntegration.ts | 1 + packages/core/src/confirmation-bus/types.ts | 4 + packages/core/src/ide/detect-ide.test.ts | 5 + packages/core/src/policy/config.ts | 101 ++++++++- packages/core/src/policy/persistence.test.ts | 209 ++++++++++++++++++ packages/core/src/policy/toml-loader.ts | 2 +- .../clearcut-logger/clearcut-logger.test.ts | 5 + packages/core/src/tools/edit.ts | 1 + packages/core/src/tools/mcp-tool.ts | 10 + packages/core/src/tools/memoryTool.ts | 1 + packages/core/src/tools/shell.ts | 11 + packages/core/src/tools/smart-edit.ts | 1 + packages/core/src/tools/tools.ts | 53 ++++- packages/core/src/tools/web-fetch.ts | 1 + packages/core/src/tools/write-file.ts | 1 + 18 files changed, 431 insertions(+), 21 deletions(-) create mode 100644 packages/core/src/policy/persistence.test.ts diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 934ab4dda4..78ec581e72 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -114,6 +114,11 @@ export const ToolConfirmationMessage: React.FC< value: ToolConfirmationOutcome.ProceedAlways, key: 'Yes, allow always', }); + options.push({ + label: 'Yes, allow always and save to policy', + value: ToolConfirmationOutcome.ProceedAlwaysAndSave, + key: 'Yes, allow always and save to policy', + }); } if (!config.getIdeMode() || !isDiffingEnabled) { options.push({ @@ -145,6 +150,11 @@ export const ToolConfirmationMessage: React.FC< value: ToolConfirmationOutcome.ProceedAlways, key: `Yes, allow always ...`, }); + options.push({ + label: `Yes, allow always and save to policy`, + value: ToolConfirmationOutcome.ProceedAlwaysAndSave, + key: `Yes, allow always and save to policy`, + }); } options.push({ label: 'No, suggest changes (esc)', @@ -164,6 +174,11 @@ export const ToolConfirmationMessage: React.FC< value: ToolConfirmationOutcome.ProceedAlways, key: 'Yes, allow always', }); + options.push({ + label: 'Yes, allow always and save to policy', + value: ToolConfirmationOutcome.ProceedAlwaysAndSave, + key: 'Yes, allow always and save to policy', + }); } options.push({ label: 'No, suggest changes (esc)', @@ -190,6 +205,11 @@ export const ToolConfirmationMessage: React.FC< value: ToolConfirmationOutcome.ProceedAlwaysServer, key: `Yes, always allow all tools from server "${mcpProps.serverName}"`, }); + options.push({ + label: `Yes, allow always tool "${mcpProps.toolName}" and save to policy`, + value: ToolConfirmationOutcome.ProceedAlwaysAndSave, + key: `Yes, allow always tool "${mcpProps.toolName}" and save to policy`, + }); } options.push({ label: 'No, suggest changes (esc)', diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap index 95aa1fca13..577432734e 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap @@ -10,7 +10,8 @@ Do you want to proceed? ● 1. Yes, allow once 2. Yes, allow always - 3. No, suggest changes (esc) + 3. Yes, allow always and save to policy + 4. No, suggest changes (esc) " `; @@ -21,7 +22,8 @@ Do you want to proceed? ● 1. Yes, allow once 2. Yes, allow always - 3. No, suggest changes (esc) + 3. Yes, allow always and save to policy + 4. No, suggest changes (esc) " `; @@ -51,8 +53,9 @@ Apply this change? ● 1. Yes, allow once 2. Yes, allow always - 3. Modify with external editor - 4. No, suggest changes (esc) + 3. Yes, allow always and save to policy + 4. Modify with external editor + 5. No, suggest changes (esc) " `; @@ -73,7 +76,8 @@ Allow execution of: 'echo'? ● 1. Yes, allow once 2. Yes, allow always ... - 3. No, suggest changes (esc) + 3. Yes, allow always and save to policy + 4. No, suggest changes (esc) " `; @@ -94,7 +98,8 @@ Do you want to proceed? ● 1. Yes, allow once 2. Yes, allow always - 3. No, suggest changes (esc) + 3. Yes, allow always and save to policy + 4. No, suggest changes (esc) " `; @@ -118,6 +123,7 @@ Allow execution of MCP tool "test-tool" from server "test-server"? ● 1. Yes, allow once 2. Yes, always allow tool "test-tool" from server "test-server" 3. Yes, always allow all tools from server "test-server" - 4. No, suggest changes (esc) + 4. Yes, allow always tool "test-tool" and save to policy + 5. No, suggest changes (esc) " `; diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolGroupMessage.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolGroupMessage.test.tsx.snap index 038c60e1f9..af18637dbd 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolGroupMessage.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolGroupMessage.test.tsx.snap @@ -39,7 +39,8 @@ exports[` > Confirmation Handling > shows confirmation dialo │ │ │ ● 1. Yes, allow once │ │ 2. Yes, allow always │ -│ 3. No, suggest changes (esc) │ +│ 3. Yes, allow always and save to policy │ +│ 4. No, suggest changes (esc) │ │ │ │ │ │ ? second-confirm A tool for testing │ @@ -122,7 +123,8 @@ exports[` > Golden Snapshots > renders tool call awaiting co │ │ │ ● 1. Yes, allow once │ │ 2. Yes, allow always │ -│ 3. No, suggest changes (esc) │ +│ 3. Yes, allow always and save to policy │ +│ 4. No, suggest changes (esc) │ │ │ ╰──────────────────────────────────────────────────────────────────────────────╯" `; diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index 93e4121705..fce417f095 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -449,6 +449,7 @@ export class Session { ); case ToolConfirmationOutcome.ProceedOnce: case ToolConfirmationOutcome.ProceedAlways: + case ToolConfirmationOutcome.ProceedAlwaysAndSave: case ToolConfirmationOutcome.ProceedAlwaysServer: case ToolConfirmationOutcome.ProceedAlwaysTool: case ToolConfirmationOutcome.ModifyWithEditor: diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index 7c1d010934..6f2f9a2e12 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -39,6 +39,10 @@ export interface ToolConfirmationResponse { export interface UpdatePolicy { type: MessageBusType.UPDATE_POLICY; toolName: string; + persist?: boolean; + argsPattern?: string; + commandPrefix?: string; + mcpName?: string; } export interface ToolPolicyRejection { diff --git a/packages/core/src/ide/detect-ide.test.ts b/packages/core/src/ide/detect-ide.test.ts index bb278a8c14..ebeb522933 100644 --- a/packages/core/src/ide/detect-ide.test.ts +++ b/packages/core/src/ide/detect-ide.test.ts @@ -7,6 +7,11 @@ import { describe, it, expect, vi, afterEach, beforeEach } from 'vitest'; import { detectIde, IDE_DEFINITIONS } from './detect-ide.js'; +beforeEach(() => { + // Ensure Antigravity detection doesn't interfere with other tests + vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', ''); +}); + describe('detectIde', () => { const ideProcessInfo = { pid: 123, command: 'some/path/to/code' }; const ideProcessInfoNoCode = { pid: 123, command: 'some/path/to/fork' }; diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index 6ea78d30ca..3c6016f086 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import * as fs from 'node:fs/promises'; import * as path from 'node:path'; import { fileURLToPath } from 'node:url'; import { Storage } from '../config/storage.js'; @@ -15,7 +16,12 @@ import { type PolicySettings, } from './types.js'; import type { PolicyEngine } from './policy-engine.js'; -import { loadPoliciesFromToml, type PolicyFileError } from './toml-loader.js'; +import { + loadPoliciesFromToml, + type PolicyFileError, + escapeRegex, +} from './toml-loader.js'; +import toml from '@iarna/toml'; import { MessageBusType, type UpdatePolicy, @@ -233,14 +239,35 @@ export async function createPolicyEngineConfig( }; } +interface TomlRule { + toolName?: string; + mcpName?: string; + decision?: string; + priority?: number; + commandPrefix?: string; + argsPattern?: string; + // Index signature to satisfy Record type if needed for toml.stringify + [key: string]: unknown; +} + export function createPolicyUpdater( policyEngine: PolicyEngine, messageBus: MessageBus, ) { messageBus.subscribe( MessageBusType.UPDATE_POLICY, - (message: UpdatePolicy) => { + async (message: UpdatePolicy) => { const toolName = message.toolName; + let argsPattern = message.argsPattern + ? new RegExp(message.argsPattern) + : undefined; + + if (message.commandPrefix) { + // Convert commandPrefix to argsPattern for in-memory rule + // This mimics what toml-loader does + const escapedPrefix = escapeRegex(message.commandPrefix); + argsPattern = new RegExp(`"command":"${escapedPrefix}`); + } policyEngine.addRule({ toolName, @@ -249,7 +276,77 @@ export function createPolicyUpdater( // This ensures user "always allow" selections are high priority // but still lose to admin policies (3.xxx) and settings excludes (200) priority: 2.95, + argsPattern, }); + + if (message.persist) { + try { + const userPoliciesDir = Storage.getUserPoliciesDir(); + await fs.mkdir(userPoliciesDir, { recursive: true }); + const policyFile = path.join(userPoliciesDir, 'auto-saved.toml'); + + // Read existing file + let existingData: { rule?: TomlRule[] } = {}; + try { + const fileContent = await fs.readFile(policyFile, 'utf-8'); + existingData = toml.parse(fileContent) as { rule?: TomlRule[] }; + } catch (error) { + if ((error as NodeJS.ErrnoException).code !== 'ENOENT') { + console.warn( + `Failed to parse ${policyFile}, overwriting with new policy.`, + error, + ); + } + } + + // Initialize rule array if needed + if (!existingData.rule) { + existingData.rule = []; + } + + // Create new rule object + const newRule: TomlRule = {}; + + if (message.mcpName) { + newRule.mcpName = message.mcpName; + // Extract simple tool name + const simpleToolName = toolName.startsWith(`${message.mcpName}__`) + ? toolName.slice(message.mcpName.length + 2) + : toolName; + newRule.toolName = simpleToolName; + newRule.decision = 'allow'; + newRule.priority = 200; + } else { + newRule.toolName = toolName; + newRule.decision = 'allow'; + newRule.priority = 100; + } + + if (message.commandPrefix) { + newRule.commandPrefix = message.commandPrefix; + } else if (message.argsPattern) { + newRule.argsPattern = message.argsPattern; + } + + // Add to rules + existingData.rule.push(newRule); + + // Serialize back to TOML + // @iarna/toml stringify might not produce beautiful output but it handles escaping correctly + const newContent = toml.stringify(existingData as toml.JsonMap); + + // Atomic write: write to tmp then rename + const tmpFile = `${policyFile}.tmp`; + await fs.writeFile(tmpFile, newContent, 'utf-8'); + await fs.rename(tmpFile, policyFile); + } catch (error) { + coreEvents.emitFeedback( + 'error', + `Failed to persist policy for ${toolName}`, + error, + ); + } + } }, ); } diff --git a/packages/core/src/policy/persistence.test.ts b/packages/core/src/policy/persistence.test.ts new file mode 100644 index 0000000000..e7916b8644 --- /dev/null +++ b/packages/core/src/policy/persistence.test.ts @@ -0,0 +1,209 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mock, +} from 'vitest'; +import * as fs from 'node:fs/promises'; +import * as path from 'node:path'; +import { createPolicyUpdater } from './config.js'; +import { PolicyEngine } from './policy-engine.js'; +import { MessageBus } from '../confirmation-bus/message-bus.js'; +import { MessageBusType } from '../confirmation-bus/types.js'; +import { Storage } from '../config/storage.js'; + +vi.mock('node:fs/promises'); +vi.mock('../config/storage.js'); + +describe('createPolicyUpdater', () => { + let policyEngine: PolicyEngine; + let messageBus: MessageBus; + + beforeEach(() => { + policyEngine = new PolicyEngine({ rules: [], checkers: [] }); + messageBus = new MessageBus(policyEngine); + vi.clearAllMocks(); + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + it('should persist policy when persist flag is true', async () => { + createPolicyUpdater(policyEngine, messageBus); + + const userPoliciesDir = '/mock/user/policies'; + vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockRejectedValue( + new Error('File not found'), + ); // Simulate new file + (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + const toolName = 'test_tool'; + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName, + persist: true, + }); + + // Wait for async operations (microtasks) + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(Storage.getUserPoliciesDir).toHaveBeenCalled(); + expect(fs.mkdir).toHaveBeenCalledWith(userPoliciesDir, { + recursive: true, + }); + + // Check written content + const expectedContent = expect.stringContaining(`toolName = "test_tool"`); + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringMatching(/\.tmp$/), + expectedContent, + 'utf-8', + ); + expect(fs.rename).toHaveBeenCalledWith( + expect.stringMatching(/\.tmp$/), + path.join(userPoliciesDir, 'auto-saved.toml'), + ); + }); + + it('should not persist policy when persist flag is false or undefined', async () => { + createPolicyUpdater(policyEngine, messageBus); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'test_tool', + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(fs.writeFile).not.toHaveBeenCalled(); + expect(fs.rename).not.toHaveBeenCalled(); + }); + + it('should persist policy with commandPrefix when provided', async () => { + createPolicyUpdater(policyEngine, messageBus); + + const userPoliciesDir = '/mock/user/policies'; + vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockRejectedValue( + new Error('File not found'), + ); + (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + const toolName = 'run_shell_command'; + const commandPrefix = 'git status'; + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName, + persist: true, + commandPrefix, + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + // In-memory rule check (unchanged) + const rules = policyEngine.getRules(); + const addedRule = rules.find((r) => r.toolName === toolName); + expect(addedRule).toBeDefined(); + expect(addedRule?.priority).toBe(2.95); + expect(addedRule?.argsPattern).toEqual(new RegExp(`"command":"git status`)); + + // Verify file written + expect(fs.writeFile).toHaveBeenCalledWith( + expect.stringMatching(/\.tmp$/), + expect.stringContaining(`commandPrefix = "git status"`), + 'utf-8', + ); + }); + + it('should persist policy with mcpName and toolName when provided', async () => { + createPolicyUpdater(policyEngine, messageBus); + + const userPoliciesDir = '/mock/user/policies'; + vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockRejectedValue( + new Error('File not found'), + ); + (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + const mcpName = 'my-jira-server'; + const simpleToolName = 'search'; + const toolName = `${mcpName}__${simpleToolName}`; + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName, + persist: true, + mcpName, + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + // Verify file written + const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; + const writtenContent = writeCall[1] as string; + expect(writtenContent).toContain(`mcpName = "${mcpName}"`); + expect(writtenContent).toContain(`toolName = "${simpleToolName}"`); + expect(writtenContent).toContain('priority = 200'); + }); + + it('should escape special characters in toolName and mcpName', async () => { + createPolicyUpdater(policyEngine, messageBus); + + const userPoliciesDir = '/mock/user/policies'; + vi.spyOn(Storage, 'getUserPoliciesDir').mockReturnValue(userPoliciesDir); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockRejectedValue( + new Error('File not found'), + ); + (fs.writeFile as unknown as Mock).mockResolvedValue(undefined); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + const mcpName = 'my"jira"server'; + const toolName = `my"jira"server__search"tool"`; + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName, + persist: true, + mcpName, + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + const writeCall = (fs.writeFile as unknown as Mock).mock.calls[0]; + const writtenContent = writeCall[1] as string; + + // Verify escaping - should be valid TOML + // Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar' + // instead of "foo\"bar\"" if there are no single quotes in the string. + try { + expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`); + } catch { + expect(writtenContent).toContain(`mcpName = 'my"jira"server'`); + } + + try { + expect(writtenContent).toContain(`toolName = "search\\"tool\\""`); + } catch { + expect(writtenContent).toContain(`toolName = 'search"tool"'`); + } + }); +}); diff --git a/packages/core/src/policy/toml-loader.ts b/packages/core/src/policy/toml-loader.ts index ed63db0929..4f5ca8b976 100644 --- a/packages/core/src/policy/toml-loader.ts +++ b/packages/core/src/policy/toml-loader.ts @@ -126,7 +126,7 @@ export interface PolicyLoadResult { * @param str The string to escape * @returns The escaped string safe for use in a regex */ -function escapeRegex(str: string): string { +export function escapeRegex(str: string): string { return str.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); } diff --git a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts index 7f8b46b71b..d66cf67b00 100644 --- a/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts +++ b/packages/core/src/telemetry/clearcut-logger/clearcut-logger.test.ts @@ -100,6 +100,11 @@ vi.mock('../../utils/installationManager.js'); const mockUserAccount = vi.mocked(UserAccountManager.prototype); const mockInstallMgr = vi.mocked(InstallationManager.prototype); +beforeEach(() => { + // Ensure Antigravity detection doesn't interfere with other tests + vi.stubEnv('ANTIGRAVITY_CLI_ALIAS', ''); +}); + // TODO(richieforeman): Consider moving this to test setup globally. beforeAll(() => { server.listen({}); diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index 4ffd0dbf72..2ea6bed52d 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -313,6 +313,7 @@ class EditToolInvocation if (outcome === ToolConfirmationOutcome.ProceedAlways) { this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); } + await this.publishPolicyUpdate(outcome); if (ideConfirmation) { const result = await ideConfirmation; diff --git a/packages/core/src/tools/mcp-tool.ts b/packages/core/src/tools/mcp-tool.ts index 176c6332f0..280927a6e0 100644 --- a/packages/core/src/tools/mcp-tool.ts +++ b/packages/core/src/tools/mcp-tool.ts @@ -16,6 +16,7 @@ import { BaseToolInvocation, Kind, ToolConfirmationOutcome, + type PolicyUpdateOptions, } from './tools.js'; import type { CallableTool, FunctionCall, Part } from '@google/genai'; import { ToolErrorType } from './tool-error.js'; @@ -87,6 +88,12 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< ); } + protected override getPolicyUpdateOptions( + _outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined { + return { mcpName: this.serverName }; + } + protected override async getConfirmationDetails( _abortSignal: AbortSignal, ): Promise { @@ -115,6 +122,9 @@ class DiscoveredMCPToolInvocation extends BaseToolInvocation< DiscoveredMCPToolInvocation.allowlist.add(serverAllowListKey); } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysTool) { DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey); + } else if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) { + DiscoveredMCPToolInvocation.allowlist.add(toolAllowListKey); + await this.publishPolicyUpdate(outcome); } }, }; diff --git a/packages/core/src/tools/memoryTool.ts b/packages/core/src/tools/memoryTool.ts index da4a54c115..875dc8152f 100644 --- a/packages/core/src/tools/memoryTool.ts +++ b/packages/core/src/tools/memoryTool.ts @@ -226,6 +226,7 @@ class MemoryToolInvocation extends BaseToolInvocation< if (outcome === ToolConfirmationOutcome.ProceedAlways) { MemoryToolInvocation.allowlist.add(allowlistKey); } + await this.publishPolicyUpdate(outcome); }, }; return confirmationDetails; diff --git a/packages/core/src/tools/shell.ts b/packages/core/src/tools/shell.ts index b0d6324ad8..feca545ef9 100644 --- a/packages/core/src/tools/shell.ts +++ b/packages/core/src/tools/shell.ts @@ -22,6 +22,7 @@ import { BaseToolInvocation, ToolConfirmationOutcome, Kind, + type PolicyUpdateOptions, } from './tools.js'; import { ApprovalMode } from '../policy/types.js'; @@ -83,6 +84,15 @@ export class ShellToolInvocation extends BaseToolInvocation< return description; } + protected override getPolicyUpdateOptions( + outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined { + if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) { + return { commandPrefix: this.params.command }; + } + return undefined; + } + protected override async getConfirmationDetails( _abortSignal: AbortSignal, ): Promise { @@ -124,6 +134,7 @@ export class ShellToolInvocation extends BaseToolInvocation< if (outcome === ToolConfirmationOutcome.ProceedAlways) { commandsToConfirm.forEach((command) => this.allowlist.add(command)); } + await this.publishPolicyUpdate(outcome); }, }; return confirmationDetails; diff --git a/packages/core/src/tools/smart-edit.ts b/packages/core/src/tools/smart-edit.ts index 86d412b407..5c1e776db6 100644 --- a/packages/core/src/tools/smart-edit.ts +++ b/packages/core/src/tools/smart-edit.ts @@ -683,6 +683,7 @@ class EditToolInvocation if (outcome === ToolConfirmationOutcome.ProceedAlways) { this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); } + await this.publishPolicyUpdate(outcome); if (ideConfirmation) { const result = await ideConfirmation; diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 4455ace265..629961026b 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -65,6 +65,14 @@ export interface ToolInvocation< ): Promise; } +/** + * Options for policy updates that can be customized by tool invocations. + */ +export interface PolicyUpdateOptions { + commandPrefix?: string; + mcpName?: string; +} + /** * A convenience base class for ToolInvocation. */ @@ -112,6 +120,40 @@ export abstract class BaseToolInvocation< return this.getConfirmationDetails(abortSignal); } + /** + * Returns tool-specific options for policy updates. + * Subclasses can override this to provide additional options like + * commandPrefix (for shell) or mcpName (for MCP tools). + */ + protected getPolicyUpdateOptions( + _outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined { + return undefined; + } + + /** + * Helper method to publish a policy update when user selects + * ProceedAlways or ProceedAlwaysAndSave. + */ + protected async publishPolicyUpdate( + outcome: ToolConfirmationOutcome, + ): Promise { + if ( + outcome === ToolConfirmationOutcome.ProceedAlways || + outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave + ) { + if (this.messageBus && this._toolName) { + const options = this.getPolicyUpdateOptions(outcome); + await this.messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: this._toolName, + persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave, + ...options, + }); + } + } + } + /** * Subclasses should override this method to provide custom confirmation UI * when the policy engine's decision is 'ASK_USER'. @@ -129,15 +171,7 @@ export abstract class BaseToolInvocation< title: `Confirm: ${this._toolDisplayName || this._toolName}`, prompt: this.getDescription(), onConfirm: async (outcome: ToolConfirmationOutcome) => { - if (outcome === ToolConfirmationOutcome.ProceedAlways) { - if (this.messageBus && this._toolName) { - // eslint-disable-next-line @typescript-eslint/no-floating-promises - this.messageBus.publish({ - type: MessageBusType.UPDATE_POLICY, - toolName: this._toolName, - }); - } - } + await this.publishPolicyUpdate(outcome); }, }; return confirmationDetails; @@ -686,6 +720,7 @@ export type ToolCallConfirmationDetails = export enum ToolConfirmationOutcome { ProceedOnce = 'proceed_once', ProceedAlways = 'proceed_always', + ProceedAlwaysAndSave = 'proceed_always_and_save', ProceedAlwaysServer = 'proceed_always_server', ProceedAlwaysTool = 'proceed_always_tool', ModifyWithEditor = 'modify_with_editor', diff --git a/packages/core/src/tools/web-fetch.ts b/packages/core/src/tools/web-fetch.ts index a1836c37ef..57591343f3 100644 --- a/packages/core/src/tools/web-fetch.ts +++ b/packages/core/src/tools/web-fetch.ts @@ -244,6 +244,7 @@ ${textContent} if (outcome === ToolConfirmationOutcome.ProceedAlways) { this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); } + await this.publishPolicyUpdate(outcome); }, }; return confirmationDetails; diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index ad9a6a7588..f9a6d2559d 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -224,6 +224,7 @@ class WriteFileToolInvocation extends BaseToolInvocation< if (outcome === ToolConfirmationOutcome.ProceedAlways) { this.config.setApprovalMode(ApprovalMode.AUTO_EDIT); } + await this.publishPolicyUpdate(outcome); if (ideConfirmation) { const result = await ideConfirmation;