From 527074b50a8427502c154baccdbc0ebeeb3e5309 Mon Sep 17 00:00:00 2001 From: AK Date: Mon, 9 Mar 2026 12:22:46 -0700 Subject: [PATCH] feat(policy): support subagent-specific policies in TOML (#21431) --- docs/reference/policy-engine.md | 4 ++++ .../core/src/agents/agent-scheduler.test.ts | 1 + packages/core/src/agents/agent-scheduler.ts | 3 ++- packages/core/src/agents/local-executor.ts | 20 ++++++++++++++++++- .../src/confirmation-bus/message-bus.test.ts | 1 + .../core/src/confirmation-bus/message-bus.ts | 1 + packages/core/src/confirmation-bus/types.ts | 4 ++++ packages/core/src/policy/policy-engine.ts | 17 +++++++++++++++- packages/core/src/policy/toml-loader.ts | 2 ++ packages/core/src/policy/types.ts | 6 ++++++ 10 files changed, 56 insertions(+), 3 deletions(-) diff --git a/docs/reference/policy-engine.md b/docs/reference/policy-engine.md index 38a0b4d50c..c0a331d99d 100644 --- a/docs/reference/policy-engine.md +++ b/docs/reference/policy-engine.md @@ -219,6 +219,10 @@ Here is a breakdown of the fields available in a TOML policy rule: # A unique name for the tool, or an array of names. toolName = "run_shell_command" +# (Optional) The name of a subagent. If provided, the rule only applies to tool calls +# made by this specific subagent. +subagent = "generalist" + # (Optional) The name of an MCP server. Can be combined with toolName # to form a composite name like "mcpName__toolName". mcpName = "my-custom-server" diff --git a/packages/core/src/agents/agent-scheduler.test.ts b/packages/core/src/agents/agent-scheduler.test.ts index dd6749d3a0..451fb276a2 100644 --- a/packages/core/src/agents/agent-scheduler.test.ts +++ b/packages/core/src/agents/agent-scheduler.test.ts @@ -27,6 +27,7 @@ describe('agent-scheduler', () => { mockMessageBus = {} as Mocked; mockToolRegistry = { getTool: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(mockMessageBus), } as unknown as Mocked; mockConfig = { getMessageBus: vi.fn().mockReturnValue(mockMessageBus), diff --git a/packages/core/src/agents/agent-scheduler.ts b/packages/core/src/agents/agent-scheduler.ts index ecb4ed960a..983f814b0f 100644 --- a/packages/core/src/agents/agent-scheduler.ts +++ b/packages/core/src/agents/agent-scheduler.ts @@ -57,10 +57,11 @@ export async function scheduleAgentTools( // eslint-disable-next-line @typescript-eslint/no-unsafe-assignment const agentConfig: Config = Object.create(config); agentConfig.getToolRegistry = () => toolRegistry; + agentConfig.getMessageBus = () => toolRegistry.getMessageBus(); const scheduler = new Scheduler({ config: agentConfig, - messageBus: config.getMessageBus(), + messageBus: toolRegistry.getMessageBus(), getPreferredEditor: getPreferredEditor ?? (() => undefined), schedulerId, parentCallId, diff --git a/packages/core/src/agents/local-executor.ts b/packages/core/src/agents/local-executor.ts index fd450c5efa..dd5b78a9a6 100644 --- a/packages/core/src/agents/local-executor.ts +++ b/packages/core/src/agents/local-executor.ts @@ -19,6 +19,7 @@ import { ToolRegistry } from '../tools/tool-registry.js'; import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; import { CompressionStatus } from '../core/turn.js'; import { type ToolCallRequestInfo } from '../scheduler/types.js'; +import { type Message } from '../confirmation-bus/types.js'; import { ChatCompressionService } from '../services/chatCompressionService.js'; import { getDirectoryContextString } from '../utils/environmentContext.js'; import { promptIdContext } from '../utils/promptIdContext.js'; @@ -113,10 +114,27 @@ export class LocalAgentExecutor { runtimeContext: Config, onActivity?: ActivityCallback, ): Promise> { + const parentMessageBus = runtimeContext.getMessageBus(); + + // Create an override object to inject the subagent name into tool confirmation requests + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + const subagentMessageBus = Object.create( + parentMessageBus, + ) as typeof parentMessageBus; + subagentMessageBus.publish = async (message: Message) => { + if (message.type === 'tool-confirmation-request') { + return parentMessageBus.publish({ + ...message, + subagent: definition.name, + }); + } + return parentMessageBus.publish(message); + }; + // Create an isolated tool registry for this agent instance. const agentToolRegistry = new ToolRegistry( runtimeContext, - runtimeContext.getMessageBus(), + subagentMessageBus, ); const parentToolRegistry = runtimeContext.getToolRegistry(); const allAgentNames = new Set( diff --git a/packages/core/src/confirmation-bus/message-bus.test.ts b/packages/core/src/confirmation-bus/message-bus.test.ts index 8342d53b1b..34e36167a9 100644 --- a/packages/core/src/confirmation-bus/message-bus.test.ts +++ b/packages/core/src/confirmation-bus/message-bus.test.ts @@ -160,6 +160,7 @@ describe('MessageBus', () => { { name: 'test-tool', args: {} }, 'test-server', annotations, + undefined, ); }); diff --git a/packages/core/src/confirmation-bus/message-bus.ts b/packages/core/src/confirmation-bus/message-bus.ts index 3dd61995ab..33aa10355b 100644 --- a/packages/core/src/confirmation-bus/message-bus.ts +++ b/packages/core/src/confirmation-bus/message-bus.ts @@ -56,6 +56,7 @@ export class MessageBus extends EventEmitter { message.toolCall, message.serverName, message.toolAnnotations, + message.subagent, ); switch (decision) { diff --git a/packages/core/src/confirmation-bus/types.ts b/packages/core/src/confirmation-bus/types.ts index aefafe0fa0..277c821da3 100644 --- a/packages/core/src/confirmation-bus/types.ts +++ b/packages/core/src/confirmation-bus/types.ts @@ -38,6 +38,10 @@ export interface ToolConfirmationRequest { * Optional tool annotations (e.g., readOnlyHint, destructiveHint) from MCP. */ toolAnnotations?: Record; + /** + * Optional subagent name, if this tool call was initiated by a subagent. + */ + subagent?: string; /** * Optional rich details for the confirmation UI (diffs, counts, etc.) */ diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index 0d6a043da0..a2f64bf356 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -74,6 +74,7 @@ function ruleMatches( serverName: string | undefined, currentApprovalMode: ApprovalMode, toolAnnotations?: Record, + subagent?: string, ): boolean { // Check if rule applies to current approval mode if (rule.modes && rule.modes.length > 0) { @@ -82,6 +83,13 @@ function ruleMatches( } } + // Check subagent if specified (only for PolicyRule, SafetyCheckerRule doesn't have it) + if ('subagent' in rule && rule.subagent) { + if (rule.subagent !== subagent) { + return false; + } + } + // Strictly enforce mcpName identity if the rule dictates it if (rule.mcpName) { if (rule.mcpName === '*') { @@ -203,6 +211,7 @@ export class PolicyEngine { allowRedirection?: boolean, rule?: PolicyRule, toolAnnotations?: Record, + subagent?: string, ): Promise { if (!command) { return { @@ -294,6 +303,7 @@ export class PolicyEngine { { name: toolName, args: { command: subCmd, dir_path } }, serverName, toolAnnotations, + subagent, ); // subResult.decision is already filtered through applyNonInteractiveMode by this.check() @@ -352,6 +362,7 @@ export class PolicyEngine { toolCall: FunctionCall, serverName: string | undefined, toolAnnotations?: Record, + subagent?: string, ): Promise { // Case 1: Metadata injection is the primary and safest way to identify an MCP server. // If we have explicit `_serverName` metadata (usually injected by tool-registry for active tools), use it. @@ -419,6 +430,7 @@ export class PolicyEngine { serverName, this.approvalMode, toolAnnotations, + subagent, ), ); @@ -437,6 +449,7 @@ export class PolicyEngine { rule.allowRedirection, rule, toolAnnotations, + subagent, ); decision = shellResult.decision; if (shellResult.rule) { @@ -463,9 +476,10 @@ export class PolicyEngine { this.defaultDecision, serverName, shellDirPath, - undefined, + false, undefined, toolAnnotations, + subagent, ); decision = shellResult.decision; matchedRule = shellResult.rule; @@ -485,6 +499,7 @@ export class PolicyEngine { serverName, this.approvalMode, toolAnnotations, + subagent, ) ) { debugLogger.debug( diff --git a/packages/core/src/policy/toml-loader.ts b/packages/core/src/policy/toml-loader.ts index c91930a21d..83dda26e9e 100644 --- a/packages/core/src/policy/toml-loader.ts +++ b/packages/core/src/policy/toml-loader.ts @@ -38,6 +38,7 @@ const MAX_TYPO_DISTANCE = 3; */ const PolicyRuleSchema = z.object({ toolName: z.union([z.string(), z.array(z.string())]).optional(), + subagent: z.string().optional(), mcpName: z.string().optional(), argsPattern: z.string().optional(), commandPrefix: z.union([z.string(), z.array(z.string())]).optional(), @@ -464,6 +465,7 @@ export async function loadPoliciesFromToml( const policyRule: PolicyRule = { toolName: effectiveToolName, + subagent: rule.subagent, mcpName: rule.mcpName, decision: rule.decision, priority: transformPriority(rule.priority, tier), diff --git a/packages/core/src/policy/types.ts b/packages/core/src/policy/types.ts index f59821b093..53a0433a15 100644 --- a/packages/core/src/policy/types.ts +++ b/packages/core/src/policy/types.ts @@ -110,6 +110,12 @@ export interface PolicyRule { */ toolName?: string; + /** + * The name of the subagent this rule applies to. + * If undefined, the rule applies regardless of whether it's the main agent or a subagent. + */ + subagent?: string; + /** * Identifies the MCP server this rule applies to. * Enables precise rule matching against `serverName` metadata instead