/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ /* eslint-disable @typescript-eslint/no-explicit-any */ import { describe, it, expect, vi, beforeEach, afterEach, type Mocked, type MockInstance, } from 'vitest'; import { Config, type ConfigParameters } from '../config/config.js'; import { ApprovalMode } from '../policy/types.js'; import { ToolRegistry, DiscoveredTool } from './tool-registry.js'; import { DISCOVERED_TOOL_PREFIX } from './tool-names.js'; import { DiscoveredMCPTool, MCP_QUALIFIED_NAME_SEPARATOR } from './mcp-tool.js'; import { mcpToTool, type FunctionDeclaration, type CallableTool, } from '@google/genai'; import { spawn } from 'node:child_process'; import fs from 'node:fs'; import { MockTool } from '../test-utils/mock-tool.js'; import { ToolErrorType } from './tool-error.js'; import type { MessageBus } from '../confirmation-bus/message-bus.js'; vi.mock('node:fs'); // Mock node:child_process vi.mock('node:child_process', async () => { const actual = await vi.importActual('node:child_process'); return { ...actual, execSync: vi.fn(), spawn: vi.fn(), }; }); // Mock MCP SDK Client and Transports const mockMcpClientConnect = vi.fn(); const mockMcpClientOnError = vi.fn(); const mockStdioTransportClose = vi.fn(); const mockSseTransportClose = vi.fn(); vi.mock('@modelcontextprotocol/sdk/client/index.js', () => { const MockClient = vi.fn().mockImplementation(() => ({ connect: mockMcpClientConnect, set onerror(handler: any) { mockMcpClientOnError(handler); }, })); return { Client: MockClient }; }); vi.mock('@modelcontextprotocol/sdk/client/stdio.js', () => { const MockStdioClientTransport = vi.fn().mockImplementation(() => ({ stderr: { on: vi.fn(), }, close: mockStdioTransportClose, })); return { StdioClientTransport: MockStdioClientTransport }; }); vi.mock('@modelcontextprotocol/sdk/client/sse.js', () => { const MockSSEClientTransport = vi.fn().mockImplementation(() => ({ close: mockSseTransportClose, })); return { SSEClientTransport: MockSSEClientTransport }; }); // Mock @google/genai mcpToTool vi.mock('@google/genai', async () => { const actualGenai = await vi.importActual('@google/genai'); return { ...actualGenai, mcpToTool: vi.fn().mockImplementation(() => ({ tool: vi.fn().mockResolvedValue({ functionDeclarations: [] }), callTool: vi.fn(), })), }; }); // Mock tool-names to provide a consistent alias for testing vi.mock('./tool-names.js', async (importOriginal) => { const actual = await importOriginal(); const mockedAliases: Record = { ...actual.TOOL_LEGACY_ALIASES, legacy_test_tool: 'current_test_tool', }; return { ...actual, TOOL_LEGACY_ALIASES: mockedAliases, // Override getToolAliases to use the mocked aliases map getToolAliases: (name: string): string[] => { const aliases = new Set([name]); const canonicalName = mockedAliases[name] ?? name; aliases.add(canonicalName); for (const [legacyName, currentName] of Object.entries(mockedAliases)) { if (currentName === canonicalName) { aliases.add(legacyName); } } return Array.from(aliases); }, }; }); // Helper to create a mock CallableTool for specific test needs const createMockCallableTool = ( toolDeclarations: FunctionDeclaration[], ): Mocked => ({ tool: vi.fn().mockResolvedValue({ functionDeclarations: toolDeclarations }), callTool: vi.fn(), }); // Helper to create a DiscoveredMCPTool const mockMessageBusForHelper = { publish: vi.fn(), subscribe: vi.fn(), unsubscribe: vi.fn(), } as unknown as MessageBus; const createMCPTool = ( serverName: string, toolName: string, description: string, mockCallable: CallableTool = {} as CallableTool, ) => new DiscoveredMCPTool( mockCallable, serverName, toolName, description, {}, mockMessageBusForHelper, ); // Helper to create a mock spawn process for tool discovery const createDiscoveryProcess = (toolDeclarations: FunctionDeclaration[]) => { const mockProcess = { stdout: { on: vi.fn(), removeListener: vi.fn() }, stderr: { on: vi.fn(), removeListener: vi.fn() }, on: vi.fn(), }; mockProcess.stdout.on.mockImplementation((event, callback) => { if (event === 'data') { callback( Buffer.from( JSON.stringify([{ functionDeclarations: toolDeclarations }]), ), ); } return mockProcess as any; }); mockProcess.on.mockImplementation((event, callback) => { if (event === 'close') { callback(0); } return mockProcess as any; }); return mockProcess; }; // Helper to create a mock spawn process for tool execution const createExecutionProcess = (exitCode: number, stderrMessage?: string) => { const mockProcess = { stdout: { on: vi.fn(), removeListener: vi.fn() }, stderr: { on: vi.fn(), removeListener: vi.fn() }, stdin: { write: vi.fn(), end: vi.fn() }, on: vi.fn(), connected: true, disconnect: vi.fn(), removeListener: vi.fn(), }; if (stderrMessage) { mockProcess.stderr.on.mockImplementation((event, callback) => { if (event === 'data') { callback(Buffer.from(stderrMessage)); } }); } mockProcess.on.mockImplementation((event, callback) => { if (event === 'close') { callback(exitCode); } }); return mockProcess; }; const baseConfigParams: ConfigParameters = { cwd: '/tmp', model: 'test-model', embeddingModel: 'test-embedding-model', sandbox: undefined, targetDir: '/test/dir', debugMode: false, userMemory: '', geminiMdFileCount: 0, approvalMode: ApprovalMode.DEFAULT, sessionId: 'test-session-id', }; describe('ToolRegistry', () => { let config: Config; let toolRegistry: ToolRegistry; const mockMessageBus = { publish: vi.fn(), subscribe: vi.fn(), unsubscribe: vi.fn(), } as unknown as MessageBus; let mockConfigGetToolDiscoveryCommand: ReturnType; let mockConfigGetExcludedTools: MockInstance< typeof Config.prototype.getExcludeTools >; beforeEach(() => { vi.mocked(fs.existsSync).mockReturnValue(true); vi.mocked(fs.statSync).mockReturnValue({ isDirectory: () => true, } as fs.Stats); config = new Config(baseConfigParams); toolRegistry = new ToolRegistry(config, mockMessageBus); vi.spyOn(console, 'warn').mockImplementation(() => {}); vi.spyOn(console, 'error').mockImplementation(() => {}); vi.spyOn(console, 'debug').mockImplementation(() => {}); vi.spyOn(console, 'log').mockImplementation(() => {}); mockMcpClientConnect.mockReset().mockResolvedValue(undefined); mockStdioTransportClose.mockReset(); mockSseTransportClose.mockReset(); vi.mocked(mcpToTool).mockClear(); vi.mocked(mcpToTool).mockReturnValue(createMockCallableTool([])); mockConfigGetToolDiscoveryCommand = vi.spyOn( config, 'getToolDiscoveryCommand', ); mockConfigGetExcludedTools = vi.spyOn(config, 'getExcludeTools'); vi.spyOn(config, 'getMcpServers'); vi.spyOn(config, 'getMcpServerCommand'); vi.spyOn(config, 'getPromptRegistry').mockReturnValue({ clear: vi.fn(), removePromptsByServer: vi.fn(), } as any); }); afterEach(() => { vi.restoreAllMocks(); }); describe('registerTool', () => { it('should register a new tool', () => { const tool = new MockTool({ name: 'mock-tool' }); toolRegistry.registerTool(tool); expect(toolRegistry.getTool('mock-tool')).toBe(tool); }); it('should pass modelId to getSchema when getting function declarations', () => { const tool = new MockTool({ name: 'mock-tool' }); const getSchemaSpy = vi.spyOn(tool, 'getSchema'); toolRegistry.registerTool(tool); const modelId = 'test-model-id'; toolRegistry.getFunctionDeclarations(modelId); expect(getSchemaSpy).toHaveBeenCalledWith(modelId); }); }); describe('excluded tools', () => { const simpleTool = new MockTool({ name: 'tool-a', displayName: 'Tool a', }); const excludedTool = new ExcludedMockTool({ name: 'excluded-tool-class', displayName: 'Excluded Tool Class', }); const mcpTool = createMCPTool( 'mcp-server', 'excluded-mcp-tool', 'description', ); const allowedTool = new MockTool({ name: 'allowed-tool', displayName: 'Allowed Tool', }); it.each([ { name: 'should match simple names', tools: [simpleTool], excludedTools: ['tool-a'], }, { name: 'should match simple MCP tool names, when qualified or unqualified', tools: [mcpTool, mcpTool.asFullyQualifiedTool()], excludedTools: [mcpTool.name], }, { name: 'should match qualified MCP tool names when qualified or unqualified', tools: [mcpTool, mcpTool.asFullyQualifiedTool()], excludedTools: [`${mcpTool.getFullyQualifiedPrefix()}${mcpTool.name}`], }, { name: 'should match class names', tools: [excludedTool], excludedTools: ['ExcludedMockTool'], }, { name: 'should exclude a tool when its legacy alias is in excludeTools', tools: [ new MockTool({ name: 'current_test_tool', displayName: 'Current Test Tool', }), ], excludedTools: ['legacy_test_tool'], }, { name: 'should exclude a tool when its current name is in excludeTools and tool is registered under current name', tools: [ new MockTool({ name: 'current_test_tool', displayName: 'Current Test Tool', }), ], excludedTools: ['current_test_tool'], }, ])('$name', ({ tools, excludedTools }) => { toolRegistry.registerTool(allowedTool); for (const tool of tools) { toolRegistry.registerTool(tool); } mockConfigGetExcludedTools.mockReturnValue(new Set(excludedTools)); expect(toolRegistry.getAllTools()).toEqual([allowedTool]); expect(toolRegistry.getAllToolNames()).toEqual([allowedTool.name]); expect(toolRegistry.getFunctionDeclarations()).toEqual( toolRegistry.getFunctionDeclarationsFiltered([allowedTool.name]), ); for (const tool of tools) { expect(toolRegistry.getTool(tool.name)).toBeUndefined(); expect( toolRegistry.getFunctionDeclarationsFiltered([tool.name]), ).toHaveLength(0); if (tool instanceof DiscoveredMCPTool) { expect(toolRegistry.getToolsByServer(tool.serverName)).toHaveLength( 0, ); } } }); }); describe('getAllTools', () => { it('should return all registered tools sorted alphabetically by displayName', () => { // Register tools with displayNames in non-alphabetical order const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' }); const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' }); const toolB = new MockTool({ name: 'b-tool', displayName: 'Tool B' }); toolRegistry.registerTool(toolC); toolRegistry.registerTool(toolA); toolRegistry.registerTool(toolB); const allTools = toolRegistry.getAllTools(); const displayNames = allTools.map((t) => t.displayName); // Assert that the returned array is sorted by displayName expect(displayNames).toEqual(['Tool A', 'Tool B', 'Tool C']); }); }); describe('getAllToolNames', () => { it('should return all registered tool names with qualified names for MCP tools', () => { // Register tools with displayNames in non-alphabetical order const toolC = new MockTool({ name: 'c-tool', displayName: 'Tool C' }); const toolA = new MockTool({ name: 'a-tool', displayName: 'Tool A' }); const mcpTool = createMCPTool('my-server', 'my-tool', 'desc'); toolRegistry.registerTool(toolC); toolRegistry.registerTool(toolA); toolRegistry.registerTool(mcpTool); const toolNames = toolRegistry.getAllToolNames(); // Assert that the returned array contains all tool names, with MCP qualified expect(toolNames).toContain('c-tool'); expect(toolNames).toContain('a-tool'); expect(toolNames).toContain('my-server__my-tool'); expect(toolNames).toHaveLength(3); }); it('should deduplicate tool names', () => { const serverName = 'my-server'; const toolName = 'my-tool'; const mcpTool = createMCPTool(serverName, toolName, 'desc'); // Register same MCP tool twice (one as alias, one as qualified) toolRegistry.registerTool(mcpTool); toolRegistry.registerTool(mcpTool.asFullyQualifiedTool()); const toolNames = toolRegistry.getAllToolNames(); expect(toolNames).toEqual([`${serverName}__${toolName}`]); }); }); describe('getToolsByServer', () => { it('should return an empty array if no tools match the server name', () => { toolRegistry.registerTool(new MockTool({ name: 'mock-tool' })); expect(toolRegistry.getToolsByServer('any-mcp-server')).toEqual([]); }); it('should return only tools matching the server name, sorted by name', async () => { const server1Name = 'mcp-server-uno'; const server2Name = 'mcp-server-dos'; const mcpTool1_c = createMCPTool(server1Name, 'zebra-tool', 'd1'); const mcpTool1_a = createMCPTool(server1Name, 'apple-tool', 'd2'); const mcpTool1_b = createMCPTool(server1Name, 'banana-tool', 'd3'); const mcpTool2 = createMCPTool(server2Name, 'tool-on-server2', 'd4'); const nonMcpTool = new MockTool({ name: 'regular-tool' }); toolRegistry.registerTool(mcpTool1_c); toolRegistry.registerTool(mcpTool1_a); toolRegistry.registerTool(mcpTool1_b); toolRegistry.registerTool(mcpTool2); toolRegistry.registerTool(nonMcpTool); const toolsFromServer1 = toolRegistry.getToolsByServer(server1Name); const toolNames = toolsFromServer1.map((t) => t.name); // Assert that the array has the correct tools and is sorted by name expect(toolsFromServer1).toHaveLength(3); expect(toolNames).toEqual(['apple-tool', 'banana-tool', 'zebra-tool']); // Assert that all returned tools are indeed from the correct server for (const tool of toolsFromServer1) { expect((tool as DiscoveredMCPTool).serverName).toBe(server1Name); } // Assert that the other server's tools are returned correctly const toolsFromServer2 = toolRegistry.getToolsByServer(server2Name); expect(toolsFromServer2).toHaveLength(1); expect(toolsFromServer2[0].name).toBe(mcpTool2.name); }); }); describe('sortTools', () => { it('should sort tools by priority: built-in, discovered, then MCP (by server name)', () => { const builtIn1 = new MockTool({ name: 'builtin-1' }); const builtIn2 = new MockTool({ name: 'builtin-2' }); const discovered1 = new DiscoveredTool( config, 'discovered-1', DISCOVERED_TOOL_PREFIX + 'discovered-1', 'desc', {}, mockMessageBus, ); const mcpZebra = createMCPTool('zebra-server', 'mcp-zebra', 'desc'); const mcpApple = createMCPTool('apple-server', 'mcp-apple', 'desc'); // Register in mixed order toolRegistry.registerTool(mcpZebra); toolRegistry.registerTool(discovered1); toolRegistry.registerTool(builtIn1); toolRegistry.registerTool(mcpApple); toolRegistry.registerTool(builtIn2); toolRegistry.sortTools(); expect(toolRegistry.getAllToolNames()).toEqual([ 'builtin-1', 'builtin-2', DISCOVERED_TOOL_PREFIX + 'discovered-1', 'apple-server__mcp-apple', 'zebra-server__mcp-zebra', ]); }); }); describe('discoverTools', () => { it('should will preserve tool parametersJsonSchema during discovery from command', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); const unsanitizedToolDeclaration: FunctionDeclaration = { name: 'tool-with-bad-format', description: 'A tool with an invalid format property', parametersJsonSchema: { type: 'object', properties: { some_string: { type: 'string', format: 'uuid', // This is an unsupported format }, }, }, }; const mockSpawn = vi.mocked(spawn); mockSpawn.mockReturnValue( createDiscoveryProcess([unsanitizedToolDeclaration]) as any, ); await toolRegistry.discoverAllTools(); const discoveredTool = toolRegistry.getTool( DISCOVERED_TOOL_PREFIX + 'tool-with-bad-format', ); expect(discoveredTool).toBeDefined(); const registeredParams = (discoveredTool as DiscoveredTool).schema .parametersJsonSchema; expect(registeredParams).toStrictEqual({ type: 'object', properties: { some_string: { type: 'string', format: 'uuid', }, }, }); }); it('should return a DISCOVERED_TOOL_EXECUTION_ERROR on tool failure', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); vi.spyOn(config, 'getToolCallCommand').mockReturnValue('my-call-command'); const toolDeclaration: FunctionDeclaration = { name: 'failing-tool', description: 'A tool that fails', parametersJsonSchema: { type: 'object', properties: {}, }, }; const mockSpawn = vi.mocked(spawn); mockSpawn.mockReturnValueOnce( createDiscoveryProcess([toolDeclaration]) as any, ); await toolRegistry.discoverAllTools(); const discoveredTool = toolRegistry.getTool( DISCOVERED_TOOL_PREFIX + 'failing-tool', ); expect(discoveredTool).toBeDefined(); mockSpawn.mockReturnValueOnce( createExecutionProcess(1, 'Something went wrong') as any, ); const invocation = (discoveredTool as DiscoveredTool).build({}); const result = await invocation.execute(new AbortController().signal); expect(result.error?.type).toBe( ToolErrorType.DISCOVERED_TOOL_EXECUTION_ERROR, ); expect(result.llmContent).toContain('Stderr: Something went wrong'); expect(result.llmContent).toContain('Exit Code: 1'); }); it('should pass MessageBus to DiscoveredTool and its invocations', async () => { const discoveryCommand = 'my-discovery-command'; mockConfigGetToolDiscoveryCommand.mockReturnValue(discoveryCommand); const toolDeclaration: FunctionDeclaration = { name: 'policy-test-tool', description: 'tests policy', parametersJsonSchema: { type: 'object', properties: {} }, }; const mockSpawn = vi.mocked(spawn); mockSpawn.mockReturnValueOnce( createDiscoveryProcess([toolDeclaration]) as any, ); await toolRegistry.discoverAllTools(); const tool = toolRegistry.getTool( DISCOVERED_TOOL_PREFIX + 'policy-test-tool', ); expect(tool).toBeDefined(); expect((tool as any).messageBus).toBe(mockMessageBus); const invocation = tool!.build({}); expect((invocation as any).messageBus).toBe(mockMessageBus); }); }); describe('getTool', () => { it('should retrieve an MCP tool by its fully qualified name even if registered with simple name', () => { const serverName = 'my-server'; const toolName = 'my-tool'; const mcpTool = createMCPTool(serverName, toolName, 'description'); // Register tool (will be registered as 'my-tool' since no conflict) toolRegistry.registerTool(mcpTool); // Verify it is available as 'my-tool' expect(toolRegistry.getTool('my-tool')).toBeDefined(); expect(toolRegistry.getTool('my-tool')?.name).toBe('my-tool'); // Verify it is available as 'my-server__my-tool' const fullyQualifiedName = `${serverName}__${toolName}`; const retrievedTool = toolRegistry.getTool(fullyQualifiedName); expect(retrievedTool).toBeDefined(); // The returned tool object is the same, so its name property is still 'my-tool' expect(retrievedTool?.name).toBe('my-tool'); }); it('should retrieve an MCP tool by its fully qualified name when tool name has special characters', () => { const serverName = 'my-server'; // Use a space which is invalid and will be replaced by underscore const toolName = 'my tool'; const validToolName = 'my_tool'; const mcpTool = createMCPTool(serverName, toolName, 'description'); // Register tool (will be registered as sanitized name) toolRegistry.registerTool(mcpTool); // Verify it is available as sanitized name expect(toolRegistry.getTool(validToolName)).toBeDefined(); expect(toolRegistry.getTool(validToolName)?.name).toBe(validToolName); // Verify it is available as 'my-server__my_tool' const fullyQualifiedName = `${serverName}__${validToolName}`; const retrievedTool = toolRegistry.getTool(fullyQualifiedName); expect(retrievedTool).toBeDefined(); expect(retrievedTool?.name).toBe(validToolName); }); it('should resolve qualified names in getFunctionDeclarationsFiltered', () => { const serverName = 'my-server'; const toolName = 'my-tool'; const mcpTool = createMCPTool(serverName, toolName, 'description'); toolRegistry.registerTool(mcpTool); const fullyQualifiedName = `${serverName}${MCP_QUALIFIED_NAME_SEPARATOR}${toolName}`; const declarations = toolRegistry.getFunctionDeclarationsFiltered([ fullyQualifiedName, ]); expect(declarations).toHaveLength(1); expect(declarations[0].name).toBe(toolName); }); it('should retrieve a tool using its legacy alias', async () => { const legacyName = 'legacy_test_tool'; const currentName = 'current_test_tool'; const mockTool = new MockTool({ name: currentName, description: 'Test Tool', messageBus: mockMessageBus, }); toolRegistry.registerTool(mockTool); const retrievedTool = toolRegistry.getTool(legacyName); expect(retrievedTool).toBeDefined(); expect(retrievedTool?.name).toBe(currentName); }); }); describe('getFunctionDeclarations', () => { it('should use fully qualified names for MCP tools in declarations', () => { const serverName = 'my-server'; const toolName = 'my-tool'; const mcpTool = createMCPTool(serverName, toolName, 'description'); toolRegistry.registerTool(mcpTool); const declarations = toolRegistry.getFunctionDeclarations(); expect(declarations).toHaveLength(1); expect(declarations[0].name).toBe(`${serverName}__${toolName}`); }); it('should deduplicate MCP tools in declarations', () => { const serverName = 'my-server'; const toolName = 'my-tool'; const mcpTool = createMCPTool(serverName, toolName, 'description'); // Register both alias and qualified toolRegistry.registerTool(mcpTool); toolRegistry.registerTool(mcpTool.asFullyQualifiedTool()); const declarations = toolRegistry.getFunctionDeclarations(); expect(declarations).toHaveLength(1); expect(declarations[0].name).toBe(`${serverName}__${toolName}`); }); }); describe('plan mode', () => { it('should only return policy-allowed tools in plan mode', () => { // Register several tools const globTool = new MockTool({ name: 'glob', displayName: 'Glob' }); const readFileTool = new MockTool({ name: 'read_file', displayName: 'ReadFile', }); const shellTool = new MockTool({ name: 'shell', displayName: 'Shell' }); const writeTool = new MockTool({ name: 'write_file', displayName: 'WriteFile', }); toolRegistry.registerTool(globTool); toolRegistry.registerTool(readFileTool); toolRegistry.registerTool(shellTool); toolRegistry.registerTool(writeTool); // Mock config in PLAN mode: exclude shell and write_file mockConfigGetExcludedTools.mockReturnValue( new Set(['shell', 'write_file']), ); const allTools = toolRegistry.getAllTools(); const toolNames = allTools.map((t) => t.name); expect(toolNames).toContain('glob'); expect(toolNames).toContain('read_file'); expect(toolNames).not.toContain('shell'); expect(toolNames).not.toContain('write_file'); }); it('should include read-only MCP tools when allowed by policy in plan mode', () => { const readOnlyMcp = createMCPTool( 'test-server', 'read-only-tool', 'A read-only MCP tool', ); // Set readOnlyHint to true via toolAnnotations Object.defineProperty(readOnlyMcp, 'isReadOnly', { value: true }); toolRegistry.registerTool(readOnlyMcp); // Policy allows this tool (not in excluded set) mockConfigGetExcludedTools.mockReturnValue(new Set()); const allTools = toolRegistry.getAllTools(); const toolNames = allTools.map((t) => t.name); expect(toolNames).toContain('read-only-tool'); }); it('should exclude non-read-only MCP tools when denied by policy in plan mode', () => { const writeMcp = createMCPTool( 'test-server', 'write-mcp-tool', 'A write MCP tool', ); toolRegistry.registerTool(writeMcp); // Policy excludes this tool mockConfigGetExcludedTools.mockReturnValue(new Set(['write-mcp-tool'])); const allTools = toolRegistry.getAllTools(); const toolNames = allTools.map((t) => t.name); expect(toolNames).not.toContain('write-mcp-tool'); }); }); describe('DiscoveredToolInvocation', () => { it('should return the stringified params from getDescription', () => { const tool = new DiscoveredTool( config, 'test-tool', DISCOVERED_TOOL_PREFIX + 'test-tool', 'A test tool', {}, mockMessageBus, ); const params = { param: 'testValue' }; const invocation = tool.build(params); const description = invocation.getDescription(); expect(description).toBe(JSON.stringify(params)); }); }); }); /** * Used for tests that exclude by class name. */ class ExcludedMockTool extends MockTool { constructor(options: ConstructorParameters[0]) { super(options); } }