diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index bf06e4179c..4cefb7011c 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -4,7 +4,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -import * as GenAiLib from '@google/genai'; import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; @@ -71,6 +70,21 @@ describe('mcp-client', () => { registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'testFunction', + inputSchema: { + type: 'object', + properties: {}, + }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ + prompts: [], + }), + request: vi.fn().mockResolvedValue({}), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -78,15 +92,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => ({ - functionDeclarations: [ - { - name: 'testFunction', - }, - ], - }), - } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), @@ -104,7 +109,7 @@ describe('mcp-client', () => { ); await client.connect(); await client.discover({} as Config); - expect(mockedMcpToTool).toHaveBeenCalledOnce(); + expect(mockedClient.listTools).toHaveBeenCalledWith({}); }); it('should not skip tools even if a parameter is missing a type', async () => { @@ -119,7 +124,33 @@ describe('mcp-client', () => { registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), - tool: vi.fn(), + + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'validTool', + inputSchema: { + type: 'object', + properties: { + param1: { type: 'string' }, + }, + }, + }, + { + name: 'invalidTool', + inputSchema: { + type: 'object', + properties: { + param1: { description: 'a param with no type' }, + }, + }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ + prompts: [], + }), + request: vi.fn().mockResolvedValue({}), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -127,31 +158,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'validTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { type: 'string' }, - }, - }, - }, - { - name: 'invalidTool', - parametersJsonSchema: { - type: 'object', - properties: { - param1: { description: 'a param with no type' }, - }, - }, - }, - ], - }), - } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), @@ -183,7 +189,9 @@ describe('mcp-client', () => { registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), - request: vi.fn().mockRejectedValue(new Error('Test error')), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockRejectedValue(new Error('Test error')), + request: vi.fn().mockResolvedValue({}), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -191,9 +199,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => Promise.resolve({ functionDeclarations: [] }), - } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), getMessageBus: vi.fn().mockReturnValue(undefined), @@ -228,7 +233,8 @@ describe('mcp-client', () => { registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), - request: vi.fn().mockResolvedValue({ prompts: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -236,7 +242,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool); const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), @@ -256,7 +261,6 @@ describe('mcp-client', () => { await expect(client.discover({} as Config)).rejects.toThrow( 'No prompts or tools found on the server.', ); - expect(mockedMcpToTool).not.toHaveBeenCalled(); }); it('should discover tools if server supports them', async () => { @@ -268,7 +272,17 @@ describe('mcp-client', () => { registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), - request: vi.fn().mockResolvedValue({ prompts: [] }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'testTool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -276,17 +290,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'testTool', - description: 'A test tool', - }, - ], - }), - } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), sortTools: vi.fn(), @@ -304,10 +307,87 @@ describe('mcp-client', () => { ); await client.connect(); await client.discover({} as Config); - expect(mockedMcpToTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); }); + it('should discover tools with $defs and $ref in schema', async () => { + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'toolWithDefs', + description: 'A tool using $defs', + inputSchema: { + type: 'object', + properties: { + param1: { + $ref: '#/$defs/MyType', + }, + }, + $defs: { + MyType: { + type: 'string', + description: 'A defined type', + }, + }, + }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ + prompts: [], + }), + request: vi.fn().mockResolvedValue({}), + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + const mockedToolRegistry = { + registerTool: vi.fn(), + sortTools: vi.fn(), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + const client = new McpClient( + 'test-server', + { + command: 'test-command', + }, + mockedToolRegistry, + {} as PromptRegistry, + workspaceContext, + false, + ); + await client.connect(); + await client.discover({} as Config); + expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); + const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock + .calls[0][0]; + expect(registeredTool.schema.parametersJsonSchema).toEqual({ + type: 'object', + properties: { + param1: { + $ref: '#/$defs/MyType', + }, + }, + $defs: { + MyType: { + type: 'string', + description: 'A defined type', + }, + }, + }); + }); + it('should remove tools and prompts on disconnect', async () => { const mockedClient = { connect: vi.fn(), @@ -318,9 +398,19 @@ describe('mcp-client', () => { getServerCapabilities: vi .fn() .mockReturnValue({ tools: {}, prompts: {} }), - request: vi.fn().mockResolvedValue({ + listPrompts: vi.fn().mockResolvedValue({ prompts: [{ id: 'prompt1', text: 'a prompt' }], }), + request: vi.fn().mockResolvedValue({}), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'testTool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, @@ -328,17 +418,6 @@ describe('mcp-client', () => { vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); - vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ - tool: () => - Promise.resolve({ - functionDeclarations: [ - { - name: 'testTool', - description: 'A test tool', - }, - ], - }), - } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), unregisterTool: vi.fn(), diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 7b4ce0b1fd..dc135e0ce4 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -16,9 +16,8 @@ import type { Prompt, } from '@modelcontextprotocol/sdk/types.js'; import { - GetPromptResultSchema, - ListPromptsResultSchema, ListRootsRequestSchema, + type Tool as McpTool, } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import type { Config, MCPServerConfig } from '../config/config.js'; @@ -27,8 +26,7 @@ import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; -import type { FunctionDeclaration } from '@google/genai'; -import { mcpToTool } from '@google/genai'; +import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai'; import { basename } from 'node:path'; import { pathToFileURL } from 'node:url'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; @@ -621,29 +619,26 @@ export async function discoverTools( // Only request tools if the server supports them. if (mcpClient.getServerCapabilities()?.tools == null) return []; - const mcpCallableTool = mcpToTool(mcpClient, { - timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - }); - const tool = await mcpCallableTool.tool(); - - if (!Array.isArray(tool.functionDeclarations)) { - // This is a valid case for a prompt-only server - return []; - } - + const response = await mcpClient.listTools({}); const discoveredTools: DiscoveredMCPTool[] = []; - for (const funcDecl of tool.functionDeclarations) { + for (const toolDef of response.tools) { try { - if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) { + if (!isEnabled(toolDef, mcpServerName, mcpServerConfig)) { continue; } + const mcpCallableTool = new McpCallableTool( + mcpClient, + toolDef, + mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + ); + const tool = new DiscoveredMCPTool( mcpCallableTool, mcpServerName, - funcDecl.name!, - funcDecl.description ?? '', - funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} }, + toolDef.name, + toolDef.description ?? '', + toolDef.inputSchema ?? { type: 'object', properties: {} }, mcpServerConfig.trust, undefined, cliConfig, @@ -657,7 +652,7 @@ export async function discoverTools( coreEvents.emitFeedback( 'error', `Error discovering tool: '${ - funcDecl.name + toolDef.name }' from MCP server '${mcpServerName}': ${(error as Error).message}`, error, ); @@ -681,6 +676,69 @@ export async function discoverTools( } } +class McpCallableTool implements CallableTool { + constructor( + private readonly client: Client, + private readonly toolDef: McpTool, + private readonly timeout: number, + ) {} + + async tool(): Promise { + return { + functionDeclarations: [ + { + name: this.toolDef.name, + description: this.toolDef.description, + parametersJsonSchema: this.toolDef.inputSchema, + }, + ], + }; + } + + async callTool(functionCalls: FunctionCall[]): Promise { + // We only expect one function call at a time for MCP tools in this context + if (functionCalls.length !== 1) { + throw new Error('McpCallableTool only supports single function call'); + } + const call = functionCalls[0]; + + try { + const result = await this.client.callTool( + { + name: call.name!, + arguments: call.args as Record, + }, + undefined, + { timeout: this.timeout }, + ); + + return [ + { + functionResponse: { + name: call.name, + response: result, + }, + }, + ]; + } catch (error) { + // Return error in the format expected by DiscoveredMCPTool + return [ + { + functionResponse: { + name: call.name, + response: { + error: { + message: error instanceof Error ? error.message : String(error), + isError: true, + }, + }, + }, + }, + ]; + } + } +} + /** * Discovers and logs prompts from a connected MCP client. * It retrieves prompt declarations from the client and logs their names. @@ -697,10 +755,7 @@ export async function discoverPrompts( // Only request prompts if the server supports them. if (mcpClient.getServerCapabilities()?.prompts == null) return []; - const response = await mcpClient.request( - { method: 'prompts/list', params: {} }, - ListPromptsResultSchema, - ); + const response = await mcpClient.listPrompts({}); for (const prompt of response.prompts) { promptRegistry.registerPrompt({ @@ -746,16 +801,17 @@ export async function invokeMcpPrompt( promptParams: Record, ): Promise { try { - const response = await mcpClient.request( - { - method: 'prompts/get', - params: { - name: promptName, - arguments: promptParams, - }, - }, - GetPromptResultSchema, - ); + const sanitizedParams: Record = {}; + for (const [key, value] of Object.entries(promptParams)) { + if (value !== undefined && value !== null) { + sanitizedParams[key] = String(value); + } + } + + const response = await mcpClient.getPrompt({ + name: promptName, + arguments: sanitizedParams, + }); return response; } catch (error) { @@ -1339,9 +1395,13 @@ export async function createTransport( ); } +interface NamedTool { + name?: string; +} + /** Visible for testing */ export function isEnabled( - funcDecl: FunctionDeclaration, + funcDecl: NamedTool, mcpServerName: string, mcpServerConfig: MCPServerConfig, ): boolean {