fix: integrate DiscoveredTool with Policy Engine (#12646)

This commit is contained in:
Allen Hutchison
2025-11-06 15:51:16 -08:00
committed by GitHub
parent 445a5eac33
commit c81a02f8d2
4 changed files with 146 additions and 16 deletions

View File

@@ -641,4 +641,33 @@ priority = 150
vi.doUnmock('node:fs/promises');
});
it('should have default ASK_USER rule for discovered tools', async () => {
vi.resetModules();
vi.doUnmock('node:fs/promises');
const { createPolicyEngineConfig: createConfig } = await import(
'./config.js'
);
// Re-mock Storage after resetModules because it was reloaded
const { Storage: FreshStorage } = await import('../config/storage.js');
vi.spyOn(FreshStorage, 'getUserPoliciesDir').mockReturnValue(
'/non/existent/user/policies',
);
vi.spyOn(FreshStorage, 'getSystemPoliciesDir').mockReturnValue(
'/non/existent/system/policies',
);
const settings: PolicySettings = {};
// Use default policy dir to load real discovered.toml
const config = await createConfig(settings, ApprovalMode.DEFAULT);
const discoveredRule = config.rules?.find(
(r) =>
r.toolName === 'discovered_tool_*' &&
r.decision === PolicyDecision.ASK_USER,
);
expect(discoveredRule).toBeDefined();
// Priority 10 in default tier → 1.010
expect(discoveredRule?.priority).toBeCloseTo(1.01, 5);
});
});

View File

@@ -0,0 +1,8 @@
# Default policy for tools discovered via toolDiscoveryCommand.
# These tools are potentially dangerous as they are arbitrary scripts.
# We default them to ASK_USER for safety.
[[rule]]
toolName = "discovered_tool_*"
decision = "ask_user"
priority = 10

View File

