rework MCP tool discovery and invocation (#13160)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
cornmander
2025-11-17 12:03:48 -05:00
committed by GitHub
parent 394a7ea019
commit 8c78fe4f10
2 changed files with 243 additions and 104 deletions
+137 -58
View File
@@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
import * as GenAiLib from '@google/genai';
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -71,6 +70,21 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(), registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(), setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testFunction',
inputSchema: {
type: 'object',
properties: {},
},
},
],
}),
listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
}; };
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
@@ -78,15 +92,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
{
name: 'testFunction',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
@@ -104,7 +109,7 @@ describe('mcp-client', () => {
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover({} as Config);
expect(mockedMcpToTool).toHaveBeenCalledOnce(); expect(mockedClient.listTools).toHaveBeenCalledWith({});
}); });
it('should not skip tools even if a parameter is missing a type', async () => { it('should not skip tools even if a parameter is missing a type', async () => {
@@ -119,21 +124,12 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(), registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(), setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
tool: vi.fn(),
}; listTools: vi.fn().mockResolvedValue({
vi.mocked(ClientLib.Client).mockReturnValue( tools: [
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{ {
name: 'validTool', name: 'validTool',
parametersJsonSchema: { inputSchema: {
type: 'object', type: 'object',
properties: { properties: {
param1: { type: 'string' }, param1: { type: 'string' },
@@ -142,7 +138,7 @@ describe('mcp-client', () => {
}, },
{ {
name: 'invalidTool', name: 'invalidTool',
parametersJsonSchema: { inputSchema: {
type: 'object', type: 'object',
properties: { properties: {
param1: { description: 'a param with no type' }, param1: { description: 'a param with no type' },
@@ -151,7 +147,17 @@ describe('mcp-client', () => {
}, },
], ],
}), }),
} as unknown as GenAiLib.CallableTool); listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
@@ -183,7 +189,9 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(), registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(), setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockRejectedValue(new Error('Test error')), listTools: vi.fn().mockResolvedValue({ tools: [] }),
listPrompts: vi.fn().mockRejectedValue(new Error('Test error')),
request: vi.fn().mockResolvedValue({}),
}; };
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
@@ -191,9 +199,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => Promise.resolve({ functionDeclarations: [] }),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined), getMessageBus: vi.fn().mockReturnValue(undefined),
@@ -228,7 +233,8 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(), registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(), setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockResolvedValue({ prompts: [] }), listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
}; };
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
@@ -236,7 +242,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
@@ -256,7 +261,6 @@ describe('mcp-client', () => {
await expect(client.discover({} as Config)).rejects.toThrow( await expect(client.discover({} as Config)).rejects.toThrow(
'No prompts or tools found on the server.', 'No prompts or tools found on the server.',
); );
expect(mockedMcpToTool).not.toHaveBeenCalled();
}); });
it('should discover tools if server supports them', async () => { it('should discover tools if server supports them', async () => {
@@ -268,7 +272,17 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(), registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(), setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
request: vi.fn().mockResolvedValue({ prompts: [] }), listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: { type: 'object', properties: {} },
},
],
}),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
}; };
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
@@ -276,17 +290,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'testTool',
description: 'A test tool',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
sortTools: vi.fn(), sortTools: vi.fn(),
@@ -304,10 +307,87 @@ describe('mcp-client', () => {
); );
await client.connect(); await client.connect();
await client.discover({} as Config); await client.discover({} as Config);
expect(mockedMcpToTool).toHaveBeenCalledOnce();
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
}); });
it('should discover tools with $defs and $ref in schema', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'toolWithDefs',
description: 'A tool using $defs',
inputSchema: {
type: 'object',
properties: {
param1: {
$ref: '#/$defs/MyType',
},
},
$defs: {
MyType: {
type: 'string',
description: 'A defined type',
},
},
},
},
],
}),
listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
workspaceContext,
false,
);
await client.connect();
await client.discover({} as Config);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
.calls[0][0];
expect(registeredTool.schema.parametersJsonSchema).toEqual({
type: 'object',
properties: {
param1: {
$ref: '#/$defs/MyType',
},
},
$defs: {
MyType: {
type: 'string',
description: 'A defined type',
},
},
});
});
it('should remove tools and prompts on disconnect', async () => { it('should remove tools and prompts on disconnect', async () => {
const mockedClient = { const mockedClient = {
connect: vi.fn(), connect: vi.fn(),
@@ -318,9 +398,19 @@ describe('mcp-client', () => {
getServerCapabilities: vi getServerCapabilities: vi
.fn() .fn()
.mockReturnValue({ tools: {}, prompts: {} }), .mockReturnValue({ tools: {}, prompts: {} }),
request: vi.fn().mockResolvedValue({ listPrompts: vi.fn().mockResolvedValue({
prompts: [{ id: 'prompt1', text: 'a prompt' }], prompts: [{ id: 'prompt1', text: 'a prompt' }],
}), }),
request: vi.fn().mockResolvedValue({}),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: { type: 'object', properties: {} },
},
],
}),
}; };
vi.mocked(ClientLib.Client).mockReturnValue( vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client, mockedClient as unknown as ClientLib.Client,
@@ -328,17 +418,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport, {} as SdkClientStdioLib.StdioClientTransport,
); );
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'testTool',
description: 'A test tool',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = { const mockedToolRegistry = {
registerTool: vi.fn(), registerTool: vi.fn(),
unregisterTool: vi.fn(), unregisterTool: vi.fn(),
+94 -34
View File
@@ -16,9 +16,8 @@ import type {
Prompt, Prompt,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { import {
GetPromptResultSchema,
ListPromptsResultSchema,
ListRootsRequestSchema, ListRootsRequestSchema,
type Tool as McpTool,
} from '@modelcontextprotocol/sdk/types.js'; } from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote'; import { parse } from 'shell-quote';
import type { Config, MCPServerConfig } from '../config/config.js'; import type { Config, MCPServerConfig } from '../config/config.js';
@@ -27,8 +26,7 @@ import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js'; import { DiscoveredMCPTool } from './mcp-tool.js';
import type { FunctionDeclaration } from '@google/genai'; import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
import { mcpToTool } from '@google/genai';
import { basename } from 'node:path'; import { basename } from 'node:path';
import { pathToFileURL } from 'node:url'; import { pathToFileURL } from 'node:url';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
@@ -621,29 +619,26 @@ export async function discoverTools(
// Only request tools if the server supports them. // Only request tools if the server supports them.
if (mcpClient.getServerCapabilities()?.tools == null) return []; if (mcpClient.getServerCapabilities()?.tools == null) return [];
const mcpCallableTool = mcpToTool(mcpClient, { const response = await mcpClient.listTools({});
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
const tool = await mcpCallableTool.tool();
if (!Array.isArray(tool.functionDeclarations)) {
// This is a valid case for a prompt-only server
return [];
}
const discoveredTools: DiscoveredMCPTool[] = []; const discoveredTools: DiscoveredMCPTool[] = [];
for (const funcDecl of tool.functionDeclarations) { for (const toolDef of response.tools) {
try { try {
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) { if (!isEnabled(toolDef, mcpServerName, mcpServerConfig)) {
continue; continue;
} }
const mcpCallableTool = new McpCallableTool(
mcpClient,
toolDef,
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
);
const tool = new DiscoveredMCPTool( const tool = new DiscoveredMCPTool(
mcpCallableTool, mcpCallableTool,
mcpServerName, mcpServerName,
funcDecl.name!, toolDef.name,
funcDecl.description ?? '', toolDef.description ?? '',
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} }, toolDef.inputSchema ?? { type: 'object', properties: {} },
mcpServerConfig.trust, mcpServerConfig.trust,
undefined, undefined,
cliConfig, cliConfig,
@@ -657,7 +652,7 @@ export async function discoverTools(
coreEvents.emitFeedback( coreEvents.emitFeedback(
'error', 'error',
`Error discovering tool: '${ `Error discovering tool: '${
funcDecl.name toolDef.name
}' from MCP server '${mcpServerName}': ${(error as Error).message}`, }' from MCP server '${mcpServerName}': ${(error as Error).message}`,
error, error,
); );
@@ -681,6 +676,69 @@ export async function discoverTools(
} }
} }
class McpCallableTool implements CallableTool {
constructor(
private readonly client: Client,
private readonly toolDef: McpTool,
private readonly timeout: number,
) {}
async tool(): Promise<Tool> {
return {
functionDeclarations: [
{
name: this.toolDef.name,
description: this.toolDef.description,
parametersJsonSchema: this.toolDef.inputSchema,
},
],
};
}
async callTool(functionCalls: FunctionCall[]): Promise<Part[]> {
// We only expect one function call at a time for MCP tools in this context
if (functionCalls.length !== 1) {
throw new Error('McpCallableTool only supports single function call');
}
const call = functionCalls[0];
try {
const result = await this.client.callTool(
{
name: call.name!,
arguments: call.args as Record<string, unknown>,
},
undefined,
{ timeout: this.timeout },
);
return [
{
functionResponse: {
name: call.name,
response: result,
},
},
];
} catch (error) {
// Return error in the format expected by DiscoveredMCPTool
return [
{
functionResponse: {
name: call.name,
response: {
error: {
message: error instanceof Error ? error.message : String(error),
isError: true,
},
},
},
},
];
}
}
}
/** /**
* Discovers and logs prompts from a connected MCP client. * Discovers and logs prompts from a connected MCP client.
* It retrieves prompt declarations from the client and logs their names. * It retrieves prompt declarations from the client and logs their names.
@@ -697,10 +755,7 @@ export async function discoverPrompts(
// Only request prompts if the server supports them. // Only request prompts if the server supports them.
if (mcpClient.getServerCapabilities()?.prompts == null) return []; if (mcpClient.getServerCapabilities()?.prompts == null) return [];
const response = await mcpClient.request( const response = await mcpClient.listPrompts({});
{ method: 'prompts/list', params: {} },
ListPromptsResultSchema,
);
for (const prompt of response.prompts) { for (const prompt of response.prompts) {
promptRegistry.registerPrompt({ promptRegistry.registerPrompt({
@@ -746,16 +801,17 @@ export async function invokeMcpPrompt(
promptParams: Record<string, unknown>, promptParams: Record<string, unknown>,
): Promise<GetPromptResult> { ): Promise<GetPromptResult> {
try { try {
const response = await mcpClient.request( const sanitizedParams: Record<string, string> = {};
{ for (const [key, value] of Object.entries(promptParams)) {
method: 'prompts/get', if (value !== undefined && value !== null) {
params: { sanitizedParams[key] = String(value);
}
}
const response = await mcpClient.getPrompt({
name: promptName, name: promptName,
arguments: promptParams, arguments: sanitizedParams,
}, });
},
GetPromptResultSchema,
);
return response; return response;
} catch (error) { } catch (error) {
@@ -1339,9 +1395,13 @@ export async function createTransport(
); );
} }
interface NamedTool {
name?: string;
}
/** Visible for testing */ /** Visible for testing */
export function isEnabled( export function isEnabled(
funcDecl: FunctionDeclaration, funcDecl: NamedTool,
mcpServerName: string, mcpServerName: string,
mcpServerConfig: MCPServerConfig, mcpServerConfig: MCPServerConfig,
): boolean { ): boolean {