diff --git a/packages/cli/src/config/settingsSchema.ts b/packages/cli/src/config/settingsSchema.ts index fb0520d334..73c3411f87 100644 --- a/packages/cli/src/config/settingsSchema.ts +++ b/packages/cli/src/config/settingsSchema.ts @@ -1475,6 +1475,18 @@ const SETTINGS_SCHEMA = { 'Enable the "Allow for all future sessions" option in tool confirmation dialogs.', showInDialog: true, }, + autoAddToPolicyByDefault: { + type: 'boolean', + label: 'Auto-add to Policy by Default', + category: 'Security', + requiresRestart: false, + default: true, + description: oneLine` + When enabled, the "Allow for all future sessions" option becomes the + default choice for low-risk tools in trusted workspaces. + `, + showInDialog: true, + }, blockGitExtensions: { type: 'boolean', label: 'Blocks extensions from Git', diff --git a/packages/cli/src/ui/commands/policiesCommand.test.ts b/packages/cli/src/ui/commands/policiesCommand.test.ts index 554d5cd53d..8ed1ab2456 100644 --- a/packages/cli/src/ui/commands/policiesCommand.test.ts +++ b/packages/cli/src/ui/commands/policiesCommand.test.ts @@ -9,12 +9,15 @@ import { policiesCommand } from './policiesCommand.js'; import { CommandKind } from './types.js'; import { MessageType } from '../types.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; +import * as fs from 'node:fs/promises'; import { type Config, PolicyDecision, ApprovalMode, } from '@google/gemini-cli-core'; +vi.mock('node:fs/promises'); + describe('policiesCommand', () => { let mockContext: ReturnType; @@ -26,8 +29,9 @@ describe('policiesCommand', () => { expect(policiesCommand.name).toBe('policies'); expect(policiesCommand.description).toBe('Manage policies'); expect(policiesCommand.kind).toBe(CommandKind.BUILT_IN); - expect(policiesCommand.subCommands).toHaveLength(1); + expect(policiesCommand.subCommands).toHaveLength(2); expect(policiesCommand.subCommands![0].name).toBe('list'); + expect(policiesCommand.subCommands![1].name).toBe('undo'); }); describe('list subcommand', () => { @@ -160,4 +164,63 @@ describe('policiesCommand', () => { expect(content).toContain('**ALLOW** tool: `shell` [Priority: 50]'); }); }); + + describe('undo subcommand', () => { + it('should show error if config is missing', async () => { + mockContext.services.config = null; + const undoCommand = policiesCommand.subCommands![1]; + await undoCommand.action!(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageType.ERROR, + text: 'Error: Config not available.', + }), + expect.any(Number), + ); + }); + + it('should show message if no backups found', async () => { + const mockStorage = { + getAutoSavedPolicyPath: vi.fn().mockReturnValue('user.toml'), + getWorkspaceAutoSavedPolicyPath: vi.fn().mockReturnValue('ws.toml'), + }; + mockContext.services.config = { + storage: mockStorage, + } as unknown as Config; + + vi.mocked(fs.access).mockRejectedValue(new Error('no backup')); + const undoCommand = policiesCommand.subCommands![1]; + await undoCommand.action!(mockContext, ''); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageType.WARNING, + text: 'No policy backups found to restore.', + }), + expect.any(Number), + ); + }); + + it('should restore backups if found', async () => { + const mockStorage = { + getAutoSavedPolicyPath: vi.fn().mockReturnValue('user.toml'), + getWorkspaceAutoSavedPolicyPath: vi.fn().mockReturnValue('ws.toml'), + }; + mockContext.services.config = { + storage: mockStorage, + } as unknown as Config; + + vi.mocked(fs.access).mockResolvedValue(undefined); + vi.mocked(fs.copyFile).mockResolvedValue(undefined); + const undoCommand = policiesCommand.subCommands![1]; + await undoCommand.action!(mockContext, ''); + expect(fs.copyFile).toHaveBeenCalled(); + expect(mockContext.ui.addItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: MessageType.INFO, + text: expect.stringContaining('Successfully restored'), + }), + expect.any(Number), + ); + }); + }); }); diff --git a/packages/cli/src/ui/commands/policiesCommand.ts b/packages/cli/src/ui/commands/policiesCommand.ts index f4bd13de28..9a6b0ee28a 100644 --- a/packages/cli/src/ui/commands/policiesCommand.ts +++ b/packages/cli/src/ui/commands/policiesCommand.ts @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +import * as fs from 'node:fs/promises'; import { ApprovalMode, type PolicyRule } from '@google/gemini-cli-core'; import { CommandKind, type SlashCommand } from './types.js'; import { MessageType } from '../types.js'; @@ -111,10 +112,66 @@ const listPoliciesCommand: SlashCommand = { }, }; +const undoPoliciesCommand: SlashCommand = { + name: 'undo', + description: 'Undo the last auto-saved policy update', + kind: CommandKind.BUILT_IN, + autoExecute: true, + action: async (context) => { + const { config } = context.services; + if (!config) { + context.ui.addItem( + { + type: MessageType.ERROR, + text: 'Error: Config not available.', + }, + Date.now(), + ); + return; + } + + const storage = config.storage; + const paths = [ + storage.getAutoSavedPolicyPath(), + storage.getWorkspaceAutoSavedPolicyPath(), + ]; + + let restoredCount = 0; + for (const p of paths) { + const bak = `${p}.bak`; + try { + await fs.access(bak); + await fs.copyFile(bak, p); + restoredCount++; + } catch { + // No backup or failed to restore + } + } + + if (restoredCount > 0) { + context.ui.addItem( + { + type: MessageType.INFO, + text: `Successfully restored ${restoredCount} policy file(s) from backup. Please restart the CLI to apply changes.`, + }, + Date.now(), + ); + } else { + context.ui.addItem( + { + type: MessageType.WARNING, + text: 'No policy backups found to restore.', + }, + Date.now(), + ); + } + }, +}; + export const policiesCommand: SlashCommand = { name: 'policies', description: 'Manage policies', kind: CommandKind.BUILT_IN, autoExecute: false, - subCommands: [listPoliciesCommand], + subCommands: [listPoliciesCommand, undoPoliciesCommand], }; diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.test.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.test.tsx index b3b34ae0a8..e109045301 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.test.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.test.tsx @@ -406,6 +406,41 @@ describe('ToolConfirmationMessage', () => { expect(lastFrame()).toContain('Allow for all future sessions'); unmount(); }); + + it('should default to "Allow for all future sessions" when autoAddToPolicyByDefault is true', async () => { + const mockConfig = { + isTrustedFolder: () => true, + getIdeMode: () => false, + } as unknown as Config; + + const { lastFrame, waitUntilReady, unmount } = renderWithProviders( + , + { + settings: createMockSettings({ + security: { + enablePermanentToolApproval: true, + autoAddToPolicyByDefault: true, + }, + }), + }, + ); + await waitUntilReady(); + + const output = lastFrame(); + // In Ink, the selected item is usually highlighted with a cursor or different color. + // We can't easily check colors in text output, but we can verify it's NOT the first option + // if we could see the selection indicator. + // Instead, we'll verify the snapshot which should show the selection. + expect(output).toMatchSnapshot(); + unmount(); + }); }); describe('Modify with external editor option', () => { diff --git a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx index 022a68e953..8215bde45a 100644 --- a/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx +++ b/packages/cli/src/ui/components/messages/ToolConfirmationMessage.tsx @@ -386,255 +386,292 @@ export const ToolConfirmationMessage: React.FC< return Math.max(availableTerminalHeight - surroundingElementsHeight, 1); }, [availableTerminalHeight, getOptions, handlesOwnUI]); - const { question, bodyContent, options, securityWarnings } = useMemo<{ - question: string; - bodyContent: React.ReactNode; - options: Array>; - securityWarnings: React.ReactNode; - }>(() => { - let bodyContent: React.ReactNode | null = null; - let securityWarnings: React.ReactNode | null = null; - let question = ''; - const options = getOptions(); + const { question, bodyContent, options, securityWarnings, initialIndex } = + useMemo<{ + question: string; + bodyContent: React.ReactNode; + options: Array>; + securityWarnings: React.ReactNode; + initialIndex: number; + }>(() => { + let bodyContent: React.ReactNode | null = null; + let securityWarnings: React.ReactNode | null = null; + let question = ''; + const options = getOptions(); - if (deceptiveUrlWarningText) { - securityWarnings = ; - } + let initialIndex = 0; + if ( + settings.merged.security.autoAddToPolicyByDefault && + isTrustedFolder && + allowPermanentApproval + ) { + const isSafeToPersist = + confirmationDetails.type === 'info' || + confirmationDetails.type === 'edit' || + (confirmationDetails.type === 'exec' && + confirmationDetails.rootCommand) || + confirmationDetails.type === 'mcp'; - if (confirmationDetails.type === 'ask_user') { - bodyContent = ( - { - handleConfirm(ToolConfirmationOutcome.ProceedOnce, { answers }); - }} - onCancel={() => { - handleConfirm(ToolConfirmationOutcome.Cancel); - }} - width={terminalWidth} - availableHeight={availableBodyContentHeight()} - /> - ); - return { - question: '', - bodyContent, - options: [], - securityWarnings: null, - }; - } - - if (confirmationDetails.type === 'exit_plan_mode') { - bodyContent = ( - { - handleConfirm(ToolConfirmationOutcome.ProceedOnce, { - approved: true, - approvalMode, - }); - }} - onFeedback={(feedback) => { - handleConfirm(ToolConfirmationOutcome.ProceedOnce, { - approved: false, - feedback, - }); - }} - onCancel={() => { - handleConfirm(ToolConfirmationOutcome.Cancel); - }} - width={terminalWidth} - availableHeight={availableBodyContentHeight()} - /> - ); - return { question: '', bodyContent, options: [], securityWarnings: null }; - } - - if (confirmationDetails.type === 'edit') { - if (!confirmationDetails.isModifying) { - question = `Apply this change?`; + if (isSafeToPersist) { + const alwaysAndSaveIndex = options.findIndex( + (o) => o.value === ToolConfirmationOutcome.ProceedAlwaysAndSave, + ); + if (alwaysAndSaveIndex !== -1) { + initialIndex = alwaysAndSaveIndex; + } + } } - } else if (confirmationDetails.type === 'exec') { - const executionProps = confirmationDetails; - if (executionProps.commands && executionProps.commands.length > 1) { - question = `Allow execution of ${executionProps.commands.length} commands?`; - } else { - question = `Allow execution of: '${sanitizeForDisplay(executionProps.rootCommand)}'?`; + if (deceptiveUrlWarningText) { + securityWarnings = ; } - } else if (confirmationDetails.type === 'info') { - question = `Do you want to proceed?`; - } else if (confirmationDetails.type === 'mcp') { - // mcp tool confirmation - const mcpProps = confirmationDetails; - question = `Allow execution of MCP tool "${sanitizeForDisplay(mcpProps.toolName)}" from server "${sanitizeForDisplay(mcpProps.serverName)}"?`; - } - if (confirmationDetails.type === 'edit') { - if (!confirmationDetails.isModifying) { + if (confirmationDetails.type === 'ask_user') { bodyContent = ( - { + handleConfirm(ToolConfirmationOutcome.ProceedOnce, { answers }); + }} + onCancel={() => { + handleConfirm(ToolConfirmationOutcome.Cancel); + }} + width={terminalWidth} + availableHeight={availableBodyContentHeight()} /> ); - } - } else if (confirmationDetails.type === 'exec') { - const executionProps = confirmationDetails; - - const commandsToDisplay = - executionProps.commands && executionProps.commands.length > 1 - ? executionProps.commands - : [executionProps.command]; - const containsRedirection = commandsToDisplay.some((cmd) => - hasRedirection(cmd), - ); - - let bodyContentHeight = availableBodyContentHeight(); - let warnings: React.ReactNode = null; - - if (bodyContentHeight !== undefined) { - bodyContentHeight -= 2; // Account for padding; + return { + question: '', + bodyContent, + options: [], + securityWarnings: null, + initialIndex: 0, + }; } - if (containsRedirection) { - // Calculate lines needed for Note and Tip - const safeWidth = Math.max(terminalWidth, 1); - const noteLength = - REDIRECTION_WARNING_NOTE_LABEL.length + - REDIRECTION_WARNING_NOTE_TEXT.length; - const tipLength = - REDIRECTION_WARNING_TIP_LABEL.length + - REDIRECTION_WARNING_TIP_TEXT.length; + if (confirmationDetails.type === 'exit_plan_mode') { + bodyContent = ( + { + handleConfirm(ToolConfirmationOutcome.ProceedOnce, { + approved: true, + approvalMode, + }); + }} + onFeedback={(feedback) => { + handleConfirm(ToolConfirmationOutcome.ProceedOnce, { + approved: false, + feedback, + }); + }} + onCancel={() => { + handleConfirm(ToolConfirmationOutcome.Cancel); + }} + width={terminalWidth} + availableHeight={availableBodyContentHeight()} + /> + ); + return { + question: '', + bodyContent, + options: [], + securityWarnings: null, + initialIndex: 0, + }; + } - const noteLines = Math.ceil(noteLength / safeWidth); - const tipLines = Math.ceil(tipLength / safeWidth); - const spacerLines = 1; - const warningHeight = noteLines + tipLines + spacerLines; + if (confirmationDetails.type === 'edit') { + if (!confirmationDetails.isModifying) { + question = `Apply this change?`; + } + } else if (confirmationDetails.type === 'exec') { + const executionProps = confirmationDetails; + + if (executionProps.commands && executionProps.commands.length > 1) { + question = `Allow execution of ${executionProps.commands.length} commands?`; + } else { + question = `Allow execution of: '${sanitizeForDisplay(executionProps.rootCommand)}'?`; + } + } else if (confirmationDetails.type === 'info') { + question = `Do you want to proceed?`; + } else if (confirmationDetails.type === 'mcp') { + // mcp tool confirmation + const mcpProps = confirmationDetails; + question = `Allow execution of MCP tool "${sanitizeForDisplay(mcpProps.toolName)}" from server "${sanitizeForDisplay(mcpProps.serverName)}"?`; + } + + if (confirmationDetails.type === 'edit') { + if (!confirmationDetails.isModifying) { + bodyContent = ( + + ); + } + } else if (confirmationDetails.type === 'exec') { + const executionProps = confirmationDetails; + + const commandsToDisplay = + executionProps.commands && executionProps.commands.length > 1 + ? executionProps.commands + : [executionProps.command]; + const containsRedirection = commandsToDisplay.some((cmd) => + hasRedirection(cmd), + ); + + let bodyContentHeight = availableBodyContentHeight(); + let warnings: React.ReactNode = null; if (bodyContentHeight !== undefined) { - bodyContentHeight = Math.max( - bodyContentHeight - warningHeight, - MINIMUM_MAX_HEIGHT, + bodyContentHeight -= 2; // Account for padding; + } + + if (containsRedirection) { + // Calculate lines needed for Note and Tip + const safeWidth = Math.max(terminalWidth, 1); + const noteLength = + REDIRECTION_WARNING_NOTE_LABEL.length + + REDIRECTION_WARNING_NOTE_TEXT.length; + const tipLength = + REDIRECTION_WARNING_TIP_LABEL.length + + REDIRECTION_WARNING_TIP_TEXT.length; + + const noteLines = Math.ceil(noteLength / safeWidth); + const tipLines = Math.ceil(tipLength / safeWidth); + const spacerLines = 1; + const warningHeight = noteLines + tipLines + spacerLines; + + if (bodyContentHeight !== undefined) { + bodyContentHeight = Math.max( + bodyContentHeight - warningHeight, + MINIMUM_MAX_HEIGHT, + ); + } + + warnings = ( + <> + + + + {REDIRECTION_WARNING_NOTE_LABEL} + {REDIRECTION_WARNING_NOTE_TEXT} + + + + + {REDIRECTION_WARNING_TIP_LABEL} + {REDIRECTION_WARNING_TIP_TEXT} + + + ); } - warnings = ( - <> - - - - {REDIRECTION_WARNING_NOTE_LABEL} - {REDIRECTION_WARNING_NOTE_TEXT} + bodyContent = ( + + + + {commandsToDisplay.map((cmd, idx) => ( + + {sanitizeForDisplay(cmd)} + + ))} + + + {warnings} + + ); + } else if (confirmationDetails.type === 'info') { + const infoProps = confirmationDetails; + const displayUrls = + infoProps.urls && + !( + infoProps.urls.length === 1 && + infoProps.urls[0] === infoProps.prompt + ); + + bodyContent = ( + + + + + {displayUrls && infoProps.urls && infoProps.urls.length > 0 && ( + + URLs to fetch: + {infoProps.urls.map((urlString) => ( + + {' '} + - + + ))} + + )} + + ); + } else if (confirmationDetails.type === 'mcp') { + // mcp tool confirmation + const mcpProps = confirmationDetails; + + bodyContent = ( + + <> + + MCP Server: {sanitizeForDisplay(mcpProps.serverName)} - - - - {REDIRECTION_WARNING_TIP_LABEL} - {REDIRECTION_WARNING_TIP_TEXT} + + Tool: {sanitizeForDisplay(mcpProps.toolName)} - - + + {hasMcpToolDetails && ( + + MCP Tool Details: + {isMcpToolDetailsExpanded ? ( + <> + + (press {expandDetailsHintKey} to collapse MCP tool + details) + + {mcpToolDetailsText} + + ) : ( + + (press {expandDetailsHintKey} to expand MCP tool details) + + )} + + )} + ); } - bodyContent = ( - - - - {commandsToDisplay.map((cmd, idx) => ( - - {sanitizeForDisplay(cmd)} - - ))} - - - {warnings} - - ); - } else if (confirmationDetails.type === 'info') { - const infoProps = confirmationDetails; - const displayUrls = - infoProps.urls && - !( - infoProps.urls.length === 1 && infoProps.urls[0] === infoProps.prompt - ); - - bodyContent = ( - - - - - {displayUrls && infoProps.urls && infoProps.urls.length > 0 && ( - - URLs to fetch: - {infoProps.urls.map((urlString) => ( - - {' '} - - - - ))} - - )} - - ); - } else if (confirmationDetails.type === 'mcp') { - // mcp tool confirmation - const mcpProps = confirmationDetails; - - bodyContent = ( - - <> - - MCP Server: {sanitizeForDisplay(mcpProps.serverName)} - - - Tool: {sanitizeForDisplay(mcpProps.toolName)} - - - {hasMcpToolDetails && ( - - MCP Tool Details: - {isMcpToolDetailsExpanded ? ( - <> - - (press {expandDetailsHintKey} to collapse MCP tool details) - - {mcpToolDetailsText} - - ) : ( - - (press {expandDetailsHintKey} to expand MCP tool details) - - )} - - )} - - ); - } - - return { question, bodyContent, options, securityWarnings }; - }, [ - confirmationDetails, - getOptions, - availableBodyContentHeight, - terminalWidth, - handleConfirm, - deceptiveUrlWarningText, - isMcpToolDetailsExpanded, - hasMcpToolDetails, - mcpToolDetailsText, - expandDetailsHintKey, - getPreferredEditor, - ]); + return { question, bodyContent, options, securityWarnings, initialIndex }; + }, [ + confirmationDetails, + getOptions, + availableBodyContentHeight, + terminalWidth, + handleConfirm, + deceptiveUrlWarningText, + isMcpToolDetailsExpanded, + hasMcpToolDetails, + mcpToolDetailsText, + expandDetailsHintKey, + getPreferredEditor, + settings.merged.security.autoAddToPolicyByDefault, + isTrustedFolder, + allowPermanentApproval, + ]); const bodyOverflowDirection: 'top' | 'bottom' = confirmationDetails.type === 'mcp' && isMcpToolDetailsExpanded @@ -697,6 +734,7 @@ export const ToolConfirmationMessage: React.FC< items={options} onSelect={handleSelect} isFocused={isFocused} + initialIndex={initialIndex} /> diff --git a/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap b/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap index 9e8dfe3a15..bb1d4a7670 100644 --- a/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap +++ b/packages/cli/src/ui/components/messages/__snapshots__/ToolConfirmationMessage.test.tsx.snap @@ -1,5 +1,21 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html +exports[`ToolConfirmationMessage > enablePermanentToolApproval setting > should default to "Allow for all future sessions" when autoAddToPolicyByDefault is true 1`] = ` +"╭──────────────────────────────────────────────────────────────────────────────╮ +│ │ +│ No changes detected. │ +│ │ +╰──────────────────────────────────────────────────────────────────────────────╯ +Apply this change? + + 1. Allow once + 2. Allow for this session +● 3. Allow for all future sessions + 4. Modify with external editor + 5. No, suggest changes (esc) +" +`; + exports[`ToolConfirmationMessage > should display multiple commands for exec type when provided 1`] = ` "echo "hello" ls -la diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index cff1eb2714..2248061b30 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -548,6 +548,7 @@ export interface ConfigParameters { truncateToolOutputThreshold?: number; eventEmitter?: EventEmitter; useWriteTodos?: boolean; + workspacePoliciesDir?: string; policyEngineConfig?: PolicyEngineConfig; directWebFetch?: boolean; policyUpdateConfirmationRequest?: PolicyUpdateConfirmationRequest; @@ -740,6 +741,7 @@ export class Config implements McpContext { private readonly fileExclusions: FileExclusions; private readonly eventEmitter?: EventEmitter; private readonly useWriteTodos: boolean; + private readonly workspacePoliciesDir: string | undefined; private readonly messageBus: MessageBus; private readonly policyEngine: PolicyEngine; private policyUpdateConfirmationRequest: @@ -951,6 +953,7 @@ export class Config implements McpContext { this.useWriteTodos = isPreviewModel(this.model) ? false : (params.useWriteTodos ?? true); + this.workspacePoliciesDir = params.workspacePoliciesDir; this.enableHooksUI = params.enableHooksUI ?? true; this.enableHooks = params.enableHooks ?? true; this.disabledHooks = params.disabledHooks ?? []; @@ -1956,6 +1959,10 @@ export class Config implements McpContext { return this.geminiMdFilePaths; } + getWorkspacePoliciesDir(): string | undefined { + return this.workspacePoliciesDir; + } + setGeminiMdFilePaths(paths: string[]): void { this.geminiMdFilePaths = paths; } diff --git a/packages/core/src/config/storage.ts b/packages/core/src/config/storage.ts index 10e88543ba..72fe2757db 100644 --- a/packages/core/src/config/storage.ts +++ b/packages/core/src/config/storage.ts @@ -168,6 +168,13 @@ export class Storage { return path.join(this.getGeminiDir(), 'policies'); } + getWorkspaceAutoSavedPolicyPath(): string { + return path.join( + this.getWorkspacePoliciesDir(), + AUTO_SAVED_POLICY_FILENAME, + ); + } + getAutoSavedPolicyPath(): string { return path.join(Storage.getUserPoliciesDir(), AUTO_SAVED_POLICY_FILENAME); } diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index aefafe0fa0..2b77f87d81 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -118,6 +118,7 @@ export interface UpdatePolicy { type: MessageBusType.UPDATE_POLICY; toolName: string; persist?: boolean; + persistScope?: 'workspace' | 'user'; argsPattern?: string; commandPrefix?: string | string[]; mcpName?: string; diff --git a/packages/core/src/policy/config.ts b/packages/core/src/policy/config.ts index a1e337436e..9b2a4ca258 100644 --- a/packages/core/src/policy/config.ts +++ b/packages/core/src/policy/config.ts @@ -520,9 +520,21 @@ export function createPolicyUpdater( if (message.persist) { persistenceQueue = persistenceQueue.then(async () => { try { - const policyFile = storage.getAutoSavedPolicyPath(); + const policyFile = + message.persistScope === 'workspace' + ? storage.getWorkspaceAutoSavedPolicyPath() + : storage.getAutoSavedPolicyPath(); await fs.mkdir(path.dirname(policyFile), { recursive: true }); + // Backup existing file if it exists + try { + await fs.copyFile(policyFile, `${policyFile}.bak`); + } catch (error) { + if (!isNodeError(error) || error.code !== 'ENOENT') { + debugLogger.warn(`Failed to backup ${policyFile}`, error); + } + } + // Read existing file let existingData: { rule?: TomlRule[] } = {}; try { diff --git a/packages/core/src/policy/persistence.test.ts b/packages/core/src/policy/persistence.test.ts index c5a71fdd93..c9c3e046d1 100644 --- a/packages/core/src/policy/persistence.test.ts +++ b/packages/core/src/policy/persistence.test.ts @@ -230,15 +230,87 @@ describe('createPolicyUpdater', () => { // Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar' // instead of "foo\"bar\"" if there are no single quotes in the string. try { - expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`); + expect(writtenContent).toContain('mcpName = "my\\"jira\\"server"'); } catch { - expect(writtenContent).toContain(`mcpName = 'my"jira"server'`); + expect(writtenContent).toContain('mcpName = \'my"jira"server\''); } try { - expect(writtenContent).toContain(`toolName = "search\\"tool\\""`); + expect(writtenContent).toContain('toolName = "search\\"tool\\""'); } catch { - expect(writtenContent).toContain(`toolName = 'search"tool"'`); + expect(writtenContent).toContain('toolName = \'search"tool"\''); } }); + + it('should persist to workspace when persistScope is workspace', async () => { + createPolicyUpdater(policyEngine, messageBus, mockStorage); + + const workspacePoliciesDir = '/mock/project/.gemini/policies'; + const policyFile = path.join( + workspacePoliciesDir, + AUTO_SAVED_POLICY_FILENAME, + ); + vi.spyOn(mockStorage, 'getWorkspaceAutoSavedPolicyPath').mockReturnValue( + policyFile, + ); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockRejectedValue( + new Error('File not found'), + ); + (fs.copyFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'test_tool', + persist: true, + persistScope: 'workspace', + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(mockStorage.getWorkspaceAutoSavedPolicyPath).toHaveBeenCalled(); + expect(fs.mkdir).toHaveBeenCalledWith(workspacePoliciesDir, { + recursive: true, + }); + expect(fs.rename).toHaveBeenCalledWith( + expect.stringMatching(/\.tmp$/), + policyFile, + ); + }); + + it('should backup existing policy file before writing', async () => { + createPolicyUpdater(policyEngine, messageBus, mockStorage); + + const policyFile = '/mock/user/.gemini/policies/auto-saved.toml'; + vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile); + (fs.mkdir as unknown as Mock).mockResolvedValue(undefined); + (fs.readFile as unknown as Mock).mockResolvedValue( + '[[rule]]\ntoolName = "existing"', + ); + (fs.copyFile as unknown as Mock).mockResolvedValue(undefined); + + const mockFileHandle = { + writeFile: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + }; + (fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); + (fs.rename as unknown as Mock).mockResolvedValue(undefined); + + await messageBus.publish({ + type: MessageBusType.UPDATE_POLICY, + toolName: 'new_tool', + persist: true, + }); + + await new Promise((resolve) => setTimeout(resolve, 0)); + + expect(fs.copyFile).toHaveBeenCalledWith(policyFile, `${policyFile}.bak`); + }); }); diff --git a/packages/core/src/policy/utils.ts b/packages/core/src/policy/utils.ts index 3742ba3ed6..cfe57f18ff 100644 --- a/packages/core/src/policy/utils.ts +++ b/packages/core/src/policy/utils.ts @@ -82,3 +82,15 @@ export function buildArgsPatterns( return [argsPattern]; } + +/** + * Builds a regex pattern to match a specific file path in tool arguments. + * This is used to narrow tool approvals for edit tools to specific files. + * + * @param filePath The relative path to the file. + * @returns A regex string that matches "file_path":"" in a JSON string. + */ +export function buildFilePathArgsPattern(filePath: string): string { + const jsonPath = JSON.stringify(filePath).slice(1, -1); + return `"file_path":"${escapeRegex(jsonPath)}"`; +} diff --git a/packages/core/src/scheduler/policy.test.ts b/packages/core/src/scheduler/policy.test.ts index 05f5b08a2f..93b1a56421 100644 --- a/packages/core/src/scheduler/policy.test.ts +++ b/packages/core/src/scheduler/policy.test.ts @@ -16,7 +16,10 @@ import { import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js'; import type { Config } from '../config/config.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; -import { MessageBusType } from '../confirmation-bus/types.js'; +import { + MessageBusType, + type SerializableConfirmationDetails, +} from '../confirmation-bus/types.js'; import { ApprovalMode, PolicyDecision } from '../policy/types.js'; import { ToolConfirmationOutcome, @@ -198,6 +201,8 @@ describe('policy.ts', () => { it('should handle standard policy updates with persistence', async () => { const mockConfig = { + isTrustedFolder: vi.fn().mockReturnValue(false), + getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined), setApprovalMode: vi.fn(), } as unknown as Mocked; const mockMessageBus = { @@ -408,6 +413,8 @@ describe('policy.ts', () => { it('should handle MCP ProceedAlwaysAndSave (persist: true)', async () => { const mockConfig = { + isTrustedFolder: vi.fn().mockReturnValue(false), + getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined), setApprovalMode: vi.fn(), } as unknown as Mocked; const mockMessageBus = { @@ -439,6 +446,92 @@ describe('policy.ts', () => { }), ); }); + + it('should determine persistScope: workspace in trusted folders', async () => { + const mockConfig = { + isTrustedFolder: vi.fn().mockReturnValue(true), + getWorkspacePoliciesDir: vi + .fn() + .mockReturnValue('/mock/project/policies'), + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysAndSave, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + persistScope: 'workspace', + }), + ); + }); + + it('should determine persistScope: user in untrusted folders', async () => { + const mockConfig = { + isTrustedFolder: vi.fn().mockReturnValue(false), + getWorkspacePoliciesDir: vi + .fn() + .mockReturnValue('/mock/project/policies'), + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'test-tool' } as AnyDeclarativeTool; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysAndSave, + undefined, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + persistScope: 'user', + }), + ); + }); + + it('should narrow edit tools with argsPattern', async () => { + const mockConfig = { + isTrustedFolder: vi.fn().mockReturnValue(false), + getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined), + setApprovalMode: vi.fn(), + } as unknown as Mocked; + const mockMessageBus = { + publish: vi.fn(), + } as unknown as Mocked; + const tool = { name: 'write_file' } as AnyDeclarativeTool; + const details = { + type: 'edit', + filePath: 'src/foo.ts', + title: 'Edit', + onConfirm: vi.fn(), + } as unknown as SerializableConfirmationDetails; + + await updatePolicy( + tool, + ToolConfirmationOutcome.ProceedAlwaysAndSave, + details, + { config: mockConfig, messageBus: mockMessageBus }, + ); + + expect(mockMessageBus.publish).toHaveBeenCalledWith( + expect.objectContaining({ + toolName: 'write_file', + argsPattern: '"file_path":"src/foo\\.ts"', + }), + ); + }); }); describe('getPolicyDenialError', () => { diff --git a/packages/core/src/scheduler/policy.ts b/packages/core/src/scheduler/policy.ts index ad4aa745bb..cdf9894108 100644 --- a/packages/core/src/scheduler/policy.ts +++ b/packages/core/src/scheduler/policy.ts @@ -22,6 +22,7 @@ import { type AnyDeclarativeTool, type PolicyUpdateOptions, } from '../tools/tools.js'; +import { buildFilePathArgsPattern } from '../policy/utils.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { EDIT_TOOL_NAMES } from '../tools/tool-names.js'; import type { ValidatingToolCall } from './types.js'; @@ -102,6 +103,20 @@ export async function updatePolicy( return; } + // Determine persist scope if we are persisting. + let persistScope: 'workspace' | 'user' | undefined; + if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) { + // If folder is trusted and workspace policies are enabled, we prefer workspace scope. + if ( + deps.config.isTrustedFolder() && + deps.config.getWorkspacePoliciesDir() + ) { + persistScope = 'workspace'; + } else { + persistScope = 'user'; + } + } + // Specialized Tools (MCP) if (confirmationDetails?.type === 'mcp') { await handleMcpPolicyUpdate( @@ -109,6 +124,7 @@ export async function updatePolicy( outcome, confirmationDetails, deps.messageBus, + persistScope, ); return; } @@ -119,6 +135,7 @@ export async function updatePolicy( outcome, confirmationDetails, deps.messageBus, + persistScope, ); } @@ -148,6 +165,7 @@ async function handleStandardPolicyUpdate( outcome: ToolConfirmationOutcome, confirmationDetails: SerializableConfirmationDetails | undefined, messageBus: MessageBus, + persistScope?: 'workspace' | 'user', ): Promise { if ( outcome === ToolConfirmationOutcome.ProceedAlways || @@ -157,12 +175,17 @@ async function handleStandardPolicyUpdate( if (confirmationDetails?.type === 'exec') { options.commandPrefix = confirmationDetails.rootCommands; + } else if (confirmationDetails?.type === 'edit') { + options.argsPattern = buildFilePathArgsPattern( + confirmationDetails.filePath, + ); } await messageBus.publish({ type: MessageBusType.UPDATE_POLICY, toolName: tool.name, persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave, + persistScope, ...options, }); } @@ -180,6 +203,7 @@ async function handleMcpPolicyUpdate( { type: 'mcp' } >, messageBus: MessageBus, + persistScope?: 'workspace' | 'user', ): Promise { const isMcpAlways = outcome === ToolConfirmationOutcome.ProceedAlways || @@ -204,5 +228,6 @@ async function handleMcpPolicyUpdate( toolName, mcpName: confirmationDetails.serverName, persist, + persistScope, }); } diff --git a/packages/core/src/tools/edit.ts b/packages/core/src/tools/edit.ts index a7169e99f2..736baa9038 100644 --- a/packages/core/src/tools/edit.ts +++ b/packages/core/src/tools/edit.ts @@ -20,7 +20,9 @@ import { type ToolLocation, type ToolResult, type ToolResultDisplay, + type PolicyUpdateOptions, } from './tools.js'; +import { buildFilePathArgsPattern } from '../policy/utils.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { ToolErrorType } from './tool-error.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; @@ -442,6 +444,14 @@ class EditToolInvocation return [{ path: this.params.file_path }]; } + protected override getPolicyUpdateOptions( + _outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined { + return { + argsPattern: buildFilePathArgsPattern(this.params.file_path), + }; + } + private async attemptSelfCorrection( params: EditToolParams, currentContent: string, diff --git a/packages/core/src/tools/tools.ts b/packages/core/src/tools/tools.ts index 0a82cc1510..79bb599ff8 100644 --- a/packages/core/src/tools/tools.ts +++ b/packages/core/src/tools/tools.ts @@ -74,6 +74,7 @@ export interface ToolInvocation< * Options for policy updates that can be customized by tool invocations. */ export interface PolicyUpdateOptions { + argsPattern?: string; commandPrefix?: string | string[]; mcpName?: string; } diff --git a/packages/core/src/tools/write-file.ts b/packages/core/src/tools/write-file.ts index f78821f0e1..40c11094e4 100644 --- a/packages/core/src/tools/write-file.ts +++ b/packages/core/src/tools/write-file.ts @@ -24,7 +24,9 @@ import { type ToolLocation, type ToolResult, type ToolConfirmationOutcome, + type PolicyUpdateOptions, } from './tools.js'; +import { buildFilePathArgsPattern } from '../policy/utils.js'; import { ToolErrorType } from './tool-error.js'; import { makeRelative, shortenPath } from '../utils/paths.js'; import { getErrorMessage, isNodeError } from '../utils/errors.js'; @@ -150,6 +152,14 @@ class WriteFileToolInvocation extends BaseToolInvocation< return [{ path: this.resolvedPath }]; } + protected override getPolicyUpdateOptions( + _outcome: ToolConfirmationOutcome, + ): PolicyUpdateOptions | undefined { + return { + argsPattern: buildFilePathArgsPattern(this.params.file_path), + }; + } + override getDescription(): string { const relativePath = makeRelative( this.resolvedPath,