@@ -11,7 +11,11 @@ import type { ConfigParameters } from '../config/config.js';
import { Config } from '../config/config.js';
import { ApprovalMode } from '../policy/types.js';
import { ToolRegistry, DiscoveredTool } from './tool-registry.js';
import {
ToolRegistry,
DiscoveredTool,
DISCOVERED_TOOL_PREFIX,
} from './tool-registry.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import type { FunctionDeclaration, CallableTool } from '@google/genai';
import { mcpToTool } from '@google/genai';
@@ -20,6 +24,7 @@ import { spawn } from 'node:child_process';
import fs from 'node:fs';
import { MockTool } from '../test-utils/mock-tool.js';
import { ToolErrorType } from './tool-error.js';
import type { MessageBus } from '../confirmation-bus/message-bus.js';
vi.mock('node:fs');
@@ -257,6 +262,7 @@ describe('ToolRegistry', () => {
const discovered1 = new DiscoveredTool(
config,
'discovered-1',
DISCOVERED_TOOL_PREFIX + 'discovered-1',
'desc',
{},
);
@@ -288,7 +294,7 @@ describe('ToolRegistry', () => {
expect(toolRegistry.getAllToolNames()).toEqual([
'builtin-1',
'builtin-2',
'discovered-1',
DISCOVERED_TOOL_PREFIX + 'discovered-1',
'mcp-apple',
'mcp-zebra',
]);
@@ -346,7 +352,9 @@ describe('ToolRegistry', () => {
await toolRegistry.discoverAllTools();
const discoveredTool = toolRegistry.getTool('tool-with-bad-format');
const discoveredTool = toolRegistry.getTool(
DISCOVERED_TOOL_PREFIX + 'tool-with-bad-format',
);
expect(discoveredTool).toBeDefined();
const registeredParams = (discoveredTool as DiscoveredTool).schema
@@ -401,7 +409,9 @@ describe('ToolRegistry', () => {
});
await toolRegistry.discoverAllTools();
const discoveredTool = toolRegistry.getTool('failing-tool');
const discoveredTool = toolRegistry.getTool(
DISCOVERED_TOOL_PREFIX + 'failing-tool',
);
expect(discoveredTool).toBeDefined();
// --- Execution Mock ---
@@ -436,11 +446,74 @@ describe('ToolRegistry', () => {
expect(result.llmContent).toContain('Stderr: Something went wrong');
expect(result.llmContent).toContain('Exit Code: 1');
});
it('should pass MessageBus to DiscoveredTool and its invocations', async () => {
const discoveryCommand = 'my-discovery-command';
mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand);
// Mock MessageBus
const mockMessageBus = {
publish: vi.fn(),
subscribe: vi.fn(),
unsubscribe: vi.fn(),
} as unknown as MessageBus;
toolRegistry.setMessageBus(mockMessageBus);
const toolDeclaration: FunctionDeclaration = {
name: 'policy-test-tool',
description: 'tests policy',
parametersJsonSchema: { type: 'object', properties: {} },
};
const mockSpawn = vi.mocked(spawn);
const discoveryProcess = {
stdout: { on: vi.fn(), removeListener: vi.fn() },
stderr: { on: vi.fn(), removeListener: vi.fn() },
on: vi.fn(),
kill: vi.fn(),
};
mockSpawn.mockReturnValueOnce(discoveryProcess as any);
discoveryProcess.stdout.on.mockImplementation((event, callback) => {
if (event === 'data') {
callback(
Buffer.from(
JSON.stringify([{ functionDeclarations: [toolDeclaration] }]),
),
);
}
});
discoveryProcess.on.mockImplementation((event, callback) => {
if (event === 'close') {
callback(0);
}
});
await toolRegistry.discoverAllTools();
const tool = toolRegistry.getTool(
DISCOVERED_TOOL_PREFIX + 'policy-test-tool',
);
expect(tool).toBeDefined();
// Verify DiscoveredTool has the message bus
expect((tool as any).messageBus).toBe(mockMessageBus);
const invocation = tool!.build({});
// Verify DiscoveredToolInvocation has the message bus
expect((invocation as any).messageBus).toBe(mockMessageBus);
});
});
describe('DiscoveredToolInvocation', () => {
it('should return the stringified params from getDescription', () => {
const tool = new DiscoveredTool(config, 'test-tool', 'A test tool', {});
const tool = new DiscoveredTool(
config,
'test-tool',
DISCOVERED_TOOL_PREFIX + 'test-tool',
'A test tool',
{},
);
const params = { param: 'testValue' };
const invocation = tool.build(params);
const description = invocation.getDescription();

View File

@@ -22,6 +22,8 @@ import type { MessageBus } from '../confirmation-bus/message-bus.js';
import { debugLogger } from '../utils/debugLogger.js';
import { coreEvents } from '../utils/events.js';
export const DISCOVERED_TOOL_PREFIX = 'discovered_tool_';
type ToolParams = Record<string, unknown>;
class DiscoveredToolInvocation extends BaseToolInvocation<
@@ -30,10 +32,12 @@ class DiscoveredToolInvocation extends BaseToolInvocation<
> {
constructor(
private readonly config: Config,
private readonly toolName: string,
private readonly originalToolName: string,
prefixedToolName: string,
params: ToolParams,
messageBus?: MessageBus,
) {
super(params);
super(params, messageBus, prefixedToolName);
}
getDescription(): string {
@@ -45,7 +49,7 @@ class DiscoveredToolInvocation extends BaseToolInvocation<
_updateOutput?: (output: string) => void,
): Promise<ToolResult> {
const callCommand = this.config.getToolCallCommand()!;
const child = spawn(callCommand, [this.toolName]);
const child = spawn(callCommand, [this.originalToolName]);
child.stdin.write(JSON.stringify(this.params));
child.stdin.end();
@@ -124,18 +128,24 @@ export class DiscoveredTool extends BaseDeclarativeTool<
ToolParams,
ToolResult
> {
private readonly originalName: string;
constructor(
private readonly config: Config,
name: string,
override readonly description: string,
originalName: string,
prefixedName: string,
description: string,
override readonly parameterSchema: Record<string, unknown>,
messageBus?: MessageBus,
) {
const discoveryCmd = config.getToolDiscoveryCommand()!;
const callCommand = config.getToolCallCommand()!;
description += `
const fullDescription =
description +
`
This tool was discovered from the project by executing the command \`${discoveryCmd}\` on project root.
When called, this tool will execute the command \`${callCommand} ${name}\` on project root.
When called, this tool will execute the command \`${callCommand} ${originalName}\` on project root.
Tool discovery and call commands can be configured in project or user settings.
When called, the tool call command is executed as a subprocess.
@@ -149,14 +159,16 @@ Exit Code: Exit code or \`(none)\` if terminated by signal.
Signal: Signal number or \`(none)\` if no signal was received.
`;
super(
name,
name,
description,
prefixedName,
prefixedName,
fullDescription,
Kind.Other,
parameterSchema,
false, // isOutputMarkdown
false, // canUpdateOutput
messageBus,
);
this.originalName = originalName;
}
protected createInvocation(
@@ -165,7 +177,13 @@ Signal: Signal number or \`(none)\` if no signal was received.
_toolName?: string,
_displayName?: string,
): ToolInvocation<ToolParams, ToolResult> {
return new DiscoveredToolInvocation(this.config, this.name, params);
return new DiscoveredToolInvocation(
this.config,
this.originalName,
this.name,
params,
_messageBus,
);
}
}
@@ -385,8 +403,10 @@ export class ToolRegistry {
new DiscoveredTool(
this.config,
func.name,
DISCOVERED_TOOL_PREFIX + func.name,
func.description ?? '',
parameters as Record<string, unknown>,
this.messageBus,
),
);
}