diff --git a/packages/core/src/core/coreToolScheduler.test.ts b/packages/core/src/core/coreToolScheduler.test.ts index 9cfacc2358..f3a35319e9 100644 --- a/packages/core/src/core/coreToolScheduler.test.ts +++ b/packages/core/src/core/coreToolScheduler.test.ts @@ -6,6 +6,7 @@ import { describe, it, expect, vi } from 'vitest'; import type { Mock } from 'vitest'; +import type { CallableTool } from '@google/genai'; import { CoreToolScheduler } from './coreToolScheduler.js'; import type { ToolCall, @@ -41,6 +42,7 @@ import { import * as modifiableToolModule from '../tools/modifiable-tool.js'; import { DEFAULT_GEMINI_MODEL } from '../config/models.js'; import type { PolicyEngine } from '../policy/policy-engine.js'; +import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; vi.mock('fs/promises', () => ({ writeFile: vi.fn(), @@ -283,7 +285,10 @@ function createMockConfig(overrides: Partial = {}): Config { if (!overrides.getPolicyEngine) { finalConfig.getPolicyEngine = () => ({ - check: async (toolCall: { name: string; args: object }) => { + check: async ( + toolCall: { name: string; args: object }, + _serverName?: string, + ) => { // Mock simple policy logic for tests const mode = finalConfig.getApprovalMode(); if (mode === ApprovalMode.YOLO) { @@ -1834,4 +1839,69 @@ describe('CoreToolScheduler Sequential Execution', () => { modifyWithEditorSpy.mockRestore(); }); + + it('should pass serverName to policy engine for DiscoveredMCPTool', async () => { + const mockMcpTool = { + tool: async () => ({ functionDeclarations: [] }), + callTool: async () => [], + }; + const serverName = 'test-server'; + const toolName = 'test-tool'; + const mcpTool = new DiscoveredMCPTool( + mockMcpTool as unknown as CallableTool, + serverName, + toolName, + 'description', + { type: 'object', properties: {} }, + createMockMessageBus() as unknown as MessageBus, + ); + + const mockToolRegistry = { + getTool: () => mcpTool, + getFunctionDeclarations: () => [], + tools: new Map(), + discovery: {}, + registerTool: () => {}, + getToolByName: () => mcpTool, + getToolByDisplayName: () => mcpTool, + getTools: () => [], + discoverTools: async () => {}, + getAllTools: () => [], + getToolsByServer: () => [], + } as unknown as ToolRegistry; + + const mockPolicyEngineCheck = vi.fn().mockResolvedValue({ + decision: PolicyDecision.ALLOW, + }); + + const mockConfig = createMockConfig({ + getToolRegistry: () => mockToolRegistry, + getPolicyEngine: () => + ({ + check: mockPolicyEngineCheck, + }) as unknown as PolicyEngine, + isInteractive: () => false, + }); + + const scheduler = new CoreToolScheduler({ + config: mockConfig, + getPreferredEditor: () => 'vscode', + }); + + const abortController = new AbortController(); + const request = { + callId: '1', + name: toolName, + args: {}, + isClientInitiated: false, + prompt_id: 'prompt-id-1', + }; + + await scheduler.schedule(request, abortController.signal); + + expect(mockPolicyEngineCheck).toHaveBeenCalledWith( + expect.objectContaining({ name: toolName }), + serverName, + ); + }); }); diff --git a/packages/core/src/core/coreToolScheduler.ts b/packages/core/src/core/coreToolScheduler.ts index bec4eacd53..9b2b08c47f 100644 --- a/packages/core/src/core/coreToolScheduler.ts +++ b/packages/core/src/core/coreToolScheduler.ts @@ -44,6 +44,7 @@ import { type ToolCallResponseInfo, } from '../scheduler/types.js'; import { ToolExecutor } from '../scheduler/tool-executor.js'; +import { DiscoveredMCPTool } from '../tools/mcp-tool.js'; export type { ToolCall, @@ -591,9 +592,15 @@ export class CoreToolScheduler { name: toolCall.request.name, args: toolCall.request.args, }; + + const serverName = + toolCall.tool instanceof DiscoveredMCPTool + ? toolCall.tool.serverName + : undefined; + const { decision } = await this.config .getPolicyEngine() - .check(toolCallForPolicy, undefined); // Server name undefined for local tools + .check(toolCallForPolicy, serverName); if (decision === PolicyDecision.DENY) { const errorMessage = `Tool execution denied by policy.`; diff --git a/packages/core/src/policy/policy-engine.test.ts b/packages/core/src/policy/policy-engine.test.ts index 33dc77f00f..c30681a429 100644 --- a/packages/core/src/policy/policy-engine.test.ts +++ b/packages/core/src/policy/policy-engine.test.ts @@ -109,6 +109,37 @@ describe('PolicyEngine', () => { ); }); + it('should match unqualified tool names with qualified rules when serverName is provided', async () => { + const rules: PolicyRule[] = [ + { + toolName: 'my-server__tool', + decision: PolicyDecision.ALLOW, + }, + ]; + + engine = new PolicyEngine({ rules }); + + // Match with qualified name (standard) + expect( + (await engine.check({ name: 'my-server__tool' }, 'my-server')).decision, + ).toBe(PolicyDecision.ALLOW); + + // Match with unqualified name + serverName (the fix) + expect((await engine.check({ name: 'tool' }, 'my-server')).decision).toBe( + PolicyDecision.ALLOW, + ); + + // Should NOT match with unqualified name but NO serverName + expect((await engine.check({ name: 'tool' }, undefined)).decision).toBe( + PolicyDecision.ASK_USER, + ); + + // Should NOT match with unqualified name but WRONG serverName + expect( + (await engine.check({ name: 'tool' }, 'wrong-server')).decision, + ).toBe(PolicyDecision.ASK_USER); + }); + it('should match by args pattern', async () => { const rules: PolicyRule[] = [ { diff --git a/packages/core/src/policy/policy-engine.ts b/packages/core/src/policy/policy-engine.ts index f90b905938..3394dc5b30 100644 --- a/packages/core/src/policy/policy-engine.ts +++ b/packages/core/src/policy/policy-engine.ts @@ -310,16 +310,22 @@ export class PolicyEngine { let matchedRule: PolicyRule | undefined; let decision: PolicyDecision | undefined; + // For tools with a server name, we want to try matching both the + // original name and the fully qualified name (server__tool). + const toolCallsToTry: FunctionCall[] = [toolCall]; + if (serverName && toolCall.name && !toolCall.name.includes('__')) { + toolCallsToTry.push({ + ...toolCall, + name: `${serverName}__${toolCall.name}`, + }); + } + for (const rule of this.rules) { - if ( - ruleMatches( - rule, - toolCall, - stringifiedArgs, - serverName, - this.approvalMode, - ) - ) { + const match = toolCallsToTry.some((tc) => + ruleMatches(rule, tc, stringifiedArgs, serverName, this.approvalMode), + ); + + if (match) { debugLogger.debug( `[PolicyEngine.check] MATCHED rule: toolName=${rule.toolName}, decision=${rule.decision}, priority=${rule.priority}, argsPattern=${rule.argsPattern?.source || 'none'}`, );