diff --git a/packages/core/src/policy/config.test.ts b/packages/core/src/policy/config.test.ts index 31c20e85dc..460087639b 100644 --- a/packages/core/src/policy/config.test.ts +++ b/packages/core/src/policy/config.test.ts @@ -641,4 +641,33 @@ priority = 150 vi.doUnmock('node:fs/promises'); }); + + it('should have default ASK_USER rule for discovered tools', async () => { + vi.resetModules(); + vi.doUnmock('node:fs/promises'); + const { createPolicyEngineConfig: createConfig } = await import( + './config.js' + ); + // Re-mock Storage after resetModules because it was reloaded + const { Storage: FreshStorage } = await import('../config/storage.js'); + vi.spyOn(FreshStorage, 'getUserPoliciesDir').mockReturnValue( + '/non/existent/user/policies', + ); + vi.spyOn(FreshStorage, 'getSystemPoliciesDir').mockReturnValue( + '/non/existent/system/policies', + ); + + const settings: PolicySettings = {}; + // Use default policy dir to load real discovered.toml + const config = await createConfig(settings, ApprovalMode.DEFAULT); + + const discoveredRule = config.rules?.find( + (r) => + r.toolName === 'discovered_tool_*' && + r.decision === PolicyDecision.ASK_USER, + ); + expect(discoveredRule).toBeDefined(); + // Priority 10 in default tier → 1.010 + expect(discoveredRule?.priority).toBeCloseTo(1.01, 5); + }); }); diff --git a/packages/core/src/policy/policies/discovered.toml b/packages/core/src/policy/policies/discovered.toml new file mode 100644 index 0000000000..b343a1807f --- /dev/null +++ b/packages/core/src/policy/policies/discovered.toml @@ -0,0 +1,8 @@ +# Default policy for tools discovered via toolDiscoveryCommand. +# These tools are potentially dangerous as they are arbitrary scripts. +# We default them to ASK_USER for safety. + +[[rule]] +toolName = "discovered_tool_*" +decision = "ask_user" +priority = 10 diff --git a/packages/core/src/tools/tool-registry.test.ts b/packages/core/src/tools/tool-registry.test.ts index 80e9390cce..1d3ddb786b 100644 --- a/packages/core/src/tools/tool-registry.test.ts +++ b/packages/core/src/tools/tool-registry.test.ts @@ -11,7 +11,11 @@ import type { ConfigParameters } from '../config/config.js'; import { Config } from '../config/config.js'; import { ApprovalMode } from '../policy/types.js'; -import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; +import { + ToolRegistry, + DiscoveredTool, + DISCOVERED_TOOL_PREFIX, +} from './tool-registry.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import type { FunctionDeclaration, CallableTool } from '@google/genai'; import { mcpToTool } from '@google/genai'; @@ -20,6 +24,7 @@ import { spawn } from 'node:child_process'; import fs from 'node:fs'; import { MockTool } from '../test-utils/mock-tool.js'; import { ToolErrorType } from './tool-error.js'; +import type { MessageBus } from '../confirmation-bus/message-bus.js'; vi.mock('node:fs'); @@ -257,6 +262,7 @@ describe('ToolRegistry', () => { const discovered1 = new DiscoveredTool( config, 'discovered-1', + DISCOVERED_TOOL_PREFIX + 'discovered-1', 'desc', {}, ); @@ -288,7 +294,7 @@ describe('ToolRegistry', () => { expect(toolRegistry.getAllToolNames()).toEqual([ 'builtin-1', 'builtin-2', - 'discovered-1', + DISCOVERED_TOOL_PREFIX + 'discovered-1', 'mcp-apple', 'mcp-zebra', ]); @@ -346,7 +352,9 @@ describe('ToolRegistry', () => { await toolRegistry.discoverAllTools(); - const discoveredTool = toolRegistry.getTool('tool-with-bad-format'); + const discoveredTool = toolRegistry.getTool( + DISCOVERED_TOOL_PREFIX + 'tool-with-bad-format', + ); expect(discoveredTool).toBeDefined(); const registeredParams = (discoveredTool as DiscoveredTool).schema @@ -401,7 +409,9 @@ describe('ToolRegistry', () => { }); await toolRegistry.discoverAllTools(); - const discoveredTool = toolRegistry.getTool('failing-tool'); + const discoveredTool = toolRegistry.getTool( + DISCOVERED_TOOL_PREFIX + 'failing-tool', + ); expect(discoveredTool).toBeDefined(); // --- Execution Mock --- @@ -436,11 +446,74 @@ describe('ToolRegistry', () => { expect(result.llmContent).toContain('Stderr: Something went wrong'); expect(result.llmContent).toContain('Exit Code: 1'); }); + + it('should pass MessageBus to DiscoveredTool and its invocations', async () => { + const discoveryCommand = 'my-discovery-command'; + mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); + + // Mock MessageBus + const mockMessageBus = { + publish: vi.fn(), + subscribe: vi.fn(), + unsubscribe: vi.fn(), + } as unknown as MessageBus; + toolRegistry.setMessageBus(mockMessageBus); + + const toolDeclaration: FunctionDeclaration = { + name: 'policy-test-tool', + description: 'tests policy', + parametersJsonSchema: { type: 'object', properties: {} }, + }; + + const mockSpawn = vi.mocked(spawn); + const discoveryProcess = { + stdout: { on: vi.fn(), removeListener: vi.fn() }, + stderr: { on: vi.fn(), removeListener: vi.fn() }, + on: vi.fn(), + kill: vi.fn(), + }; + mockSpawn.mockReturnValueOnce(discoveryProcess as any); + + discoveryProcess.stdout.on.mockImplementation((event, callback) => { + if (event === 'data') { + callback( + Buffer.from( + JSON.stringify([{ functionDeclarations: [toolDeclaration] }]), + ), + ); + } + }); + discoveryProcess.on.mockImplementation((event, callback) => { + if (event === 'close') { + callback(0); + } + }); + + await toolRegistry.discoverAllTools(); + const tool = toolRegistry.getTool( + DISCOVERED_TOOL_PREFIX + 'policy-test-tool', + ); + expect(tool).toBeDefined(); + + // Verify DiscoveredTool has the message bus + expect((tool as any).messageBus).toBe(mockMessageBus); + + const invocation = tool!.build({}); + + // Verify DiscoveredToolInvocation has the message bus + expect((invocation as any).messageBus).toBe(mockMessageBus); + }); }); describe('DiscoveredToolInvocation', () => { it('should return the stringified params from getDescription', () => { - const tool = new DiscoveredTool(config, 'test-tool', 'A test tool', {}); + const tool = new DiscoveredTool( + config, + 'test-tool', + DISCOVERED_TOOL_PREFIX + 'test-tool', + 'A test tool', + {}, + ); const params = { param: 'testValue' }; const invocation = tool.build(params); const description = invocation.getDescription(); diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index c350abfbd2..59f45a6826 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -22,6 +22,8 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js'; import { debugLogger } from '../utils/debugLogger.js'; import { coreEvents } from '../utils/events.js'; +export const DISCOVERED_TOOL_PREFIX = 'discovered_tool_'; + type ToolParams = Record; class DiscoveredToolInvocation extends BaseToolInvocation< @@ -30,10 +32,12 @@ class DiscoveredToolInvocation extends BaseToolInvocation< > { constructor( private readonly config: Config, - private readonly toolName: string, + private readonly originalToolName: string, + prefixedToolName: string, params: ToolParams, + messageBus?: MessageBus, ) { - super(params); + super(params, messageBus, prefixedToolName); } getDescription(): string { @@ -45,7 +49,7 @@ class DiscoveredToolInvocation extends BaseToolInvocation< _updateOutput?: (output: string) => void, ): Promise { const callCommand = this.config.getToolCallCommand()!; - const child = spawn(callCommand, [this.toolName]); + const child = spawn(callCommand, [this.originalToolName]); child.stdin.write(JSON.stringify(this.params)); child.stdin.end(); @@ -124,18 +128,24 @@ export class DiscoveredTool extends BaseDeclarativeTool< ToolParams, ToolResult > { + private readonly originalName: string; + constructor( private readonly config: Config, - name: string, - override readonly description: string, + originalName: string, + prefixedName: string, + description: string, override readonly parameterSchema: Record, + messageBus?: MessageBus, ) { const discoveryCmd = config.getToolDiscoveryCommand()!; const callCommand = config.getToolCallCommand()!; - description += ` + const fullDescription = + description + + ` This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root. -When called, this tool will execute the command \`${callCommand} ${name}\` on project root. +When called, this tool will execute the command \`${callCommand} ${originalName}\` on project root. Tool discovery and call commands can be configured in project or user settings. When called, the tool call command is executed as a subprocess. @@ -149,14 +159,16 @@ Exit Code: Exit code or \`(none)\` if terminated by signal. Signal: Signal number or \`(none)\` if no signal was received. `; super( - name, - name, - description, + prefixedName, + prefixedName, + fullDescription, Kind.Other, parameterSchema, false, // isOutputMarkdown false, // canUpdateOutput + messageBus, ); + this.originalName = originalName; } protected createInvocation( @@ -165,7 +177,13 @@ Signal: Signal number or \`(none)\` if no signal was received. _toolName?: string, _displayName?: string, ): ToolInvocation { - return new DiscoveredToolInvocation(this.config, this.name, params); + return new DiscoveredToolInvocation( + this.config, + this.originalName, + this.name, + params, + _messageBus, + ); } } @@ -385,8 +403,10 @@ export class ToolRegistry { new DiscoveredTool( this.config, func.name, + DISCOVERED_TOOL_PREFIX + func.name, func.description ?? '', parameters as Record, + this.messageBus, ), ); }