feat(policy): support auto-add to policy by default and scoped persistence (#20361)

This commit is contained in:
Spencer
2026-03-10 13:01:41 -04:00
committed by GitHub
parent 49ea9b0457
commit a220874281
31 changed files with 929 additions and 498 deletions
+1
View File
@@ -125,6 +125,7 @@ they appear in the UI.
| ------------------------------------- | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- | | ------------------------------------- | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ------- |
| Disable YOLO Mode | `security.disableYoloMode` | Disable YOLO mode, even if enabled by a flag. | `false` | | Disable YOLO Mode | `security.disableYoloMode` | Disable YOLO mode, even if enabled by a flag. | `false` |
| Allow Permanent Tool Approval | `security.enablePermanentToolApproval` | Enable the "Allow for all future sessions" option in tool confirmation dialogs. | `false` | | Allow Permanent Tool Approval | `security.enablePermanentToolApproval` | Enable the "Allow for all future sessions" option in tool confirmation dialogs. | `false` |
| Auto-add to Policy by Default | `security.autoAddToPolicyByDefault` | When enabled, the "Allow for all future sessions" option becomes the default choice for low-risk tools in trusted workspaces. | `false` |
| Blocks extensions from Git | `security.blockGitExtensions` | Blocks installing and loading extensions from Git. | `false` | | Blocks extensions from Git | `security.blockGitExtensions` | Blocks installing and loading extensions from Git. | `false` |
| Extension Source Regex Allowlist | `security.allowedExtensions` | List of Regex patterns for allowed extensions. If nonempty, only extensions that match the patterns in this list are allowed. Overrides the blockGitExtensions setting. | `[]` | | Extension Source Regex Allowlist | `security.allowedExtensions` | List of Regex patterns for allowed extensions. If nonempty, only extensions that match the patterns in this list are allowed. Overrides the blockGitExtensions setting. | `[]` |
| Folder Trust | `security.folderTrust.enabled` | Setting to track whether Folder trust is enabled. | `true` | | Folder Trust | `security.folderTrust.enabled` | Setting to track whether Folder trust is enabled. | `true` |
+5
View File
@@ -872,6 +872,11 @@ their corresponding top-level category object in your `settings.json` file.
confirmation dialogs. confirmation dialogs.
- **Default:** `false` - **Default:** `false`
- **`security.autoAddToPolicyByDefault`** (boolean):
- **Description:** When enabled, the "Allow for all future sessions" option
becomes the default choice for low-risk tools in trusted workspaces.
- **Default:** `false`
- **`security.blockGitExtensions`** (boolean): - **`security.blockGitExtensions`** (boolean):
- **Description:** Blocks installing and loading extensions from Git. - **Description:** Blocks installing and loading extensions from Git.
- **Default:** `false` - **Default:** `false`
+12
View File
@@ -1496,6 +1496,18 @@ const SETTINGS_SCHEMA = {
'Enable the "Allow for all future sessions" option in tool confirmation dialogs.', 'Enable the "Allow for all future sessions" option in tool confirmation dialogs.',
showInDialog: true, showInDialog: true,
}, },
autoAddToPolicyByDefault: {
type: 'boolean',
label: 'Auto-add to Policy by Default',
category: 'Security',
requiresRestart: false,
default: false,
description: oneLine`
When enabled, the "Allow for all future sessions" option becomes the
default choice for low-risk tools in trusted workspaces.
`,
showInDialog: true,
},
blockGitExtensions: { blockGitExtensions: {
type: 'boolean', type: 'boolean',
label: 'Blocks extensions from Git', label: 'Blocks extensions from Git',
@@ -411,7 +411,7 @@ describe('ToolConfirmationMessage', () => {
unmount(); unmount();
}); });
it('should show "Allow for all future sessions" when setting is true', async () => { it('should show "Allow for all future sessions" when trusted', async () => {
const mockConfig = { const mockConfig = {
isTrustedFolder: () => true, isTrustedFolder: () => true,
getIdeMode: () => false, getIdeMode: () => false,
@@ -434,7 +434,10 @@ describe('ToolConfirmationMessage', () => {
); );
await waitUntilReady(); await waitUntilReady();
expect(lastFrame()).toContain('Allow for all future sessions'); const output = lastFrame();
expect(output).toContain('future sessions');
// Verify it is the default selection (matching the indicator in the snapshot)
expect(output).toMatchSnapshot();
unmount(); unmount();
}); });
}); });
@@ -246,9 +246,9 @@ export const ToolConfirmationMessage: React.FC<
}); });
if (allowPermanentApproval) { if (allowPermanentApproval) {
options.push({ options.push({
label: 'Allow for all future sessions', label: 'Allow for this file in all future sessions',
value: ToolConfirmationOutcome.ProceedAlwaysAndSave, value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
key: 'Allow for all future sessions', key: 'Allow for this file in all future sessions',
}); });
} }
} }
@@ -282,7 +282,7 @@ export const ToolConfirmationMessage: React.FC<
}); });
if (allowPermanentApproval) { if (allowPermanentApproval) {
options.push({ options.push({
label: `Allow for all future sessions`, label: `Allow this command for all future sessions`,
value: ToolConfirmationOutcome.ProceedAlwaysAndSave, value: ToolConfirmationOutcome.ProceedAlwaysAndSave,
key: `Allow for all future sessions`, key: `Allow for all future sessions`,
}); });
@@ -388,17 +388,41 @@ export const ToolConfirmationMessage: React.FC<
return Math.max(availableTerminalHeight - surroundingElementsHeight, 1); return Math.max(availableTerminalHeight - surroundingElementsHeight, 1);
}, [availableTerminalHeight, getOptions, handlesOwnUI]); }, [availableTerminalHeight, getOptions, handlesOwnUI]);
const { question, bodyContent, options, securityWarnings } = useMemo<{ const { question, bodyContent, options, securityWarnings, initialIndex } =
useMemo<{
question: string; question: string;
bodyContent: React.ReactNode; bodyContent: React.ReactNode;
options: Array<RadioSelectItem<ToolConfirmationOutcome>>; options: Array<RadioSelectItem<ToolConfirmationOutcome>>;
securityWarnings: React.ReactNode; securityWarnings: React.ReactNode;
initialIndex: number;
}>(() => { }>(() => {
let bodyContent: React.ReactNode | null = null; let bodyContent: React.ReactNode | null = null;
let securityWarnings: React.ReactNode | null = null; let securityWarnings: React.ReactNode | null = null;
let question = ''; let question = '';
const options = getOptions(); const options = getOptions();
let initialIndex = 0;
if (isTrustedFolder && allowPermanentApproval) {
// It is safe to allow permanent approval for info, edit, and mcp tools
// in trusted folders because the generated policy rules are narrowed
// to specific files, patterns, or tools (rather than allowing all access).
const isSafeToPersist =
confirmationDetails.type === 'info' ||
confirmationDetails.type === 'edit' ||
confirmationDetails.type === 'mcp';
if (
isSafeToPersist &&
settings.merged.security.autoAddToPolicyByDefault
) {
const alwaysAndSaveIndex = options.findIndex(
(o) => o.value === ToolConfirmationOutcome.ProceedAlwaysAndSave,
);
if (alwaysAndSaveIndex !== -1) {
initialIndex = alwaysAndSaveIndex;
}
}
}
if (deceptiveUrlWarningText) { if (deceptiveUrlWarningText) {
securityWarnings = <WarningMessage text={deceptiveUrlWarningText} />; securityWarnings = <WarningMessage text={deceptiveUrlWarningText} />;
} }
@@ -422,6 +446,7 @@ export const ToolConfirmationMessage: React.FC<
bodyContent, bodyContent,
options: [], options: [],
securityWarnings: null, securityWarnings: null,
initialIndex: 0,
}; };
} }
@@ -449,7 +474,13 @@ export const ToolConfirmationMessage: React.FC<
availableHeight={availableBodyContentHeight()} availableHeight={availableBodyContentHeight()}
/> />
); );
return { question: '', bodyContent, options: [], securityWarnings: null }; return {
question: '',
bodyContent,
options: [],
securityWarnings: null,
initialIndex: 0,
};
} }
if (confirmationDetails.type === 'edit') { if (confirmationDetails.type === 'edit') {
@@ -504,12 +535,12 @@ export const ToolConfirmationMessage: React.FC<
if (containsRedirection) { if (containsRedirection) {
// Calculate lines needed for Note and Tip // Calculate lines needed for Note and Tip
const safeWidth = Math.max(terminalWidth, 1); const safeWidth = Math.max(terminalWidth, 1);
const tipText = `Toggle auto-edit (${formatCommand(Command.CYCLE_APPROVAL_MODE)}) to allow redirection in the future.`;
const noteLength = const noteLength =
REDIRECTION_WARNING_NOTE_LABEL.length + REDIRECTION_WARNING_NOTE_LABEL.length +
REDIRECTION_WARNING_NOTE_TEXT.length; REDIRECTION_WARNING_NOTE_TEXT.length;
const tipLength = REDIRECTION_WARNING_TIP_LABEL.length + tipText.length; const tipText = `Toggle auto-edit (${formatCommand(Command.CYCLE_APPROVAL_MODE)}) to allow redirection in the future.`;
const tipLength =
REDIRECTION_WARNING_TIP_LABEL.length + tipText.length;
const noteLines = Math.ceil(noteLength / safeWidth); const noteLines = Math.ceil(noteLength / safeWidth);
const tipLines = Math.ceil(tipLength / safeWidth); const tipLines = Math.ceil(tipLength / safeWidth);
@@ -574,7 +605,8 @@ export const ToolConfirmationMessage: React.FC<
const displayUrls = const displayUrls =
infoProps.urls && infoProps.urls &&
!( !(
infoProps.urls.length === 1 && infoProps.urls[0] === infoProps.prompt infoProps.urls.length === 1 &&
infoProps.urls[0] === infoProps.prompt
); );
bodyContent = ( bodyContent = (
@@ -618,7 +650,8 @@ export const ToolConfirmationMessage: React.FC<
{isMcpToolDetailsExpanded ? ( {isMcpToolDetailsExpanded ? (
<> <>
<Text color={theme.text.secondary}> <Text color={theme.text.secondary}>
(press {expandDetailsHintKey} to collapse MCP tool details) (press {expandDetailsHintKey} to collapse MCP tool
details)
</Text> </Text>
<Text color={theme.text.link}>{mcpToolDetailsText}</Text> <Text color={theme.text.link}>{mcpToolDetailsText}</Text>
</> </>
@@ -633,7 +666,7 @@ export const ToolConfirmationMessage: React.FC<
); );
} }
return { question, bodyContent, options, securityWarnings }; return { question, bodyContent, options, securityWarnings, initialIndex };
}, [ }, [
confirmationDetails, confirmationDetails,
getOptions, getOptions,
@@ -646,6 +679,8 @@ export const ToolConfirmationMessage: React.FC<
mcpToolDetailsText, mcpToolDetailsText,
expandDetailsHintKey, expandDetailsHintKey,
getPreferredEditor, getPreferredEditor,
isTrustedFolder,
allowPermanentApproval,
settings, settings,
]); ]);
@@ -710,6 +745,7 @@ export const ToolConfirmationMessage: React.FC<
items={options} items={options}
onSelect={handleSelect} onSelect={handleSelect}
isFocused={isFocused} isFocused={isFocused}
initialIndex={initialIndex}
/> />
</Box> </Box>
</> </>
@@ -1,5 +1,21 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`ToolConfirmationMessage > enablePermanentToolApproval setting > should show "Allow for all future sessions" when trusted 1`] = `
"╭──────────────────────────────────────────────────────────────────────────────╮
│ │
│ No changes detected. │
│ │
╰──────────────────────────────────────────────────────────────────────────────╯
Apply this change?
● 1. Allow once
2. Allow for this session
3. Allow for this file in all future sessions
4. Modify with external editor
5. No, suggest changes (esc)
"
`;
exports[`ToolConfirmationMessage > should display multiple commands for exec type when provided 1`] = ` exports[`ToolConfirmationMessage > should display multiple commands for exec type when provided 1`] = `
"echo "hello" "echo "hello"
@@ -70,7 +70,7 @@ class McpToolInvocation extends BaseToolInvocation<
}; };
} }
protected override getPolicyUpdateOptions( override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome, _outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined { ): PolicyUpdateOptions | undefined {
return { return {
@@ -177,7 +177,7 @@ class TypeTextInvocation extends BaseToolInvocation<
}; };
} }
protected override getPolicyUpdateOptions( override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome, _outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined { ): PolicyUpdateOptions | undefined {
return { return {
+33 -30
View File
@@ -553,6 +553,7 @@ export interface ConfigParameters {
truncateToolOutputThreshold?: number; truncateToolOutputThreshold?: number;
eventEmitter?: EventEmitter; eventEmitter?: EventEmitter;
useWriteTodos?: boolean; useWriteTodos?: boolean;
workspacePoliciesDir?: string;
policyEngineConfig?: PolicyEngineConfig; policyEngineConfig?: PolicyEngineConfig;
directWebFetch?: boolean; directWebFetch?: boolean;
policyUpdateConfirmationRequest?: PolicyUpdateConfirmationRequest; policyUpdateConfirmationRequest?: PolicyUpdateConfirmationRequest;
@@ -746,6 +747,7 @@ export class Config implements McpContext, AgentLoopContext {
private readonly fileExclusions: FileExclusions; private readonly fileExclusions: FileExclusions;
private readonly eventEmitter?: EventEmitter; private readonly eventEmitter?: EventEmitter;
private readonly useWriteTodos: boolean; private readonly useWriteTodos: boolean;
private readonly workspacePoliciesDir: string | undefined;
private readonly _messageBus: MessageBus; private readonly _messageBus: MessageBus;
private readonly policyEngine: PolicyEngine; private readonly policyEngine: PolicyEngine;
private policyUpdateConfirmationRequest: private policyUpdateConfirmationRequest:
@@ -956,6 +958,7 @@ export class Config implements McpContext, AgentLoopContext {
this.useWriteTodos = isPreviewModel(this.model) this.useWriteTodos = isPreviewModel(this.model)
? false ? false
: (params.useWriteTodos ?? true); : (params.useWriteTodos ?? true);
this.workspacePoliciesDir = params.workspacePoliciesDir;
this.enableHooksUI = params.enableHooksUI ?? true; this.enableHooksUI = params.enableHooksUI ?? true;
this.enableHooks = params.enableHooks ?? true; this.enableHooks = params.enableHooks ?? true;
this.disabledHooks = params.disabledHooks ?? []; this.disabledHooks = params.disabledHooks ?? [];
@@ -1187,7 +1190,7 @@ export class Config implements McpContext, AgentLoopContext {
if (this.getSkillManager().getSkills().length > 0) { if (this.getSkillManager().getSkills().length > 0) {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
this.getToolRegistry().registerTool( this.getToolRegistry().registerTool(
new ActivateSkillTool(this, this._messageBus), new ActivateSkillTool(this, this.messageBus),
); );
} }
} }
@@ -1999,6 +2002,10 @@ export class Config implements McpContext, AgentLoopContext {
return this.geminiMdFilePaths; return this.geminiMdFilePaths;
} }
getWorkspacePoliciesDir(): string | undefined {
return this.workspacePoliciesDir;
}
setGeminiMdFilePaths(paths: string[]): void { setGeminiMdFilePaths(paths: string[]): void {
this.geminiMdFilePaths = paths; this.geminiMdFilePaths = paths;
} }
@@ -2621,7 +2628,7 @@ export class Config implements McpContext, AgentLoopContext {
if (this.getSkillManager().getSkills().length > 0) { if (this.getSkillManager().getSkills().length > 0) {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
this.getToolRegistry().registerTool( this.getToolRegistry().registerTool(
new ActivateSkillTool(this, this._messageBus), new ActivateSkillTool(this, this.messageBus),
); );
} else { } else {
this.getToolRegistry().unregisterTool(ActivateSkillTool.Name); this.getToolRegistry().unregisterTool(ActivateSkillTool.Name);
@@ -2805,7 +2812,7 @@ export class Config implements McpContext, AgentLoopContext {
} }
async createToolRegistry(): Promise<ToolRegistry> { async createToolRegistry(): Promise<ToolRegistry> {
const registry = new ToolRegistry(this, this._messageBus); const registry = new ToolRegistry(this, this.messageBus);
// helper to create & register core tools that are enabled // helper to create & register core tools that are enabled
const maybeRegister = ( const maybeRegister = (
@@ -2835,10 +2842,10 @@ export class Config implements McpContext, AgentLoopContext {
}; };
maybeRegister(LSTool, () => maybeRegister(LSTool, () =>
registry.registerTool(new LSTool(this, this._messageBus)), registry.registerTool(new LSTool(this, this.messageBus)),
); );
maybeRegister(ReadFileTool, () => maybeRegister(ReadFileTool, () =>
registry.registerTool(new ReadFileTool(this, this._messageBus)), registry.registerTool(new ReadFileTool(this, this.messageBus)),
); );
if (this.getUseRipgrep()) { if (this.getUseRipgrep()) {
@@ -2851,85 +2858,81 @@ export class Config implements McpContext, AgentLoopContext {
} }
if (useRipgrep) { if (useRipgrep) {
maybeRegister(RipGrepTool, () => maybeRegister(RipGrepTool, () =>
registry.registerTool(new RipGrepTool(this, this._messageBus)), registry.registerTool(new RipGrepTool(this, this.messageBus)),
); );
} else { } else {
logRipgrepFallback(this, new RipgrepFallbackEvent(errorString)); logRipgrepFallback(this, new RipgrepFallbackEvent(errorString));
maybeRegister(GrepTool, () => maybeRegister(GrepTool, () =>
registry.registerTool(new GrepTool(this, this._messageBus)), registry.registerTool(new GrepTool(this, this.messageBus)),
); );
} }
} else { } else {
maybeRegister(GrepTool, () => maybeRegister(GrepTool, () =>
registry.registerTool(new GrepTool(this, this._messageBus)), registry.registerTool(new GrepTool(this, this.messageBus)),
); );
} }
maybeRegister(GlobTool, () => maybeRegister(GlobTool, () =>
registry.registerTool(new GlobTool(this, this._messageBus)), registry.registerTool(new GlobTool(this, this.messageBus)),
); );
maybeRegister(ActivateSkillTool, () => maybeRegister(ActivateSkillTool, () =>
registry.registerTool(new ActivateSkillTool(this, this._messageBus)), registry.registerTool(new ActivateSkillTool(this, this.messageBus)),
); );
maybeRegister(EditTool, () => maybeRegister(EditTool, () =>
registry.registerTool(new EditTool(this, this._messageBus)), registry.registerTool(new EditTool(this, this.messageBus)),
); );
maybeRegister(WriteFileTool, () => maybeRegister(WriteFileTool, () =>
registry.registerTool(new WriteFileTool(this, this._messageBus)), registry.registerTool(new WriteFileTool(this, this.messageBus)),
); );
maybeRegister(WebFetchTool, () => maybeRegister(WebFetchTool, () =>
registry.registerTool(new WebFetchTool(this, this._messageBus)), registry.registerTool(new WebFetchTool(this, this.messageBus)),
); );
maybeRegister(ShellTool, () => maybeRegister(ShellTool, () =>
registry.registerTool(new ShellTool(this, this._messageBus)), registry.registerTool(new ShellTool(this, this.messageBus)),
); );
maybeRegister(MemoryTool, () => maybeRegister(MemoryTool, () =>
registry.registerTool(new MemoryTool(this._messageBus)), registry.registerTool(new MemoryTool(this.messageBus)),
); );
maybeRegister(WebSearchTool, () => maybeRegister(WebSearchTool, () =>
registry.registerTool(new WebSearchTool(this, this._messageBus)), registry.registerTool(new WebSearchTool(this, this.messageBus)),
); );
maybeRegister(AskUserTool, () => maybeRegister(AskUserTool, () =>
registry.registerTool(new AskUserTool(this._messageBus)), registry.registerTool(new AskUserTool(this.messageBus)),
); );
if (this.getUseWriteTodos()) { if (this.getUseWriteTodos()) {
maybeRegister(WriteTodosTool, () => maybeRegister(WriteTodosTool, () =>
registry.registerTool(new WriteTodosTool(this._messageBus)), registry.registerTool(new WriteTodosTool(this.messageBus)),
); );
} }
if (this.isPlanEnabled()) { if (this.isPlanEnabled()) {
maybeRegister(ExitPlanModeTool, () => maybeRegister(ExitPlanModeTool, () =>
registry.registerTool(new ExitPlanModeTool(this, this._messageBus)), registry.registerTool(new ExitPlanModeTool(this, this.messageBus)),
); );
maybeRegister(EnterPlanModeTool, () => maybeRegister(EnterPlanModeTool, () =>
registry.registerTool(new EnterPlanModeTool(this, this._messageBus)), registry.registerTool(new EnterPlanModeTool(this, this.messageBus)),
); );
} }
if (this.isTrackerEnabled()) { if (this.isTrackerEnabled()) {
maybeRegister(TrackerCreateTaskTool, () => maybeRegister(TrackerCreateTaskTool, () =>
registry.registerTool( registry.registerTool(new TrackerCreateTaskTool(this, this.messageBus)),
new TrackerCreateTaskTool(this, this._messageBus),
),
); );
maybeRegister(TrackerUpdateTaskTool, () => maybeRegister(TrackerUpdateTaskTool, () =>
registry.registerTool( registry.registerTool(new TrackerUpdateTaskTool(this, this.messageBus)),
new TrackerUpdateTaskTool(this, this._messageBus),
),
); );
maybeRegister(TrackerGetTaskTool, () => maybeRegister(TrackerGetTaskTool, () =>
registry.registerTool(new TrackerGetTaskTool(this, this._messageBus)), registry.registerTool(new TrackerGetTaskTool(this, this.messageBus)),
); );
maybeRegister(TrackerListTasksTool, () => maybeRegister(TrackerListTasksTool, () =>
registry.registerTool(new TrackerListTasksTool(this, this._messageBus)), registry.registerTool(new TrackerListTasksTool(this, this.messageBus)),
); );
maybeRegister(TrackerAddDependencyTool, () => maybeRegister(TrackerAddDependencyTool, () =>
registry.registerTool( registry.registerTool(
new TrackerAddDependencyTool(this, this._messageBus), new TrackerAddDependencyTool(this, this.messageBus),
), ),
); );
maybeRegister(TrackerVisualizeTool, () => maybeRegister(TrackerVisualizeTool, () =>
registry.registerTool(new TrackerVisualizeTool(this, this._messageBus)), registry.registerTool(new TrackerVisualizeTool(this, this.messageBus)),
); );
} }
+7
View File
@@ -172,6 +172,13 @@ export class Storage {
return path.join(this.getGeminiDir(), 'policies'); return path.join(this.getGeminiDir(), 'policies');
} }
getWorkspaceAutoSavedPolicyPath(): string {
return path.join(
this.getWorkspacePoliciesDir(),
AUTO_SAVED_POLICY_FILENAME,
);
}
getAutoSavedPolicyPath(): string { getAutoSavedPolicyPath(): string {
return path.join(Storage.getUserPoliciesDir(), AUTO_SAVED_POLICY_FILENAME); return path.join(Storage.getUserPoliciesDir(), AUTO_SAVED_POLICY_FILENAME);
} }
@@ -122,6 +122,7 @@ export interface UpdatePolicy {
type: MessageBusType.UPDATE_POLICY; type: MessageBusType.UPDATE_POLICY;
toolName: string; toolName: string;
persist?: boolean; persist?: boolean;
persistScope?: 'workspace' | 'user';
argsPattern?: string; argsPattern?: string;
commandPrefix?: string | string[]; commandPrefix?: string | string[];
mcpName?: string; mcpName?: string;
+63 -17
View File
@@ -29,7 +29,7 @@ import { type MessageBus } from '../confirmation-bus/message-bus.js';
import { coreEvents } from '../utils/events.js'; import { coreEvents } from '../utils/events.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { SHELL_TOOL_NAMES } from '../utils/shell-utils.js'; import { SHELL_TOOL_NAMES } from '../utils/shell-utils.js';
import { SHELL_TOOL_NAME } from '../tools/tool-names.js'; import { SHELL_TOOL_NAME, SENSITIVE_TOOLS } from '../tools/tool-names.js';
import { isNodeError } from '../utils/errors.js'; import { isNodeError } from '../utils/errors.js';
import { MCP_TOOL_PREFIX } from '../tools/mcp-tool.js'; import { MCP_TOOL_PREFIX } from '../tools/mcp-tool.js';
@@ -46,13 +46,20 @@ export const WORKSPACE_POLICY_TIER = 3;
export const USER_POLICY_TIER = 4; export const USER_POLICY_TIER = 4;
export const ADMIN_POLICY_TIER = 5; export const ADMIN_POLICY_TIER = 5;
// Specific priority offsets and derived priorities for dynamic/settings rules. /**
// These are added to the tier base (e.g., USER_POLICY_TIER). * The fractional priority of "Always allow" rules (e.g., 950/1000).
* Higher fraction within a tier wins.
*/
export const ALWAYS_ALLOW_PRIORITY_FRACTION = 950;
// Workspace tier (3) + high priority (950/1000) = ALWAYS_ALLOW_PRIORITY /**
// This ensures user "always allow" selections are high priority * The fractional priority offset for "Always allow" rules (e.g., 0.95).
// within the workspace tier but still lose to user/admin policies. * This ensures consistency between in-memory rules and persisted rules.
export const ALWAYS_ALLOW_PRIORITY = WORKSPACE_POLICY_TIER + 0.95; */
export const ALWAYS_ALLOW_PRIORITY_OFFSET =
ALWAYS_ALLOW_PRIORITY_FRACTION / 1000;
// Specific priority offsets and derived priorities for dynamic/settings rules.
export const MCP_EXCLUDED_PRIORITY = USER_POLICY_TIER + 0.9; export const MCP_EXCLUDED_PRIORITY = USER_POLICY_TIER + 0.9;
export const EXCLUDE_TOOLS_FLAG_PRIORITY = USER_POLICY_TIER + 0.4; export const EXCLUDE_TOOLS_FLAG_PRIORITY = USER_POLICY_TIER + 0.4;
@@ -60,6 +67,18 @@ export const ALLOWED_TOOLS_FLAG_PRIORITY = USER_POLICY_TIER + 0.3;
export const TRUSTED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.2; export const TRUSTED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.2;
export const ALLOWED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.1; export const ALLOWED_MCP_SERVER_PRIORITY = USER_POLICY_TIER + 0.1;
// These are added to the tier base (e.g., USER_POLICY_TIER).
// Workspace tier (3) + high priority (950/1000) = ALWAYS_ALLOW_PRIORITY
export const ALWAYS_ALLOW_PRIORITY =
WORKSPACE_POLICY_TIER + ALWAYS_ALLOW_PRIORITY_OFFSET;
/**
* Returns the fractional priority of ALWAYS_ALLOW_PRIORITY scaled to 1000.
*/
export function getAlwaysAllowPriorityFraction(): number {
return Math.round((ALWAYS_ALLOW_PRIORITY % 1) * 1000);
}
/** /**
* Gets the list of directories to search for policy files, in order of increasing priority * Gets the list of directories to search for policy files, in order of increasing priority
* (Default -> Extension -> Workspace -> User -> Admin). * (Default -> Extension -> Workspace -> User -> Admin).
@@ -492,6 +511,19 @@ export function createPolicyUpdater(
if (message.commandPrefix) { if (message.commandPrefix) {
// Convert commandPrefix(es) to argsPatterns for in-memory rules // Convert commandPrefix(es) to argsPatterns for in-memory rules
const patterns = buildArgsPatterns(undefined, message.commandPrefix); const patterns = buildArgsPatterns(undefined, message.commandPrefix);
const tier =
message.persistScope === 'user'
? USER_POLICY_TIER
: WORKSPACE_POLICY_TIER;
const priority = tier + getAlwaysAllowPriorityFraction() / 1000;
if (SENSITIVE_TOOLS.has(toolName) && !message.commandPrefix) {
debugLogger.warn(
`Attempted to update policy for sensitive tool '${toolName}' without a commandPrefix. Skipping.`,
);
return;
}
for (const pattern of patterns) { for (const pattern of patterns) {
if (pattern) { if (pattern) {
// Note: patterns from buildArgsPatterns are derived from escapeRegex, // Note: patterns from buildArgsPatterns are derived from escapeRegex,
@@ -499,7 +531,7 @@ export function createPolicyUpdater(
policyEngine.addRule({ policyEngine.addRule({
toolName, toolName,
decision: PolicyDecision.ALLOW, decision: PolicyDecision.ALLOW,
priority: ALWAYS_ALLOW_PRIORITY, priority,
argsPattern: new RegExp(pattern), argsPattern: new RegExp(pattern),
source: 'Dynamic (Confirmed)', source: 'Dynamic (Confirmed)',
}); });
@@ -518,10 +550,23 @@ export function createPolicyUpdater(
? new RegExp(message.argsPattern) ? new RegExp(message.argsPattern)
: undefined; : undefined;
const tier =
message.persistScope === 'user'
? USER_POLICY_TIER
: WORKSPACE_POLICY_TIER;
const priority = tier + getAlwaysAllowPriorityFraction() / 1000;
if (SENSITIVE_TOOLS.has(toolName) && !message.argsPattern) {
debugLogger.warn(
`Attempted to update policy for sensitive tool '${toolName}' without an argsPattern. Skipping.`,
);
return;
}
policyEngine.addRule({ policyEngine.addRule({
toolName, toolName,
decision: PolicyDecision.ALLOW, decision: PolicyDecision.ALLOW,
priority: ALWAYS_ALLOW_PRIORITY, priority,
argsPattern, argsPattern,
source: 'Dynamic (Confirmed)', source: 'Dynamic (Confirmed)',
}); });
@@ -530,7 +575,10 @@ export function createPolicyUpdater(
if (message.persist) { if (message.persist) {
persistenceQueue = persistenceQueue.then(async () => { persistenceQueue = persistenceQueue.then(async () => {
try { try {
const policyFile = storage.getAutoSavedPolicyPath(); const policyFile =
message.persistScope === 'workspace'
? storage.getWorkspaceAutoSavedPolicyPath()
: storage.getAutoSavedPolicyPath();
await fs.mkdir(path.dirname(policyFile), { recursive: true }); await fs.mkdir(path.dirname(policyFile), { recursive: true });
// Read existing file // Read existing file
@@ -560,21 +608,19 @@ export function createPolicyUpdater(
} }
// Create new rule object // Create new rule object
const newRule: TomlRule = {}; const newRule: TomlRule = {
decision: 'allow',
priority: getAlwaysAllowPriorityFraction(),
};
if (message.mcpName) { if (message.mcpName) {
newRule.mcpName = message.mcpName; newRule.mcpName = message.mcpName;
// Extract simple tool name // Extract simple tool name
const simpleToolName = toolName.startsWith(`${message.mcpName}__`) newRule.toolName = toolName.startsWith(`${message.mcpName}__`)
? toolName.slice(message.mcpName.length + 2) ? toolName.slice(message.mcpName.length + 2)
: toolName; : toolName;
newRule.toolName = simpleToolName;
newRule.decision = 'allow';
newRule.priority = 200;
} else { } else {
newRule.toolName = toolName; newRule.toolName = toolName;
newRule.decision = 'allow';
newRule.priority = 100;
} }
if (message.commandPrefix) { if (message.commandPrefix) {
+116 -135
View File
@@ -4,25 +4,22 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import * as fs from 'node:fs/promises';
import * as path from 'node:path'; import * as path from 'node:path';
import { createPolicyUpdater, ALWAYS_ALLOW_PRIORITY } from './config.js'; import {
createPolicyUpdater,
getAlwaysAllowPriorityFraction,
} from './config.js';
import { PolicyEngine } from './policy-engine.js'; import { PolicyEngine } from './policy-engine.js';
import { MessageBus } from '../confirmation-bus/message-bus.js'; import { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js'; import { MessageBusType } from '../confirmation-bus/types.js';
import { Storage, AUTO_SAVED_POLICY_FILENAME } from '../config/storage.js'; import { Storage, AUTO_SAVED_POLICY_FILENAME } from '../config/storage.js';
import { ApprovalMode } from './types.js'; import { ApprovalMode } from './types.js';
import { vol, fs as memfs } from 'memfs';
// Use memfs for all fs operations in this test
vi.mock('node:fs/promises', () => import('memfs').then((m) => m.fs.promises));
vi.mock('node:fs/promises');
vi.mock('../config/storage.js'); vi.mock('../config/storage.js');
describe('createPolicyUpdater', () => { describe('createPolicyUpdater', () => {
@@ -31,6 +28,8 @@ describe('createPolicyUpdater', () => {
let mockStorage: Storage; let mockStorage: Storage;
beforeEach(() => { beforeEach(() => {
vi.useFakeTimers();
vol.reset();
policyEngine = new PolicyEngine({ policyEngine = new PolicyEngine({
rules: [], rules: [],
checkers: [], checkers: [],
@@ -43,202 +42,184 @@ describe('createPolicyUpdater', () => {
afterEach(() => { afterEach(() => {
vi.restoreAllMocks(); vi.restoreAllMocks();
vi.useRealTimers();
}); });
it('should persist policy when persist flag is true', async () => { it('should persist policy when persist flag is true', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage); createPolicyUpdater(policyEngine, messageBus, mockStorage);
const userPoliciesDir = '/mock/user/.gemini/policies'; const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile); vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
); // Simulate new file
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'test_tool';
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName, toolName: 'test_tool',
persist: true, persist: true,
}); });
// Wait for async operations (microtasks) // Policy updater handles persistence asynchronously in a promise queue.
await new Promise((resolve) => setTimeout(resolve, 0)); // We use advanceTimersByTimeAsync to yield to the microtask queue.
await vi.advanceTimersByTimeAsync(100);
expect(fs.mkdir).toHaveBeenCalledWith(userPoliciesDir, { const fileExists = memfs.existsSync(policyFile);
recursive: true, expect(fileExists).toBe(true);
});
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); const content = memfs.readFileSync(policyFile, 'utf-8') as string;
expect(content).toContain('toolName = "test_tool"');
// Check written content expect(content).toContain('decision = "allow"');
const expectedContent = expect.stringContaining(`toolName = "test_tool"`); const expectedPriority = getAlwaysAllowPriorityFraction();
expect(mockFileHandle.writeFile).toHaveBeenCalledWith( expect(content).toContain(`priority = ${expectedPriority}`);
expectedContent,
'utf-8',
);
expect(fs.rename).toHaveBeenCalledWith(
expect.stringMatching(/\.tmp$/),
policyFile,
);
}); });
it('should not persist policy when persist flag is false or undefined', async () => { it('should not persist policy when persist flag is false or undefined', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage); createPolicyUpdater(policyEngine, messageBus, mockStorage);
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName: 'test_tool', toolName: 'test_tool',
}); });
await new Promise((resolve) => setTimeout(resolve, 0)); await vi.advanceTimersByTimeAsync(100);
expect(fs.writeFile).not.toHaveBeenCalled(); expect(memfs.existsSync(policyFile)).toBe(false);
expect(fs.rename).not.toHaveBeenCalled();
}); });
it('should persist policy with commandPrefix when provided', async () => { it('should append to existing policy file', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage); createPolicyUpdater(policyEngine, messageBus, mockStorage);
const userPoliciesDir = '/mock/user/.gemini/policies'; const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile); vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
const mockFileHandle = { const existingContent =
writeFile: vi.fn().mockResolvedValue(undefined), '[[rule]]\ntoolName = "existing_tool"\ndecision = "allow"\n';
close: vi.fn().mockResolvedValue(undefined), const dir = path.dirname(policyFile);
}; memfs.mkdirSync(dir, { recursive: true });
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle); memfs.writeFileSync(policyFile, existingContent);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const toolName = 'run_shell_command';
const commandPrefix = 'git status';
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName, toolName: 'new_tool',
persist: true, persist: true,
commandPrefix,
}); });
await new Promise((resolve) => setTimeout(resolve, 0)); await vi.advanceTimersByTimeAsync(100);
// In-memory rule check (unchanged) const content = memfs.readFileSync(policyFile, 'utf-8') as string;
const rules = policyEngine.getRules(); expect(content).toContain('toolName = "existing_tool"');
const addedRule = rules.find((r) => r.toolName === toolName); expect(content).toContain('toolName = "new_tool"');
expect(addedRule).toBeDefined();
expect(addedRule?.priority).toBe(ALWAYS_ALLOW_PRIORITY);
expect(addedRule?.argsPattern).toEqual(
new RegExp(`"command":"git\\ status(?:[\\s"]|\\\\")`),
);
// Verify file written
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx');
expect(mockFileHandle.writeFile).toHaveBeenCalledWith(
expect.stringContaining(`commandPrefix = "git status"`),
'utf-8',
);
}); });
it('should persist policy with mcpName and toolName when provided', async () => { it('should handle toml with multiple rules correctly', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage); createPolicyUpdater(policyEngine, messageBus, mockStorage);
const userPoliciesDir = '/mock/user/.gemini/policies'; const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile); vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
const mockFileHandle = { const existingContent = `
writeFile: vi.fn().mockResolvedValue(undefined), [[rule]]
close: vi.fn().mockResolvedValue(undefined), toolName = "tool1"
}; decision = "allow"
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my-jira-server'; [[rule]]
const simpleToolName = 'search'; toolName = "tool2"
const toolName = `${mcpName}__${simpleToolName}`; decision = "deny"
`;
const dir = path.dirname(policyFile);
memfs.mkdirSync(dir, { recursive: true });
memfs.writeFileSync(policyFile, existingContent);
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName, toolName: 'tool3',
persist: true, persist: true,
mcpName,
}); });
await new Promise((resolve) => setTimeout(resolve, 0)); await vi.advanceTimersByTimeAsync(100);
// Verify file written const content = memfs.readFileSync(policyFile, 'utf-8') as string;
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); expect(content).toContain('toolName = "tool1"');
const writeCall = mockFileHandle.writeFile.mock.calls[0]; expect(content).toContain('toolName = "tool2"');
const writtenContent = writeCall[0] as string; expect(content).toContain('toolName = "tool3"');
expect(writtenContent).toContain(`mcpName = "${mcpName}"`);
expect(writtenContent).toContain(`toolName = "${simpleToolName}"`);
expect(writtenContent).toContain('priority = 200');
}); });
it('should escape special characters in toolName and mcpName', async () => { it('should include argsPattern if provided', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage); createPolicyUpdater(policyEngine, messageBus, mockStorage);
const userPoliciesDir = '/mock/user/.gemini/policies'; const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
const policyFile = path.join(userPoliciesDir, AUTO_SAVED_POLICY_FILENAME);
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile); vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
(fs.mkdir as unknown as Mock).mockResolvedValue(undefined);
(fs.readFile as unknown as Mock).mockRejectedValue(
new Error('File not found'),
);
const mockFileHandle = {
writeFile: vi.fn().mockResolvedValue(undefined),
close: vi.fn().mockResolvedValue(undefined),
};
(fs.open as unknown as Mock).mockResolvedValue(mockFileHandle);
(fs.rename as unknown as Mock).mockResolvedValue(undefined);
const mcpName = 'my"jira"server';
const toolName = `my"jira"server__search"tool"`;
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName, toolName: 'test_tool',
persist: true, persist: true,
mcpName, argsPattern: '^foo.*$',
}); });
await new Promise((resolve) => setTimeout(resolve, 0)); await vi.advanceTimersByTimeAsync(100);
expect(fs.open).toHaveBeenCalledWith(expect.stringMatching(/\.tmp$/), 'wx'); const content = memfs.readFileSync(policyFile, 'utf-8') as string;
const writeCall = mockFileHandle.writeFile.mock.calls[0]; expect(content).toContain('argsPattern = "^foo.*$"');
const writtenContent = writeCall[0] as string; });
// Verify escaping - should be valid TOML it('should include mcpName if provided', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage);
const policyFile = '/mock/user/.gemini/policies/auto-saved.toml';
vi.spyOn(mockStorage, 'getAutoSavedPolicyPath').mockReturnValue(policyFile);
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'search"tool"',
persist: true,
mcpName: 'my"jira"server',
});
await vi.advanceTimersByTimeAsync(100);
const writtenContent = memfs.readFileSync(policyFile, 'utf-8') as string;
// Verify escaping - should be valid TOML and contain the values
// Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar' // Note: @iarna/toml optimizes for shortest representation, so it may use single quotes 'foo"bar'
// instead of "foo\"bar\"" if there are no single quotes in the string. // instead of "foo\"bar\"" if there are no single quotes in the string.
try { try {
expect(writtenContent).toContain(`mcpName = "my\\"jira\\"server"`); expect(writtenContent).toContain('mcpName = "my\\"jira\\"server"');
} catch { } catch {
expect(writtenContent).toContain(`mcpName = 'my"jira"server'`); expect(writtenContent).toContain('mcpName = \'my"jira"server\'');
} }
try { try {
expect(writtenContent).toContain(`toolName = "search\\"tool\\""`); expect(writtenContent).toContain('toolName = "search\\"tool\\""');
} catch { } catch {
expect(writtenContent).toContain(`toolName = 'search"tool"'`); expect(writtenContent).toContain('toolName = \'search"tool"\'');
} }
}); });
it('should persist to workspace when persistScope is workspace', async () => {
createPolicyUpdater(policyEngine, messageBus, mockStorage);
const workspacePoliciesDir = '/mock/project/.gemini/policies';
const policyFile = path.join(
workspacePoliciesDir,
AUTO_SAVED_POLICY_FILENAME,
);
vi.spyOn(mockStorage, 'getWorkspaceAutoSavedPolicyPath').mockReturnValue(
policyFile,
);
await messageBus.publish({
type: MessageBusType.UPDATE_POLICY,
toolName: 'test_tool',
persist: true,
persistScope: 'workspace',
});
await vi.advanceTimersByTimeAsync(100);
expect(memfs.existsSync(policyFile)).toBe(true);
const content = memfs.readFileSync(policyFile, 'utf-8') as string;
expect(content).toContain('toolName = "test_tool"');
});
}); });
@@ -19,6 +19,7 @@ import {
type PolicyUpdateOptions, type PolicyUpdateOptions,
} from '../tools/tools.js'; } from '../tools/tools.js';
import * as shellUtils from '../utils/shell-utils.js'; import * as shellUtils from '../utils/shell-utils.js';
import { escapeRegex } from './utils.js';
vi.mock('node:fs/promises'); vi.mock('node:fs/promises');
vi.mock('../config/storage.js'); vi.mock('../config/storage.js');
@@ -75,7 +76,9 @@ describe('createPolicyUpdater', () => {
expect.objectContaining({ expect.objectContaining({
toolName: 'run_shell_command', toolName: 'run_shell_command',
priority: ALWAYS_ALLOW_PRIORITY, priority: ALWAYS_ALLOW_PRIORITY,
argsPattern: new RegExp('"command":"echo(?:[\\s"]|\\\\")'), argsPattern: new RegExp(
escapeRegex('"command":"echo') + '(?:[\\s"]|\\\\")',
),
}), }),
); );
expect(policyEngine.addRule).toHaveBeenNthCalledWith( expect(policyEngine.addRule).toHaveBeenNthCalledWith(
@@ -83,7 +86,9 @@ describe('createPolicyUpdater', () => {
expect.objectContaining({ expect.objectContaining({
toolName: 'run_shell_command', toolName: 'run_shell_command',
priority: ALWAYS_ALLOW_PRIORITY, priority: ALWAYS_ALLOW_PRIORITY,
argsPattern: new RegExp('"command":"ls(?:[\\s"]|\\\\")'), argsPattern: new RegExp(
escapeRegex('"command":"ls') + '(?:[\\s"]|\\\\")',
),
}), }),
); );
}); });
@@ -103,7 +108,9 @@ describe('createPolicyUpdater', () => {
expect.objectContaining({ expect.objectContaining({
toolName: 'run_shell_command', toolName: 'run_shell_command',
priority: ALWAYS_ALLOW_PRIORITY, priority: ALWAYS_ALLOW_PRIORITY,
argsPattern: new RegExp('"command":"git(?:[\\s"]|\\\\")'), argsPattern: new RegExp(
escapeRegex('"command":"git') + '(?:[\\s"]|\\\\")',
),
}), }),
); );
}); });
+17 -16
View File
@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import { describe, it, expect } from 'vitest'; import { expect, describe, it } from 'vitest';
import { escapeRegex, buildArgsPatterns, isSafeRegExp } from './utils.js'; import { escapeRegex, buildArgsPatterns, isSafeRegExp } from './utils.js';
describe('policy/utils', () => { describe('policy/utils', () => {
@@ -43,20 +43,20 @@ describe('policy/utils', () => {
}); });
it('should return false for invalid regexes', () => { it('should return false for invalid regexes', () => {
expect(isSafeRegExp('[')).toBe(false);
expect(isSafeRegExp('([a-z)')).toBe(false); expect(isSafeRegExp('([a-z)')).toBe(false);
expect(isSafeRegExp('*')).toBe(false); expect(isSafeRegExp('*')).toBe(false);
}); });
it('should return false for extremely long regexes', () => { it('should return false for long regexes', () => {
expect(isSafeRegExp('a'.repeat(2049))).toBe(false); expect(isSafeRegExp('a'.repeat(3000))).toBe(false);
}); });
it('should return false for nested quantifiers (potential ReDoS)', () => { it('should return false for nested quantifiers (ReDoS heuristic)', () => {
expect(isSafeRegExp('(a+)+')).toBe(false); expect(isSafeRegExp('(a+)+')).toBe(false);
expect(isSafeRegExp('(a+)*')).toBe(false); expect(isSafeRegExp('(a|b)*')).toBe(true);
expect(isSafeRegExp('(a*)+')).toBe(false); expect(isSafeRegExp('(.*)*')).toBe(false);
expect(isSafeRegExp('(a*)*')).toBe(false); expect(isSafeRegExp('([a-z]+)+')).toBe(false);
expect(isSafeRegExp('(a|b+)+')).toBe(false);
expect(isSafeRegExp('(.*)+')).toBe(false); expect(isSafeRegExp('(.*)+')).toBe(false);
}); });
}); });
@@ -69,14 +69,14 @@ describe('policy/utils', () => {
it('should build pattern from a single commandPrefix', () => { it('should build pattern from a single commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'ls', undefined); const result = buildArgsPatterns(undefined, 'ls', undefined);
expect(result).toEqual(['"command":"ls(?:[\\s"]|\\\\")']); expect(result).toEqual(['\\"command\\":\\"ls(?:[\\s"]|\\\\")']);
}); });
it('should build patterns from an array of commandPrefixes', () => { it('should build patterns from an array of commandPrefixes', () => {
const result = buildArgsPatterns(undefined, ['ls', 'cd'], undefined); const result = buildArgsPatterns(undefined, ['echo', 'ls'], undefined);
expect(result).toEqual([ expect(result).toEqual([
'"command":"ls(?:[\\s"]|\\\\")', '\\"command\\":\\"echo(?:[\\s"]|\\\\")',
'"command":"cd(?:[\\s"]|\\\\")', '\\"command\\":\\"ls(?:[\\s"]|\\\\")',
]); ]);
}); });
@@ -87,7 +87,7 @@ describe('policy/utils', () => {
it('should prioritize commandPrefix over commandRegex and argsPattern', () => { it('should prioritize commandPrefix over commandRegex and argsPattern', () => {
const result = buildArgsPatterns('raw', 'prefix', 'regex'); const result = buildArgsPatterns('raw', 'prefix', 'regex');
expect(result).toEqual(['"command":"prefix(?:[\\s"]|\\\\")']); expect(result).toEqual(['\\"command\\":\\"prefix(?:[\\s"]|\\\\")']);
}); });
it('should prioritize commandRegex over argsPattern if no commandPrefix', () => { it('should prioritize commandRegex over argsPattern if no commandPrefix', () => {
@@ -98,14 +98,15 @@ describe('policy/utils', () => {
it('should escape characters in commandPrefix', () => { it('should escape characters in commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'git checkout -b', undefined); const result = buildArgsPatterns(undefined, 'git checkout -b', undefined);
expect(result).toEqual([ expect(result).toEqual([
'"command":"git\\ checkout\\ \\-b(?:[\\s"]|\\\\")', '\\"command\\":\\"git\\ checkout\\ \\-b(?:[\\s"]|\\\\")',
]); ]);
}); });
it('should correctly escape quotes in commandPrefix', () => { it('should correctly escape quotes in commandPrefix', () => {
const result = buildArgsPatterns(undefined, 'git "fix"', undefined); const result = buildArgsPatterns(undefined, 'git "fix"', undefined);
expect(result).toEqual([ expect(result).toEqual([
'"command":"git\\ \\\\\\"fix\\\\\\"(?:[\\s"]|\\\\")', // eslint-disable-next-line no-useless-escape
'\\\"command\\\":\\\"git\\ \\\\\\\"fix\\\\\\\"(?:[\\s\"]|\\\\\")',
]); ]);
}); });
@@ -142,7 +143,7 @@ describe('policy/utils', () => {
const gitRegex = new RegExp(gitPatterns[0]!); const gitRegex = new RegExp(gitPatterns[0]!);
// git\status -> {"command":"git\\status"} // git\status -> {"command":"git\\status"}
const gitAttack = '{"command":"git\\\\status"}'; const gitAttack = '{"command":"git\\\\status"}';
expect(gitRegex.test(gitAttack)).toBe(false); expect(gitAttack).not.toMatch(gitRegex);
}); });
}); });
}); });
+39 -6
View File
@@ -63,16 +63,22 @@ export function buildArgsPatterns(
? commandPrefix ? commandPrefix
: [commandPrefix]; : [commandPrefix];
// Expand command prefixes to multiple patterns.
// We append [\\s"] to ensure we match whole words only (e.g., "git" but not
// "github"). Since we match against JSON stringified args, the value is
// always followed by a space or a closing quote.
return prefixes.map((prefix) => { return prefixes.map((prefix) => {
const jsonPrefix = JSON.stringify(prefix).slice(1, -1); // JSON.stringify safely encodes the prefix in quotes.
// We remove ONLY the trailing quote to match it as an open prefix string.
const encodedPrefix = JSON.stringify(prefix);
const openQuotePrefix = encodedPrefix.substring(
0,
encodedPrefix.length - 1,
);
// Escape the exact JSON literal segment we expect to see
const matchSegment = escapeRegex(`"command":${openQuotePrefix}`);
// We allow [\s], ["], or the specific sequence [\"] (for escaped quotes // We allow [\s], ["], or the specific sequence [\"] (for escaped quotes
// in JSON). We do NOT allow generic [\\], which would match "git\status" // in JSON). We do NOT allow generic [\\], which would match "git\status"
// -> "gitstatus". // -> "gitstatus".
return `"command":"${escapeRegex(jsonPrefix)}(?:[\\s"]|\\\\")`; return `${matchSegment}(?:[\\s"]|\\\\")`;
}); });
} }
@@ -82,3 +88,30 @@ export function buildArgsPatterns(
return [argsPattern]; return [argsPattern];
} }
/**
* Builds a regex pattern to match a specific file path in tool arguments.
* This is used to narrow tool approvals for edit tools to specific files.
*
* @param filePath The relative path to the file.
* @returns A regex string that matches "file_path":"<path>" in a JSON string.
*/
export function buildFilePathArgsPattern(filePath: string): string {
// JSON.stringify safely encodes the path (handling quotes, backslashes, etc)
// and wraps it in double quotes. We simply prepend the key name and escape
// the entire sequence for Regex matching without any slicing.
const encodedPath = JSON.stringify(filePath);
return escapeRegex(`"file_path":${encodedPath}`);
}
/**
* Builds a regex pattern to match a specific "pattern" in tool arguments.
* This is used to narrow tool approvals for search tools like glob/grep to specific patterns.
*
* @param pattern The pattern to match.
* @returns A regex string that matches "pattern":"<pattern>" in a JSON string.
*/
export function buildPatternArgsPattern(pattern: string): string {
const encodedPattern = JSON.stringify(pattern);
return escapeRegex(`"pattern":${encodedPattern}`);
}
+99 -1
View File
@@ -16,8 +16,12 @@ import {
import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js'; import { checkPolicy, updatePolicy, getPolicyDenialError } from './policy.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { MessageBusType } from '../confirmation-bus/types.js'; import {
MessageBusType,
type SerializableConfirmationDetails,
} from '../confirmation-bus/types.js';
import { ApprovalMode, PolicyDecision } from '../policy/types.js'; import { ApprovalMode, PolicyDecision } from '../policy/types.js';
import { escapeRegex } from '../policy/utils.js';
import { import {
ToolConfirmationOutcome, ToolConfirmationOutcome,
type AnyDeclarativeTool, type AnyDeclarativeTool,
@@ -219,6 +223,8 @@ describe('policy.ts', () => {
it('should handle standard policy updates with persistence', async () => { it('should handle standard policy updates with persistence', async () => {
const mockConfig = { const mockConfig = {
isTrustedFolder: vi.fn().mockReturnValue(false),
getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined),
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>; } as unknown as Mocked<Config>;
@@ -453,6 +459,8 @@ describe('policy.ts', () => {
it('should handle MCP ProceedAlwaysAndSave (persist: true)', async () => { it('should handle MCP ProceedAlwaysAndSave (persist: true)', async () => {
const mockConfig = { const mockConfig = {
isTrustedFolder: vi.fn().mockReturnValue(false),
getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined),
setApprovalMode: vi.fn(), setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>; } as unknown as Mocked<Config>;
@@ -487,6 +495,96 @@ describe('policy.ts', () => {
}), }),
); );
}); });
it('should determine persistScope: workspace in trusted folders', async () => {
const mockConfig = {
isTrustedFolder: vi.fn().mockReturnValue(true),
getWorkspacePoliciesDir: vi
.fn()
.mockReturnValue('/mock/project/policies'),
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
const tool = { name: 'test-tool' } as AnyDeclarativeTool;
await updatePolicy(
tool,
ToolConfirmationOutcome.ProceedAlwaysAndSave,
undefined,
{ config: mockConfig, messageBus: mockMessageBus },
);
expect(mockMessageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
persistScope: 'workspace',
}),
);
});
it('should determine persistScope: user in untrusted folders', async () => {
const mockConfig = {
isTrustedFolder: vi.fn().mockReturnValue(false),
getWorkspacePoliciesDir: vi
.fn()
.mockReturnValue('/mock/project/policies'),
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
const tool = { name: 'test-tool' } as AnyDeclarativeTool;
await updatePolicy(
tool,
ToolConfirmationOutcome.ProceedAlwaysAndSave,
undefined,
{ config: mockConfig, messageBus: mockMessageBus },
);
expect(mockMessageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
persistScope: 'user',
}),
);
});
it('should narrow edit tools with argsPattern', async () => {
const mockConfig = {
isTrustedFolder: vi.fn().mockReturnValue(false),
getWorkspacePoliciesDir: vi.fn().mockReturnValue(undefined),
getTargetDir: vi.fn().mockReturnValue('/mock/dir'),
setApprovalMode: vi.fn(),
} as unknown as Mocked<Config>;
const mockMessageBus = {
publish: vi.fn(),
} as unknown as Mocked<MessageBus>;
const tool = { name: 'write_file' } as AnyDeclarativeTool;
const details: SerializableConfirmationDetails = {
type: 'edit',
title: 'Edit',
filePath: 'src/foo.ts',
fileName: 'foo.ts',
fileDiff: '--- foo.ts\n+++ foo.ts\n@@ -1 +1 @@\n-old\n+new',
originalContent: 'old',
newContent: 'new',
};
await updatePolicy(
tool,
ToolConfirmationOutcome.ProceedAlwaysAndSave,
details,
{ config: mockConfig, messageBus: mockMessageBus },
);
expect(mockMessageBus.publish).toHaveBeenCalledWith(
expect.objectContaining({
toolName: 'write_file',
argsPattern: escapeRegex('"file_path":"src/foo.ts"'),
}),
);
});
}); });
describe('getPolicyDenialError', () => { describe('getPolicyDenialError', () => {
+40 -3
View File
@@ -20,8 +20,11 @@ import {
import { import {
ToolConfirmationOutcome, ToolConfirmationOutcome,
type AnyDeclarativeTool, type AnyDeclarativeTool,
type AnyToolInvocation,
type PolicyUpdateOptions, type PolicyUpdateOptions,
} from '../tools/tools.js'; } from '../tools/tools.js';
import { buildFilePathArgsPattern } from '../policy/utils.js';
import { makeRelative } from '../utils/paths.js';
import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js';
import { EDIT_TOOL_NAMES } from '../tools/tool-names.js'; import { EDIT_TOOL_NAMES } from '../tools/tool-names.js';
import type { ValidatingToolCall } from './types.js'; import type { ValidatingToolCall } from './types.js';
@@ -94,7 +97,11 @@ export async function updatePolicy(
tool: AnyDeclarativeTool, tool: AnyDeclarativeTool,
outcome: ToolConfirmationOutcome, outcome: ToolConfirmationOutcome,
confirmationDetails: SerializableConfirmationDetails | undefined, confirmationDetails: SerializableConfirmationDetails | undefined,
deps: { config: Config; messageBus: MessageBus }, deps: {
config: Config;
messageBus: MessageBus;
toolInvocation?: AnyToolInvocation;
},
): Promise<void> { ): Promise<void> {
// Mode Transitions (AUTO_EDIT) // Mode Transitions (AUTO_EDIT)
if (isAutoEditTransition(tool, outcome)) { if (isAutoEditTransition(tool, outcome)) {
@@ -102,6 +109,20 @@ export async function updatePolicy(
return; return;
} }
// Determine persist scope if we are persisting.
let persistScope: 'workspace' | 'user' | undefined;
if (outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave) {
// If folder is trusted and workspace policies are enabled, we prefer workspace scope.
if (
deps.config.isTrustedFolder() &&
deps.config.getWorkspacePoliciesDir() !== undefined
) {
persistScope = 'workspace';
} else {
persistScope = 'user';
}
}
// Specialized Tools (MCP) // Specialized Tools (MCP)
if (confirmationDetails?.type === 'mcp') { if (confirmationDetails?.type === 'mcp') {
await handleMcpPolicyUpdate( await handleMcpPolicyUpdate(
@@ -109,6 +130,7 @@ export async function updatePolicy(
outcome, outcome,
confirmationDetails, confirmationDetails,
deps.messageBus, deps.messageBus,
persistScope,
); );
return; return;
} }
@@ -119,6 +141,9 @@ export async function updatePolicy(
outcome, outcome,
confirmationDetails, confirmationDetails,
deps.messageBus, deps.messageBus,
persistScope,
deps.toolInvocation,
deps.config,
); );
} }
@@ -148,21 +173,31 @@ async function handleStandardPolicyUpdate(
outcome: ToolConfirmationOutcome, outcome: ToolConfirmationOutcome,
confirmationDetails: SerializableConfirmationDetails | undefined, confirmationDetails: SerializableConfirmationDetails | undefined,
messageBus: MessageBus, messageBus: MessageBus,
persistScope?: 'workspace' | 'user',
toolInvocation?: AnyToolInvocation,
config?: Config,
): Promise<void> { ): Promise<void> {
if ( if (
outcome === ToolConfirmationOutcome.ProceedAlways || outcome === ToolConfirmationOutcome.ProceedAlways ||
outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave
) { ) {
const options: PolicyUpdateOptions = {}; const options: PolicyUpdateOptions =
toolInvocation?.getPolicyUpdateOptions?.(outcome) || {};
if (confirmationDetails?.type === 'exec') { if (!options.commandPrefix && confirmationDetails?.type === 'exec') {
options.commandPrefix = confirmationDetails.rootCommands; options.commandPrefix = confirmationDetails.rootCommands;
} else if (!options.argsPattern && confirmationDetails?.type === 'edit') {
const filePath = config
? makeRelative(confirmationDetails.filePath, config.getTargetDir())
: confirmationDetails.filePath;
options.argsPattern = buildFilePathArgsPattern(filePath);
} }
await messageBus.publish({ await messageBus.publish({
type: MessageBusType.UPDATE_POLICY, type: MessageBusType.UPDATE_POLICY,
toolName: tool.name, toolName: tool.name,
persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave, persist: outcome === ToolConfirmationOutcome.ProceedAlwaysAndSave,
persistScope,
...options, ...options,
}); });
} }
@@ -180,6 +215,7 @@ async function handleMcpPolicyUpdate(
{ type: 'mcp' } { type: 'mcp' }
>, >,
messageBus: MessageBus, messageBus: MessageBus,
persistScope?: 'workspace' | 'user',
): Promise<void> { ): Promise<void> {
const isMcpAlways = const isMcpAlways =
outcome === ToolConfirmationOutcome.ProceedAlways || outcome === ToolConfirmationOutcome.ProceedAlways ||
@@ -204,5 +240,6 @@ async function handleMcpPolicyUpdate(
toolName, toolName,
mcpName: confirmationDetails.serverName, mcpName: confirmationDetails.serverName,
persist, persist,
persistScope,
}); });
} }
+1
View File
@@ -608,6 +608,7 @@ export class Scheduler {
await updatePolicy(toolCall.tool, outcome, lastDetails, { await updatePolicy(toolCall.tool, outcome, lastDetails, {
config: this.config, config: this.config,
messageBus: this.messageBus, messageBus: this.messageBus,
toolInvocation: toolCall.invocation,
}); });
} }
+58 -33
View File
@@ -20,11 +20,14 @@ import {
type ToolLocation, type ToolLocation,
type ToolResult, type ToolResult,
type ToolResultDisplay, type ToolResultDisplay,
type PolicyUpdateOptions,
} from './tools.js'; } from './tools.js';
import { buildFilePathArgsPattern } from '../policy/utils.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import { isNodeError } from '../utils/errors.js'; import { isNodeError } from '../utils/errors.js';
import { correctPath } from '../utils/pathCorrector.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { ApprovalMode } from '../policy/types.js'; import { ApprovalMode } from '../policy/types.js';
import { CoreToolCallStatus } from '../scheduler/types.js'; import { CoreToolCallStatus } from '../scheduler/types.js';
@@ -44,7 +47,6 @@ import {
logEditCorrectionEvent, logEditCorrectionEvent,
} from '../telemetry/loggers.js'; } from '../telemetry/loggers.js';
import { correctPath } from '../utils/pathCorrector.js';
import { import {
EDIT_TOOL_NAME, EDIT_TOOL_NAME,
READ_FILE_TOOL_NAME, READ_FILE_TOOL_NAME,
@@ -442,6 +444,8 @@ class EditToolInvocation
extends BaseToolInvocation<EditToolParams, ToolResult> extends BaseToolInvocation<EditToolParams, ToolResult>
implements ToolInvocation<EditToolParams, ToolResult> implements ToolInvocation<EditToolParams, ToolResult>
{ {
private readonly resolvedPath: string;
constructor( constructor(
private readonly config: Config, private readonly config: Config,
params: EditToolParams, params: EditToolParams,
@@ -450,10 +454,31 @@ class EditToolInvocation
displayName?: string, displayName?: string,
) { ) {
super(params, messageBus, toolName, displayName); super(params, messageBus, toolName, displayName);
if (!path.isAbsolute(this.params.file_path)) {
const result = correctPath(this.params.file_path, this.config);
if (result.success) {
this.resolvedPath = result.correctedPath;
} else {
this.resolvedPath = path.resolve(
this.config.getTargetDir(),
this.params.file_path,
);
}
} else {
this.resolvedPath = this.params.file_path;
}
} }
override toolLocations(): ToolLocation[] { override toolLocations(): ToolLocation[] {
return [{ path: this.params.file_path }]; return [{ path: this.resolvedPath }];
}
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildFilePathArgsPattern(this.params.file_path),
};
} }
private async attemptSelfCorrection( private async attemptSelfCorrection(
@@ -471,7 +496,7 @@ class EditToolInvocation
const initialContentHash = hashContent(currentContent); const initialContentHash = hashContent(currentContent);
const onDiskContent = await this.config const onDiskContent = await this.config
.getFileSystemService() .getFileSystemService()
.readTextFile(params.file_path); .readTextFile(this.resolvedPath);
const onDiskContentHash = hashContent(onDiskContent.replace(/\r\n/g, '\n')); const onDiskContentHash = hashContent(onDiskContent.replace(/\r\n/g, '\n'));
if (initialContentHash !== onDiskContentHash) { if (initialContentHash !== onDiskContentHash) {
@@ -582,7 +607,7 @@ class EditToolInvocation
try { try {
currentContent = await this.config currentContent = await this.config
.getFileSystemService() .getFileSystemService()
.readTextFile(params.file_path); .readTextFile(this.resolvedPath);
originalLineEnding = detectLineEnding(currentContent); originalLineEnding = detectLineEnding(currentContent);
currentContent = currentContent.replace(/\r\n/g, '\n'); currentContent = currentContent.replace(/\r\n/g, '\n');
fileExists = true; fileExists = true;
@@ -615,7 +640,7 @@ class EditToolInvocation
isNewFile: false, isNewFile: false,
error: { error: {
display: `File not found. Cannot apply edit. Use an empty old_string to create a new file.`, display: `File not found. Cannot apply edit. Use an empty old_string to create a new file.`,
raw: `File not found: ${params.file_path}`, raw: `File not found: ${this.resolvedPath}`,
type: ToolErrorType.FILE_NOT_FOUND, type: ToolErrorType.FILE_NOT_FOUND,
}, },
originalLineEnding, originalLineEnding,
@@ -630,7 +655,7 @@ class EditToolInvocation
isNewFile: false, isNewFile: false,
error: { error: {
display: `Failed to read content of file.`, display: `Failed to read content of file.`,
raw: `Failed to read content of existing file: ${params.file_path}`, raw: `Failed to read content of existing file: ${this.resolvedPath}`,
type: ToolErrorType.READ_CONTENT_FAILURE, type: ToolErrorType.READ_CONTENT_FAILURE,
}, },
originalLineEnding, originalLineEnding,
@@ -645,7 +670,7 @@ class EditToolInvocation
isNewFile: false, isNewFile: false,
error: { error: {
display: `Failed to edit. Attempted to create a file that already exists.`, display: `Failed to edit. Attempted to create a file that already exists.`,
raw: `File already exists, cannot create: ${params.file_path}`, raw: `File already exists, cannot create: ${this.resolvedPath}`,
type: ToolErrorType.ATTEMPT_TO_CREATE_EXISTING_FILE, type: ToolErrorType.ATTEMPT_TO_CREATE_EXISTING_FILE,
}, },
originalLineEnding, originalLineEnding,
@@ -727,7 +752,7 @@ class EditToolInvocation
return false; return false;
} }
const fileName = path.basename(this.params.file_path); const fileName = path.basename(this.resolvedPath);
const fileDiff = Diff.createPatch( const fileDiff = Diff.createPatch(
fileName, fileName,
editData.currentContent ?? '', editData.currentContent ?? '',
@@ -739,14 +764,14 @@ class EditToolInvocation
const ideClient = await IdeClient.getInstance(); const ideClient = await IdeClient.getInstance();
const ideConfirmation = const ideConfirmation =
this.config.getIdeMode() && ideClient.isDiffingEnabled() this.config.getIdeMode() && ideClient.isDiffingEnabled()
? ideClient.openDiff(this.params.file_path, editData.newContent) ? ideClient.openDiff(this.resolvedPath, editData.newContent)
: undefined; : undefined;
const confirmationDetails: ToolEditConfirmationDetails = { const confirmationDetails: ToolEditConfirmationDetails = {
type: 'edit', type: 'edit',
title: `Confirm Edit: ${shortenPath(makeRelative(this.params.file_path, this.config.getTargetDir()))}`, title: `Confirm Edit: ${shortenPath(makeRelative(this.resolvedPath, this.config.getTargetDir()))}`,
fileName, fileName,
filePath: this.params.file_path, filePath: this.resolvedPath,
fileDiff, fileDiff,
originalContent: editData.currentContent, originalContent: editData.currentContent,
newContent: editData.newContent, newContent: editData.newContent,
@@ -771,7 +796,7 @@ class EditToolInvocation
getDescription(): string { getDescription(): string {
const relativePath = makeRelative( const relativePath = makeRelative(
this.params.file_path, this.resolvedPath,
this.config.getTargetDir(), this.config.getTargetDir(),
); );
if (this.params.old_string === '') { if (this.params.old_string === '') {
@@ -797,11 +822,7 @@ class EditToolInvocation
* @returns Result of the edit operation * @returns Result of the edit operation
*/ */
async execute(signal: AbortSignal): Promise<ToolResult> { async execute(signal: AbortSignal): Promise<ToolResult> {
const resolvedPath = path.resolve( const validationError = this.config.validatePathAccess(this.resolvedPath);
this.config.getTargetDir(),
this.params.file_path,
);
const validationError = this.config.validatePathAccess(resolvedPath);
if (validationError) { if (validationError) {
return { return {
llmContent: validationError, llmContent: validationError,
@@ -843,7 +864,7 @@ class EditToolInvocation
} }
try { try {
await this.ensureParentDirectoriesExistAsync(this.params.file_path); await this.ensureParentDirectoriesExistAsync(this.resolvedPath);
let finalContent = editData.newContent; let finalContent = editData.newContent;
// Restore original line endings if they were CRLF, or use OS default for new files // Restore original line endings if they were CRLF, or use OS default for new files
@@ -856,15 +877,15 @@ class EditToolInvocation
} }
await this.config await this.config
.getFileSystemService() .getFileSystemService()
.writeTextFile(this.params.file_path, finalContent); .writeTextFile(this.resolvedPath, finalContent);
let displayResult: ToolResultDisplay; let displayResult: ToolResultDisplay;
if (editData.isNewFile) { if (editData.isNewFile) {
displayResult = `Created ${shortenPath(makeRelative(this.params.file_path, this.config.getTargetDir()))}`; displayResult = `Created ${shortenPath(makeRelative(this.resolvedPath, this.config.getTargetDir()))}`;
} else { } else {
// Generate diff for display, even though core logic doesn't technically need it // Generate diff for display, even though core logic doesn't technically need it
// The CLI wrapper will use this part of the ToolResult // The CLI wrapper will use this part of the ToolResult
const fileName = path.basename(this.params.file_path); const fileName = path.basename(this.resolvedPath);
const fileDiff = Diff.createPatch( const fileDiff = Diff.createPatch(
fileName, fileName,
editData.currentContent ?? '', // Should not be null here if not isNewFile editData.currentContent ?? '', // Should not be null here if not isNewFile
@@ -883,7 +904,7 @@ class EditToolInvocation
displayResult = { displayResult = {
fileDiff, fileDiff,
fileName, fileName,
filePath: this.params.file_path, filePath: this.resolvedPath,
originalContent: editData.currentContent, originalContent: editData.currentContent,
newContent: editData.newContent, newContent: editData.newContent,
diffStat, diffStat,
@@ -893,8 +914,8 @@ class EditToolInvocation
const llmSuccessMessageParts = [ const llmSuccessMessageParts = [
editData.isNewFile editData.isNewFile
? `Created new file: ${this.params.file_path} with provided content.` ? `Created new file: ${this.resolvedPath} with provided content.`
: `Successfully modified file: ${this.params.file_path} (${editData.occurrences} replacements).`, : `Successfully modified file: ${this.resolvedPath} (${editData.occurrences} replacements).`,
]; ];
// Return a diff of the file before and after the write so that the agent // Return a diff of the file before and after the write so that the agent
@@ -985,16 +1006,20 @@ export class EditTool
return "The 'file_path' parameter must be non-empty."; return "The 'file_path' parameter must be non-empty.";
} }
let filePath = params.file_path; let resolvedPath: string;
if (!path.isAbsolute(filePath)) { if (!path.isAbsolute(params.file_path)) {
// Attempt to auto-correct to an absolute path const result = correctPath(params.file_path, this.config);
const result = correctPath(filePath, this.config); if (result.success) {
if (!result.success) { resolvedPath = result.correctedPath;
return result.error; } else {
resolvedPath = path.resolve(
this.config.getTargetDir(),
params.file_path,
);
} }
filePath = result.correctedPath; } else {
resolvedPath = params.file_path;
} }
params.file_path = filePath;
const newPlaceholders = detectOmissionPlaceholders(params.new_string); const newPlaceholders = detectOmissionPlaceholders(params.new_string);
if (newPlaceholders.length > 0) { if (newPlaceholders.length > 0) {
@@ -1009,7 +1034,7 @@ export class EditTool
} }
} }
return this.config.validatePathAccess(params.file_path); return this.config.validatePathAccess(resolvedPath);
} }
protected createInvocation( protected createInvocation(
+11
View File
@@ -14,12 +14,15 @@ import {
Kind, Kind,
type ToolInvocation, type ToolInvocation,
type ToolResult, type ToolResult,
type PolicyUpdateOptions,
type ToolConfirmationOutcome,
} from './tools.js'; } from './tools.js';
import { shortenPath, makeRelative } from '../utils/paths.js'; import { shortenPath, makeRelative } from '../utils/paths.js';
import { type Config } from '../config/config.js'; import { type Config } from '../config/config.js';
import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js'; import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { GLOB_TOOL_NAME, GLOB_DISPLAY_NAME } from './tool-names.js'; import { GLOB_TOOL_NAME, GLOB_DISPLAY_NAME } from './tool-names.js';
import { buildPatternArgsPattern } from '../policy/utils.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { GLOB_DEFINITION } from './definitions/coreTools.js'; import { GLOB_DEFINITION } from './definitions/coreTools.js';
@@ -118,6 +121,14 @@ class GlobToolInvocation extends BaseToolInvocation<
return description; return description;
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildPatternArgsPattern(this.params.pattern),
};
}
async execute(signal: AbortSignal): Promise<ToolResult> { async execute(signal: AbortSignal): Promise<ToolResult> {
try { try {
const workspaceContext = this.config.getWorkspaceContext(); const workspaceContext = this.config.getWorkspaceContext();
+11
View File
@@ -21,6 +21,8 @@ import {
Kind, Kind,
type ToolInvocation, type ToolInvocation,
type ToolResult, type ToolResult,
type PolicyUpdateOptions,
type ToolConfirmationOutcome,
} from './tools.js'; } from './tools.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js'; import { getErrorMessage, isNodeError } from '../utils/errors.js';
@@ -29,6 +31,7 @@ import type { Config } from '../config/config.js';
import type { FileExclusions } from '../utils/ignorePatterns.js'; import type { FileExclusions } from '../utils/ignorePatterns.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { GREP_TOOL_NAME } from './tool-names.js'; import { GREP_TOOL_NAME } from './tool-names.js';
import { buildPatternArgsPattern } from '../policy/utils.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { GREP_DEFINITION } from './definitions/coreTools.js'; import { GREP_DEFINITION } from './definitions/coreTools.js';
import { resolveToolDeclaration } from './definitions/resolver.js'; import { resolveToolDeclaration } from './definitions/resolver.js';
@@ -285,6 +288,14 @@ class GrepToolInvocation extends BaseToolInvocation<
} }
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildPatternArgsPattern(this.params.pattern),
};
}
/** /**
* Checks if a command is available in the system's PATH. * Checks if a command is available in the system's PATH.
* @param {string} command The command name (e.g., 'git', 'grep'). * @param {string} command The command name (e.g., 'git', 'grep').
+11
View File
@@ -13,12 +13,15 @@ import {
Kind, Kind,
type ToolInvocation, type ToolInvocation,
type ToolResult, type ToolResult,
type PolicyUpdateOptions,
type ToolConfirmationOutcome,
} from './tools.js'; } from './tools.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import type { Config } from '../config/config.js'; import type { Config } from '../config/config.js';
import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js'; import { DEFAULT_FILE_FILTERING_OPTIONS } from '../config/constants.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { LS_TOOL_NAME } from './tool-names.js'; import { LS_TOOL_NAME } from './tool-names.js';
import { buildFilePathArgsPattern } from '../policy/utils.js';
import { debugLogger } from '../utils/debugLogger.js'; import { debugLogger } from '../utils/debugLogger.js';
import { LS_DEFINITION } from './definitions/coreTools.js'; import { LS_DEFINITION } from './definitions/coreTools.js';
import { resolveToolDeclaration } from './definitions/resolver.js'; import { resolveToolDeclaration } from './definitions/resolver.js';
@@ -123,6 +126,14 @@ class LSToolInvocation extends BaseToolInvocation<LSToolParams, ToolResult> {
return shortenPath(relativePath); return shortenPath(relativePath);
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildFilePathArgsPattern(this.params.dir_path),
};
}
// Helper for consistent error formatting // Helper for consistent error formatting
private errorResult( private errorResult(
llmContent: string, llmContent: string,
+1 -1
View File
@@ -184,7 +184,7 @@ export class DiscoveredMCPToolInvocation extends BaseToolInvocation<
); );
} }
protected override getPolicyUpdateOptions( override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome, _outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined { ): PolicyUpdateOptions | undefined {
return { mcpName: this.serverName }; return { mcpName: this.serverName };
+11
View File
@@ -14,8 +14,11 @@ import {
type ToolInvocation, type ToolInvocation,
type ToolLocation, type ToolLocation,
type ToolResult, type ToolResult,
type PolicyUpdateOptions,
type ToolConfirmationOutcome,
} from './tools.js'; } from './tools.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { buildFilePathArgsPattern } from '../policy/utils.js';
import type { PartUnion } from '@google/genai'; import type { PartUnion } from '@google/genai';
import { import {
@@ -88,6 +91,14 @@ class ReadFileToolInvocation extends BaseToolInvocation<
]; ];
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildFilePathArgsPattern(this.params.file_path),
};
}
async execute(): Promise<ToolResult> { async execute(): Promise<ToolResult> {
const validationError = this.config.validatePathAccess( const validationError = this.config.validatePathAccess(
this.resolvedPath, this.resolvedPath,
@@ -11,11 +11,14 @@ import {
Kind, Kind,
type ToolInvocation, type ToolInvocation,
type ToolResult, type ToolResult,
type PolicyUpdateOptions,
type ToolConfirmationOutcome,
} from './tools.js'; } from './tools.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
import * as fsPromises from 'node:fs/promises'; import * as fsPromises from 'node:fs/promises';
import * as path from 'node:path'; import * as path from 'node:path';
import { glob, escape } from 'glob'; import { glob, escape } from 'glob';
import { buildPatternArgsPattern } from '../policy/utils.js';
import { import {
detectFileType, detectFileType,
processSingleFileContent, processSingleFileContent,
@@ -155,6 +158,16 @@ ${finalExclusionPatternsForDescription
)}".`; )}".`;
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
// We join the include patterns to match the JSON stringified arguments.
// buildPatternArgsPattern handles JSON stringification.
return {
argsPattern: buildPatternArgsPattern(JSON.stringify(this.params.include)),
};
}
async execute(signal: AbortSignal): Promise<ToolResult> { async execute(signal: AbortSignal): Promise<ToolResult> {
const { include, exclude = [], useDefaultExcludes = true } = this.params; const { include, exclude = [], useDefaultExcludes = true } = this.params;
+1 -1
View File
@@ -90,7 +90,7 @@ export class ShellToolInvocation extends BaseToolInvocation<
return description; return description;
} }
protected override getPolicyUpdateOptions( override getPolicyUpdateOptions(
outcome: ToolConfirmationOutcome, outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined { ): PolicyUpdateOptions | undefined {
if ( if (
+23 -6
View File
@@ -154,12 +154,22 @@ export const LS_TOOL_NAME_LEGACY = 'list_directory'; // Just to be safe if anyth
export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]); export const EDIT_TOOL_NAMES = new Set([EDIT_TOOL_NAME, WRITE_FILE_TOOL_NAME]);
// Tool Display Names /**
export const WRITE_FILE_DISPLAY_NAME = 'WriteFile'; * Tools that can access local files or remote resources and should be
export const EDIT_DISPLAY_NAME = 'Edit'; * treated with extra caution when updating policies.
export const ASK_USER_DISPLAY_NAME = 'Ask User'; */
export const READ_FILE_DISPLAY_NAME = 'ReadFile'; export const SENSITIVE_TOOLS = new Set([
export const GLOB_DISPLAY_NAME = 'FindFiles'; GLOB_TOOL_NAME,
GREP_TOOL_NAME,
READ_MANY_FILES_TOOL_NAME,
WEB_FETCH_TOOL_NAME,
READ_FILE_TOOL_NAME,
LS_TOOL_NAME,
WRITE_FILE_TOOL_NAME,
EDIT_TOOL_NAME,
SHELL_TOOL_NAME,
]);
export const TRACKER_CREATE_TASK_TOOL_NAME = 'tracker_create_task'; export const TRACKER_CREATE_TASK_TOOL_NAME = 'tracker_create_task';
export const TRACKER_UPDATE_TASK_TOOL_NAME = 'tracker_update_task'; export const TRACKER_UPDATE_TASK_TOOL_NAME = 'tracker_update_task';
export const TRACKER_GET_TASK_TOOL_NAME = 'tracker_get_task'; export const TRACKER_GET_TASK_TOOL_NAME = 'tracker_get_task';
@@ -167,6 +177,13 @@ export const TRACKER_LIST_TASKS_TOOL_NAME = 'tracker_list_tasks';
export const TRACKER_ADD_DEPENDENCY_TOOL_NAME = 'tracker_add_dependency'; export const TRACKER_ADD_DEPENDENCY_TOOL_NAME = 'tracker_add_dependency';
export const TRACKER_VISUALIZE_TOOL_NAME = 'tracker_visualize'; export const TRACKER_VISUALIZE_TOOL_NAME = 'tracker_visualize';
// Tool Display Names
export const WRITE_FILE_DISPLAY_NAME = 'WriteFile';
export const EDIT_DISPLAY_NAME = 'Edit';
export const ASK_USER_DISPLAY_NAME = 'Ask User';
export const READ_FILE_DISPLAY_NAME = 'ReadFile';
export const GLOB_DISPLAY_NAME = 'FindFiles';
/** /**
* Mapping of legacy tool names to their current names. * Mapping of legacy tool names to their current names.
* This ensures backward compatibility for user-defined policies, skills, and hooks. * This ensures backward compatibility for user-defined policies, skills, and hooks.
+10 -1
View File
@@ -68,12 +68,21 @@ export interface ToolInvocation<
updateOutput?: (output: ToolLiveOutput) => void, updateOutput?: (output: ToolLiveOutput) => void,
shellExecutionConfig?: ShellExecutionConfig, shellExecutionConfig?: ShellExecutionConfig,
): Promise<TResult>; ): Promise<TResult>;
/**
* Returns tool-specific options for policy updates.
* This is used by the scheduler to narrow policy rules when a tool is approved.
*/
getPolicyUpdateOptions?(
outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined;
} }
/** /**
* Options for policy updates that can be customized by tool invocations. * Options for policy updates that can be customized by tool invocations.
*/ */
export interface PolicyUpdateOptions { export interface PolicyUpdateOptions {
argsPattern?: string;
commandPrefix?: string | string[]; commandPrefix?: string | string[];
mcpName?: string; mcpName?: string;
} }
@@ -130,7 +139,7 @@ export abstract class BaseToolInvocation<
* Subclasses can override this to provide additional options like * Subclasses can override this to provide additional options like
* commandPrefix (for shell) or mcpName (for MCP tools). * commandPrefix (for shell) or mcpName (for MCP tools).
*/ */
protected getPolicyUpdateOptions( getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome, _outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined { ): PolicyUpdateOptions | undefined {
return undefined; return undefined;
+18
View File
@@ -12,7 +12,9 @@ import {
type ToolInvocation, type ToolInvocation,
type ToolResult, type ToolResult,
type ToolConfirmationOutcome, type ToolConfirmationOutcome,
type PolicyUpdateOptions,
} from './tools.js'; } from './tools.js';
import { buildPatternArgsPattern } from '../policy/utils.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { getErrorMessage } from '../utils/errors.js'; import { getErrorMessage } from '../utils/errors.js';
@@ -291,6 +293,22 @@ ${textContent}
return `Processing URLs and instructions from prompt: "${displayPrompt}"`; return `Processing URLs and instructions from prompt: "${displayPrompt}"`;
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
if (this.params.url) {
return {
argsPattern: buildPatternArgsPattern(this.params.url),
};
}
if (this.params.prompt) {
return {
argsPattern: buildPatternArgsPattern(this.params.prompt),
};
}
return undefined;
}
protected override async getConfirmationDetails( protected override async getConfirmationDetails(
_abortSignal: AbortSignal, _abortSignal: AbortSignal,
): Promise<ToolCallConfirmationDetails | false> { ): Promise<ToolCallConfirmationDetails | false> {
+10
View File
@@ -24,7 +24,9 @@ import {
type ToolLocation, type ToolLocation,
type ToolResult, type ToolResult,
type ToolConfirmationOutcome, type ToolConfirmationOutcome,
type PolicyUpdateOptions,
} from './tools.js'; } from './tools.js';
import { buildFilePathArgsPattern } from '../policy/utils.js';
import { ToolErrorType } from './tool-error.js'; import { ToolErrorType } from './tool-error.js';
import { makeRelative, shortenPath } from '../utils/paths.js'; import { makeRelative, shortenPath } from '../utils/paths.js';
import { getErrorMessage, isNodeError } from '../utils/errors.js'; import { getErrorMessage, isNodeError } from '../utils/errors.js';
@@ -164,6 +166,14 @@ class WriteFileToolInvocation extends BaseToolInvocation<
return [{ path: this.resolvedPath }]; return [{ path: this.resolvedPath }];
} }
override getPolicyUpdateOptions(
_outcome: ToolConfirmationOutcome,
): PolicyUpdateOptions | undefined {
return {
argsPattern: buildFilePathArgsPattern(this.params.file_path),
};
}
override getDescription(): string { override getDescription(): string {
const relativePath = makeRelative( const relativePath = makeRelative(
this.resolvedPath, this.resolvedPath,
+7
View File
@@ -1461,6 +1461,13 @@
"default": false, "default": false,
"type": "boolean" "type": "boolean"
}, },
"autoAddToPolicyByDefault": {
"title": "Auto-add to Policy by Default",
"description": "When enabled, the \"Allow for all future sessions\" option becomes the default choice for low-risk tools in trusted workspaces.",
"markdownDescription": "When enabled, the \"Allow for all future sessions\" option becomes the default choice for low-risk tools in trusted workspaces.\n\n- Category: `Security`\n- Requires restart: `no`\n- Default: `false`",
"default": false,
"type": "boolean"
},
"blockGitExtensions": { "blockGitExtensions": {
"title": "Blocks extensions from Git", "title": "Blocks extensions from Git",
"description": "Blocks installing and loading extensions from Git.", "description": "Blocks installing and loading extensions from Git.",