rework MCP tool discovery and invocation (#13160)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
cornmander
2025-11-17 12:03:48 -05:00
committed by GitHub
parent 394a7ea019
commit 8c78fe4f10
2 changed files with 243 additions and 104 deletions

View File

@@ -4,7 +4,6 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as GenAiLib from '@google/genai';
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
@@ -71,6 +70,21 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testFunction',
inputSchema: {
type: 'object',
properties: {},
},
},
],
}),
listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -78,15 +92,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => ({
functionDeclarations: [
{
name: 'testFunction',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
@@ -104,7 +109,7 @@ describe('mcp-client', () => {
);
await client.connect();
await client.discover({} as Config);
expect(mockedMcpToTool).toHaveBeenCalledOnce();
expect(mockedClient.listTools).toHaveBeenCalledWith({});
});
it('should not skip tools even if a parameter is missing a type', async () => {
@@ -119,7 +124,33 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
tool: vi.fn(),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'validTool',
inputSchema: {
type: 'object',
properties: {
param1: { type: 'string' },
},
},
},
{
name: 'invalidTool',
inputSchema: {
type: 'object',
properties: {
param1: { description: 'a param with no type' },
},
},
},
],
}),
listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -127,31 +158,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'validTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: { type: 'string' },
},
},
},
{
name: 'invalidTool',
parametersJsonSchema: {
type: 'object',
properties: {
param1: { description: 'a param with no type' },
},
},
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
@@ -183,7 +189,9 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockRejectedValue(new Error('Test error')),
listTools: vi.fn().mockResolvedValue({ tools: [] }),
listPrompts: vi.fn().mockRejectedValue(new Error('Test error')),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -191,9 +199,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () => Promise.resolve({ functionDeclarations: [] }),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
@@ -228,7 +233,8 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
request: vi.fn().mockResolvedValue({ prompts: [] }),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -236,7 +242,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
@@ -256,7 +261,6 @@ describe('mcp-client', () => {
await expect(client.discover({} as Config)).rejects.toThrow(
'No prompts or tools found on the server.',
);
expect(mockedMcpToTool).not.toHaveBeenCalled();
});
it('should discover tools if server supports them', async () => {
@@ -268,7 +272,17 @@ describe('mcp-client', () => {
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
request: vi.fn().mockResolvedValue({ prompts: [] }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: { type: 'object', properties: {} },
},
],
}),
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -276,17 +290,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'testTool',
description: 'A test tool',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
@@ -304,10 +307,87 @@ describe('mcp-client', () => {
);
await client.connect();
await client.discover({} as Config);
expect(mockedMcpToTool).toHaveBeenCalledOnce();
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
});
it('should discover tools with $defs and $ref in schema', async () => {
const mockedClient = {
connect: vi.fn(),
discover: vi.fn(),
disconnect: vi.fn(),
getStatus: vi.fn(),
registerCapabilities: vi.fn(),
setRequestHandler: vi.fn(),
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'toolWithDefs',
description: 'A tool using $defs',
inputSchema: {
type: 'object',
properties: {
param1: {
$ref: '#/$defs/MyType',
},
},
$defs: {
MyType: {
type: 'string',
description: 'A defined type',
},
},
},
},
],
}),
listPrompts: vi.fn().mockResolvedValue({
prompts: [],
}),
request: vi.fn().mockResolvedValue({}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
);
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
const mockedToolRegistry = {
registerTool: vi.fn(),
sortTools: vi.fn(),
getMessageBus: vi.fn().mockReturnValue(undefined),
} as unknown as ToolRegistry;
const client = new McpClient(
'test-server',
{
command: 'test-command',
},
mockedToolRegistry,
{} as PromptRegistry,
workspaceContext,
false,
);
await client.connect();
await client.discover({} as Config);
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
.calls[0][0];
expect(registeredTool.schema.parametersJsonSchema).toEqual({
type: 'object',
properties: {
param1: {
$ref: '#/$defs/MyType',
},
},
$defs: {
MyType: {
type: 'string',
description: 'A defined type',
},
},
});
});
it('should remove tools and prompts on disconnect', async () => {
const mockedClient = {
connect: vi.fn(),
@@ -318,9 +398,19 @@ describe('mcp-client', () => {
getServerCapabilities: vi
.fn()
.mockReturnValue({ tools: {}, prompts: {} }),
request: vi.fn().mockResolvedValue({
listPrompts: vi.fn().mockResolvedValue({
prompts: [{ id: 'prompt1', text: 'a prompt' }],
}),
request: vi.fn().mockResolvedValue({}),
listTools: vi.fn().mockResolvedValue({
tools: [
{
name: 'testTool',
description: 'A test tool',
inputSchema: { type: 'object', properties: {} },
},
],
}),
};
vi.mocked(ClientLib.Client).mockReturnValue(
mockedClient as unknown as ClientLib.Client,
@@ -328,17 +418,6 @@ describe('mcp-client', () => {
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
{} as SdkClientStdioLib.StdioClientTransport,
);
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
tool: () =>
Promise.resolve({
functionDeclarations: [
{
name: 'testTool',
description: 'A test tool',
},
],
}),
} as unknown as GenAiLib.CallableTool);
const mockedToolRegistry = {
registerTool: vi.fn(),
unregisterTool: vi.fn(),

View File

@@ -16,9 +16,8 @@ import type {
Prompt,
} from '@modelcontextprotocol/sdk/types.js';
import {
GetPromptResultSchema,
ListPromptsResultSchema,
ListRootsRequestSchema,
type Tool as McpTool,
} from '@modelcontextprotocol/sdk/types.js';
import { parse } from 'shell-quote';
import type { Config, MCPServerConfig } from '../config/config.js';
@@ -27,8 +26,7 @@ import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
import { DiscoveredMCPTool } from './mcp-tool.js';
import type { FunctionDeclaration } from '@google/genai';
import { mcpToTool } from '@google/genai';
import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
import { basename } from 'node:path';
import { pathToFileURL } from 'node:url';
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
@@ -621,29 +619,26 @@ export async function discoverTools(
// Only request tools if the server supports them.
if (mcpClient.getServerCapabilities()?.tools == null) return [];
const mcpCallableTool = mcpToTool(mcpClient, {
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
});
const tool = await mcpCallableTool.tool();
if (!Array.isArray(tool.functionDeclarations)) {
// This is a valid case for a prompt-only server
return [];
}
const response = await mcpClient.listTools({});
const discoveredTools: DiscoveredMCPTool[] = [];
for (const funcDecl of tool.functionDeclarations) {
for (const toolDef of response.tools) {
try {
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
if (!isEnabled(toolDef, mcpServerName, mcpServerConfig)) {
continue;
}
const mcpCallableTool = new McpCallableTool(
mcpClient,
toolDef,
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
);
const tool = new DiscoveredMCPTool(
mcpCallableTool,
mcpServerName,
funcDecl.name!,
funcDecl.description ?? '',
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
toolDef.name,
toolDef.description ?? '',
toolDef.inputSchema ?? { type: 'object', properties: {} },
mcpServerConfig.trust,
undefined,
cliConfig,
@@ -657,7 +652,7 @@ export async function discoverTools(
coreEvents.emitFeedback(
'error',
`Error discovering tool: '${
funcDecl.name
toolDef.name
}' from MCP server '${mcpServerName}': ${(error as Error).message}`,
error,
);
@@ -681,6 +676,69 @@ export async function discoverTools(
}
}
class McpCallableTool implements CallableTool {
constructor(
private readonly client: Client,
private readonly toolDef: McpTool,
private readonly timeout: number,
) {}
async tool(): Promise<Tool> {
return {
functionDeclarations: [
{
name: this.toolDef.name,
description: this.toolDef.description,
parametersJsonSchema: this.toolDef.inputSchema,
},
],
};
}
async callTool(functionCalls: FunctionCall[]): Promise<Part[]> {
// We only expect one function call at a time for MCP tools in this context
if (functionCalls.length !== 1) {
throw new Error('McpCallableTool only supports single function call');
}
const call = functionCalls[0];
try {
const result = await this.client.callTool(
{
name: call.name!,
arguments: call.args as Record<string, unknown>,
},
undefined,
{ timeout: this.timeout },
);
return [
{
functionResponse: {
name: call.name,
response: result,
},
},
];
} catch (error) {
// Return error in the format expected by DiscoveredMCPTool
return [
{
functionResponse: {
name: call.name,
response: {
error: {
message: error instanceof Error ? error.message : String(error),
isError: true,
},
},
},
},
];
}
}
}
/**
* Discovers and logs prompts from a connected MCP client.
* It retrieves prompt declarations from the client and logs their names.
@@ -697,10 +755,7 @@ export async function discoverPrompts(
// Only request prompts if the server supports them.
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
const response = await mcpClient.request(
{ method: 'prompts/list', params: {} },
ListPromptsResultSchema,
);
const response = await mcpClient.listPrompts({});
for (const prompt of response.prompts) {
promptRegistry.registerPrompt({
@@ -746,16 +801,17 @@ export async function invokeMcpPrompt(
promptParams: Record<string, unknown>,
): Promise<GetPromptResult> {
try {
const response = await mcpClient.request(
{
method: 'prompts/get',
params: {
name: promptName,
arguments: promptParams,
},
},
GetPromptResultSchema,
);
const sanitizedParams: Record<string, string> = {};
for (const [key, value] of Object.entries(promptParams)) {
if (value !== undefined && value !== null) {
sanitizedParams[key] = String(value);
}
}
const response = await mcpClient.getPrompt({
name: promptName,
arguments: sanitizedParams,
});
return response;
} catch (error) {
@@ -1339,9 +1395,13 @@ export async function createTransport(
);
}
interface NamedTool {
name?: string;
}
/** Visible for testing */
export function isEnabled(
funcDecl: FunctionDeclaration,
funcDecl: NamedTool,
mcpServerName: string,
mcpServerConfig: MCPServerConfig,
): boolean {