mirror of
https://github.com/google-gemini/gemini-cli.git
synced 2026-03-10 22:21:22 -07:00
rework MCP tool discovery and invocation (#13160)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as GenAiLib from '@google/genai';
|
||||
import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js';
|
||||
import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js';
|
||||
@@ -71,6 +70,21 @@ describe('mcp-client', () => {
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'testFunction',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
}),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -78,15 +92,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () => ({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testFunction',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
@@ -104,7 +109,7 @@ describe('mcp-client', () => {
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
expect(mockedClient.listTools).toHaveBeenCalledWith({});
|
||||
});
|
||||
|
||||
it('should not skip tools even if a parameter is missing a type', async () => {
|
||||
@@ -119,7 +124,33 @@ describe('mcp-client', () => {
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
tool: vi.fn(),
|
||||
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'validTool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'invalidTool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { description: 'a param with no type' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
}),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -127,31 +158,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'validTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { type: 'string' },
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: 'invalidTool',
|
||||
parametersJsonSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: { description: 'a param with no type' },
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
@@ -183,7 +189,9 @@ describe('mcp-client', () => {
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
||||
request: vi.fn().mockRejectedValue(new Error('Test error')),
|
||||
listTools: vi.fn().mockResolvedValue({ tools: [] }),
|
||||
listPrompts: vi.fn().mockRejectedValue(new Error('Test error')),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -191,9 +199,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () => Promise.resolve({ functionDeclarations: [] }),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
@@ -228,7 +233,8 @@ describe('mcp-client', () => {
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }),
|
||||
request: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -236,7 +242,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
@@ -256,7 +261,6 @@ describe('mcp-client', () => {
|
||||
await expect(client.discover({} as Config)).rejects.toThrow(
|
||||
'No prompts or tools found on the server.',
|
||||
);
|
||||
expect(mockedMcpToTool).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should discover tools if server supports them', async () => {
|
||||
@@ -268,7 +272,17 @@ describe('mcp-client', () => {
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
request: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({ prompts: [] }),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -276,17 +290,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
@@ -304,10 +307,87 @@ describe('mcp-client', () => {
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedMcpToTool).toHaveBeenCalledOnce();
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it('should discover tools with $defs and $ref in schema', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
discover: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getStatus: vi.fn(),
|
||||
registerCapabilities: vi.fn(),
|
||||
setRequestHandler: vi.fn(),
|
||||
getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'toolWithDefs',
|
||||
description: 'A tool using $defs',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: {
|
||||
$ref: '#/$defs/MyType',
|
||||
},
|
||||
},
|
||||
$defs: {
|
||||
MyType: {
|
||||
type: 'string',
|
||||
description: 'A defined type',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
listPrompts: vi.fn().mockResolvedValue({
|
||||
prompts: [],
|
||||
}),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
);
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
sortTools: vi.fn(),
|
||||
getMessageBus: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as ToolRegistry;
|
||||
const client = new McpClient(
|
||||
'test-server',
|
||||
{
|
||||
command: 'test-command',
|
||||
},
|
||||
mockedToolRegistry,
|
||||
{} as PromptRegistry,
|
||||
workspaceContext,
|
||||
false,
|
||||
);
|
||||
await client.connect();
|
||||
await client.discover({} as Config);
|
||||
expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce();
|
||||
const registeredTool = vi.mocked(mockedToolRegistry.registerTool).mock
|
||||
.calls[0][0];
|
||||
expect(registeredTool.schema.parametersJsonSchema).toEqual({
|
||||
type: 'object',
|
||||
properties: {
|
||||
param1: {
|
||||
$ref: '#/$defs/MyType',
|
||||
},
|
||||
},
|
||||
$defs: {
|
||||
MyType: {
|
||||
type: 'string',
|
||||
description: 'A defined type',
|
||||
},
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should remove tools and prompts on disconnect', async () => {
|
||||
const mockedClient = {
|
||||
connect: vi.fn(),
|
||||
@@ -318,9 +398,19 @@ describe('mcp-client', () => {
|
||||
getServerCapabilities: vi
|
||||
.fn()
|
||||
.mockReturnValue({ tools: {}, prompts: {} }),
|
||||
request: vi.fn().mockResolvedValue({
|
||||
listPrompts: vi.fn().mockResolvedValue({
|
||||
prompts: [{ id: 'prompt1', text: 'a prompt' }],
|
||||
}),
|
||||
request: vi.fn().mockResolvedValue({}),
|
||||
listTools: vi.fn().mockResolvedValue({
|
||||
tools: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
inputSchema: { type: 'object', properties: {} },
|
||||
},
|
||||
],
|
||||
}),
|
||||
};
|
||||
vi.mocked(ClientLib.Client).mockReturnValue(
|
||||
mockedClient as unknown as ClientLib.Client,
|
||||
@@ -328,17 +418,6 @@ describe('mcp-client', () => {
|
||||
vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue(
|
||||
{} as SdkClientStdioLib.StdioClientTransport,
|
||||
);
|
||||
vi.mocked(GenAiLib.mcpToTool).mockReturnValue({
|
||||
tool: () =>
|
||||
Promise.resolve({
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: 'testTool',
|
||||
description: 'A test tool',
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as GenAiLib.CallableTool);
|
||||
const mockedToolRegistry = {
|
||||
registerTool: vi.fn(),
|
||||
unregisterTool: vi.fn(),
|
||||
|
||||
@@ -16,9 +16,8 @@ import type {
|
||||
Prompt,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
GetPromptResultSchema,
|
||||
ListPromptsResultSchema,
|
||||
ListRootsRequestSchema,
|
||||
type Tool as McpTool,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
@@ -27,8 +26,7 @@ import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
|
||||
import type { FunctionDeclaration } from '@google/genai';
|
||||
import { mcpToTool } from '@google/genai';
|
||||
import type { CallableTool, FunctionCall, Part, Tool } from '@google/genai';
|
||||
import { basename } from 'node:path';
|
||||
import { pathToFileURL } from 'node:url';
|
||||
import { MCPOAuthProvider } from '../mcp/oauth-provider.js';
|
||||
@@ -621,29 +619,26 @@ export async function discoverTools(
|
||||
// Only request tools if the server supports them.
|
||||
if (mcpClient.getServerCapabilities()?.tools == null) return [];
|
||||
|
||||
const mcpCallableTool = mcpToTool(mcpClient, {
|
||||
timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
});
|
||||
const tool = await mcpCallableTool.tool();
|
||||
|
||||
if (!Array.isArray(tool.functionDeclarations)) {
|
||||
// This is a valid case for a prompt-only server
|
||||
return [];
|
||||
}
|
||||
|
||||
const response = await mcpClient.listTools({});
|
||||
const discoveredTools: DiscoveredMCPTool[] = [];
|
||||
for (const funcDecl of tool.functionDeclarations) {
|
||||
for (const toolDef of response.tools) {
|
||||
try {
|
||||
if (!isEnabled(funcDecl, mcpServerName, mcpServerConfig)) {
|
||||
if (!isEnabled(toolDef, mcpServerName, mcpServerConfig)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const mcpCallableTool = new McpCallableTool(
|
||||
mcpClient,
|
||||
toolDef,
|
||||
mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
);
|
||||
|
||||
const tool = new DiscoveredMCPTool(
|
||||
mcpCallableTool,
|
||||
mcpServerName,
|
||||
funcDecl.name!,
|
||||
funcDecl.description ?? '',
|
||||
funcDecl.parametersJsonSchema ?? { type: 'object', properties: {} },
|
||||
toolDef.name,
|
||||
toolDef.description ?? '',
|
||||
toolDef.inputSchema ?? { type: 'object', properties: {} },
|
||||
mcpServerConfig.trust,
|
||||
undefined,
|
||||
cliConfig,
|
||||
@@ -657,7 +652,7 @@ export async function discoverTools(
|
||||
coreEvents.emitFeedback(
|
||||
'error',
|
||||
`Error discovering tool: '${
|
||||
funcDecl.name
|
||||
toolDef.name
|
||||
}' from MCP server '${mcpServerName}': ${(error as Error).message}`,
|
||||
error,
|
||||
);
|
||||
@@ -681,6 +676,69 @@ export async function discoverTools(
|
||||
}
|
||||
}
|
||||
|
||||
class McpCallableTool implements CallableTool {
|
||||
constructor(
|
||||
private readonly client: Client,
|
||||
private readonly toolDef: McpTool,
|
||||
private readonly timeout: number,
|
||||
) {}
|
||||
|
||||
async tool(): Promise<Tool> {
|
||||
return {
|
||||
functionDeclarations: [
|
||||
{
|
||||
name: this.toolDef.name,
|
||||
description: this.toolDef.description,
|
||||
parametersJsonSchema: this.toolDef.inputSchema,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
async callTool(functionCalls: FunctionCall[]): Promise<Part[]> {
|
||||
// We only expect one function call at a time for MCP tools in this context
|
||||
if (functionCalls.length !== 1) {
|
||||
throw new Error('McpCallableTool only supports single function call');
|
||||
}
|
||||
const call = functionCalls[0];
|
||||
|
||||
try {
|
||||
const result = await this.client.callTool(
|
||||
{
|
||||
name: call.name!,
|
||||
arguments: call.args as Record<string, unknown>,
|
||||
},
|
||||
undefined,
|
||||
{ timeout: this.timeout },
|
||||
);
|
||||
|
||||
return [
|
||||
{
|
||||
functionResponse: {
|
||||
name: call.name,
|
||||
response: result,
|
||||
},
|
||||
},
|
||||
];
|
||||
} catch (error) {
|
||||
// Return error in the format expected by DiscoveredMCPTool
|
||||
return [
|
||||
{
|
||||
functionResponse: {
|
||||
name: call.name,
|
||||
response: {
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
isError: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discovers and logs prompts from a connected MCP client.
|
||||
* It retrieves prompt declarations from the client and logs their names.
|
||||
@@ -697,10 +755,7 @@ export async function discoverPrompts(
|
||||
// Only request prompts if the server supports them.
|
||||
if (mcpClient.getServerCapabilities()?.prompts == null) return [];
|
||||
|
||||
const response = await mcpClient.request(
|
||||
{ method: 'prompts/list', params: {} },
|
||||
ListPromptsResultSchema,
|
||||
);
|
||||
const response = await mcpClient.listPrompts({});
|
||||
|
||||
for (const prompt of response.prompts) {
|
||||
promptRegistry.registerPrompt({
|
||||
@@ -746,16 +801,17 @@ export async function invokeMcpPrompt(
|
||||
promptParams: Record<string, unknown>,
|
||||
): Promise<GetPromptResult> {
|
||||
try {
|
||||
const response = await mcpClient.request(
|
||||
{
|
||||
method: 'prompts/get',
|
||||
params: {
|
||||
name: promptName,
|
||||
arguments: promptParams,
|
||||
},
|
||||
},
|
||||
GetPromptResultSchema,
|
||||
);
|
||||
const sanitizedParams: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(promptParams)) {
|
||||
if (value !== undefined && value !== null) {
|
||||
sanitizedParams[key] = String(value);
|
||||
}
|
||||
}
|
||||
|
||||
const response = await mcpClient.getPrompt({
|
||||
name: promptName,
|
||||
arguments: sanitizedParams,
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (error) {
|
||||
@@ -1339,9 +1395,13 @@ export async function createTransport(
|
||||
);
|
||||
}
|
||||
|
||||
interface NamedTool {
|
||||
name?: string;
|
||||
}
|
||||
|
||||
/** Visible for testing */
|
||||
export function isEnabled(
|
||||
funcDecl: FunctionDeclaration,
|
||||
funcDecl: NamedTool,
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
): boolean {
|
||||
|
||||
Reference in New Issue
Block a user