/** * @license * Copyright 2025 Google LLC * SPDX-License-Identifier: Apache-2.0 */ import * as GenAiLib from '@google/genai'; import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { AuthProviderType, type Config } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { WorkspaceContext } from '../utils/workspaceContext.js'; import { connectToMcpServer, createTransport, hasNetworkTransport, isEnabled, McpClient, populateMcpServerCommand, } from './mcp-client.js'; import type { ToolRegistry } from './tool-registry.js'; import * as fs from 'node:fs'; import * as os from 'node:os'; import * as path from 'node:path'; vi.mock('@modelcontextprotocol/sdk/client/stdio.js'); vi.mock('@modelcontextprotocol/sdk/client/index.js'); vi.mock('@google/genai'); vi.mock('../mcp/oauth-provider.js'); vi.mock('../mcp/oauth-token-storage.js'); vi.mock('../mcp/oauth-utils.js'); describe('mcp-client', () => { let workspaceContext: WorkspaceContext; let testWorkspace: string; beforeEach(() => { // create a tmp dir for this test // Create a unique temporary directory for the workspace to avoid conflicts testWorkspace = fs.mkdtempSync( path.join(os.tmpdir(), 'gemini-agent-test-'), ); workspaceContext = new WorkspaceContext(testWorkspace); }); afterEach(() => { vi.restoreAllMocks(); }); describe('McpClient', () => { it('should discover tools', async () => { const mockedClient = { connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, ); vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => ({ functionDeclarations: [ { name: 'testFunction', }, ], }), } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), } as unknown as ToolRegistry; const client = new McpClient( 'test-server', { command: 'test-command', }, mockedToolRegistry, {} as PromptRegistry, workspaceContext, false, ); await client.connect(); await client.discover({} as Config); expect(mockedMcpToTool).toHaveBeenCalledOnce(); }); it('should not skip tools even if a parameter is missing a type', async () => { const consoleWarnSpy = vi .spyOn(console, 'warn') .mockImplementation(() => {}); const mockedClient = { connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), tool: vi.fn(), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, ); vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => Promise.resolve({ functionDeclarations: [ { name: 'validTool', parametersJsonSchema: { type: 'object', properties: { param1: { type: 'string' }, }, }, }, { name: 'invalidTool', parametersJsonSchema: { type: 'object', properties: { param1: { description: 'a param with no type' }, }, }, }, ], }), } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), } as unknown as ToolRegistry; const client = new McpClient( 'test-server', { command: 'test-command', }, mockedToolRegistry, {} as PromptRegistry, workspaceContext, false, ); await client.connect(); await client.discover({} as Config); expect(mockedToolRegistry.registerTool).toHaveBeenCalledTimes(2); expect(consoleWarnSpy).not.toHaveBeenCalled(); consoleWarnSpy.mockRestore(); }); it('should handle errors when discovering prompts', async () => { const consoleErrorSpy = vi .spyOn(console, 'error') .mockImplementation(() => {}); const mockedClient = { connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), request: vi.fn().mockRejectedValue(new Error('Test error')), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, ); vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => Promise.resolve({ functionDeclarations: [] }), } as unknown as GenAiLib.CallableTool); const client = new McpClient( 'test-server', { command: 'test-command', }, {} as ToolRegistry, {} as PromptRegistry, workspaceContext, false, ); await client.connect(); await expect(client.discover({} as Config)).rejects.toThrow( 'No prompts or tools found on the server.', ); expect(consoleErrorSpy).toHaveBeenCalledWith( `Error discovering prompts from test-server: Test error`, ); consoleErrorSpy.mockRestore(); }); it('should not discover tools if server does not support them', async () => { const mockedClient = { connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ prompts: {} }), request: vi.fn().mockResolvedValue({ prompts: [] }), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, ); vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool); const mockedToolRegistry = { registerTool: vi.fn(), } as unknown as ToolRegistry; const client = new McpClient( 'test-server', { command: 'test-command', }, mockedToolRegistry, {} as PromptRegistry, workspaceContext, false, ); await client.connect(); await expect(client.discover({} as Config)).rejects.toThrow( 'No prompts or tools found on the server.', ); expect(mockedMcpToTool).not.toHaveBeenCalled(); }); it('should discover tools if server supports them', async () => { const mockedClient = { connect: vi.fn(), discover: vi.fn(), disconnect: vi.fn(), getStatus: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), request: vi.fn().mockResolvedValue({ prompts: [] }), }; vi.mocked(ClientLib.Client).mockReturnValue( mockedClient as unknown as ClientLib.Client, ); vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( {} as SdkClientStdioLib.StdioClientTransport, ); const mockedMcpToTool = vi.mocked(GenAiLib.mcpToTool).mockReturnValue({ tool: () => Promise.resolve({ functionDeclarations: [ { name: 'testTool', description: 'A test tool', }, ], }), } as unknown as GenAiLib.CallableTool); const mockedToolRegistry = { registerTool: vi.fn(), } as unknown as ToolRegistry; const client = new McpClient( 'test-server', { command: 'test-command', }, mockedToolRegistry, {} as PromptRegistry, workspaceContext, false, ); await client.connect(); await client.discover({} as Config); expect(mockedMcpToTool).toHaveBeenCalledOnce(); expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); }); }); describe('appendMcpServerCommand', () => { it('should do nothing if no MCP servers or command are configured', () => { const out = populateMcpServerCommand({}, undefined); expect(out).toEqual({}); }); it('should discover tools via mcpServerCommand', () => { const commandString = 'command --arg1 value1'; const out = populateMcpServerCommand({}, commandString); expect(out).toEqual({ mcp: { command: 'command', args: ['--arg1', 'value1'], }, }); }); it('should handle error if mcpServerCommand parsing fails', () => { expect(() => populateMcpServerCommand({}, 'derp && herp')).toThrowError(); }); }); describe('createTransport', () => { describe('should connect via httpUrl', () => { it('without headers', async () => { const transport = await createTransport( 'test-server', { httpUrl: 'http://test-server', }, false, ); expect(transport).toEqual( new StreamableHTTPClientTransport(new URL('http://test-server'), {}), ); }); it('with headers', async () => { const transport = await createTransport( 'test-server', { httpUrl: 'http://test-server', headers: { Authorization: 'derp' }, }, false, ); expect(transport).toEqual( new StreamableHTTPClientTransport(new URL('http://test-server'), { requestInit: { headers: { Authorization: 'derp' }, }, }), ); }); }); describe('should connect via url', () => { it('without headers', async () => { const transport = await createTransport( 'test-server', { url: 'http://test-server', }, false, ); expect(transport).toEqual( new SSEClientTransport(new URL('http://test-server'), {}), ); }); it('with headers', async () => { const transport = await createTransport( 'test-server', { url: 'http://test-server', headers: { Authorization: 'derp' }, }, false, ); expect(transport).toEqual( new SSEClientTransport(new URL('http://test-server'), { requestInit: { headers: { Authorization: 'derp' }, }, }), ); }); }); it('should connect via command', async () => { const mockedTransport = vi .spyOn(SdkClientStdioLib, 'StdioClientTransport') .mockReturnValue({} as SdkClientStdioLib.StdioClientTransport); await createTransport( 'test-server', { command: 'test-command', args: ['--foo', 'bar'], env: { FOO: 'bar' }, cwd: 'test/cwd', }, false, ); expect(mockedTransport).toHaveBeenCalledWith({ command: 'test-command', args: ['--foo', 'bar'], cwd: 'test/cwd', env: { ...process.env, FOO: 'bar' }, stderr: 'pipe', }); }); describe('useGoogleCredentialProvider', () => { it('should use GoogleCredentialProvider when specified', async () => { const transport = await createTransport( 'test-server', { httpUrl: 'http://test.googleapis.com', authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, oauth: { scopes: ['scope1'], }, }, false, ); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); // eslint-disable-next-line @typescript-eslint/no-explicit-any const authProvider = (transport as any)._authProvider; expect(authProvider).toBeInstanceOf(GoogleCredentialProvider); }); it('should use GoogleCredentialProvider with SSE transport', async () => { const transport = await createTransport( 'test-server', { url: 'http://test.googleapis.com', authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, oauth: { scopes: ['scope1'], }, }, false, ); expect(transport).toBeInstanceOf(SSEClientTransport); // eslint-disable-next-line @typescript-eslint/no-explicit-any const authProvider = (transport as any)._authProvider; expect(authProvider).toBeInstanceOf(GoogleCredentialProvider); }); it('should throw an error if no URL is provided with GoogleCredentialProvider', async () => { await expect( createTransport( 'test-server', { authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, oauth: { scopes: ['scope1'], }, }, false, ), ).rejects.toThrow( 'URL must be provided in the config for Google Credentials provider', ); }); }); }); describe('isEnabled', () => { const funcDecl = { name: 'myTool' }; const serverName = 'myServer'; it('should return true if no include or exclude lists are provided', () => { const mcpServerConfig = {}; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); }); it('should return false if the tool is in the exclude list', () => { const mcpServerConfig = { excludeTools: ['myTool'] }; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); }); it('should return true if the tool is in the include list', () => { const mcpServerConfig = { includeTools: ['myTool'] }; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); }); it('should return true if the tool is in the include list with parentheses', () => { const mcpServerConfig = { includeTools: ['myTool()'] }; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(true); }); it('should return false if the include list exists but does not contain the tool', () => { const mcpServerConfig = { includeTools: ['anotherTool'] }; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); }); it('should return false if the tool is in both the include and exclude lists', () => { const mcpServerConfig = { includeTools: ['myTool'], excludeTools: ['myTool'], }; expect(isEnabled(funcDecl, serverName, mcpServerConfig)).toBe(false); }); it('should return false if the function declaration has no name', () => { const namelessFuncDecl = {}; const mcpServerConfig = {}; expect(isEnabled(namelessFuncDecl, serverName, mcpServerConfig)).toBe( false, ); }); }); describe('hasNetworkTransport', () => { it('should return true if only url is provided', () => { const config = { url: 'http://example.com' }; expect(hasNetworkTransport(config)).toBe(true); }); it('should return true if only httpUrl is provided', () => { const config = { httpUrl: 'http://example.com' }; expect(hasNetworkTransport(config)).toBe(true); }); it('should return true if both url and httpUrl are provided', () => { const config = { url: 'http://example.com/sse', httpUrl: 'http://example.com/http', }; expect(hasNetworkTransport(config)).toBe(true); }); it('should return false if neither url nor httpUrl is provided', () => { const config = { command: 'do-something' }; expect(hasNetworkTransport(config)).toBe(false); }); it('should return false for an empty config object', () => { const config = {}; expect(hasNetworkTransport(config)).toBe(false); }); }); }); describe('connectToMcpServer with OAuth', () => { let mockedClient: ClientLib.Client; let workspaceContext: WorkspaceContext; let testWorkspace: string; let mockAuthProvider: MCPOAuthProvider; let mockTokenStorage: MCPOAuthTokenStorage; beforeEach(() => { mockedClient = { connect: vi.fn(), close: vi.fn(), registerCapabilities: vi.fn(), setRequestHandler: vi.fn(), onclose: vi.fn(), notification: vi.fn(), } as unknown as ClientLib.Client; vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient); testWorkspace = fs.mkdtempSync( path.join(os.tmpdir(), 'gemini-agent-test-'), ); workspaceContext = new WorkspaceContext(testWorkspace); vi.spyOn(console, 'log').mockImplementation(() => {}); vi.spyOn(console, 'warn').mockImplementation(() => {}); vi.spyOn(console, 'error').mockImplementation(() => {}); mockTokenStorage = { getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }), } as unknown as MCPOAuthTokenStorage; vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage); mockAuthProvider = { authenticate: vi.fn().mockResolvedValue(undefined), getValidToken: vi.fn().mockResolvedValue('test-access-token'), tokenStorage: mockTokenStorage, } as unknown as MCPOAuthProvider; vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider); }); afterEach(() => { vi.clearAllMocks(); }); it('should handle automatic OAuth flow on 401 with www-authenticate header', async () => { const serverUrl = 'http://test-server.com/'; const authUrl = 'http://auth.example.com/auth'; const tokenUrl = 'http://auth.example.com/token'; const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`; vi.mocked(mockedClient.connect).mockRejectedValueOnce( new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`), ); vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ authorizationUrl: authUrl, tokenUrl, scopes: ['test-scope'], }); // We need this to be an any type because we dig into its private state. // eslint-disable-next-line @typescript-eslint/no-explicit-any let capturedTransport: any; vi.mocked(mockedClient.connect).mockImplementationOnce( async (transport) => { capturedTransport = transport; return Promise.resolve(); }, ); const client = await connectToMcpServer( 'test-server', { httpUrl: serverUrl }, false, workspaceContext, ); expect(client).toBe(mockedClient); expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); const authHeader = capturedTransport._requestInit?.headers?.['Authorization']; expect(authHeader).toBe('Bearer test-access-token'); }); it('should discover oauth config if not in www-authenticate header', async () => { const serverUrl = 'http://test-server.com'; const authUrl = 'http://auth.example.com/auth'; const tokenUrl = 'http://auth.example.com/token'; vi.mocked(mockedClient.connect).mockRejectedValueOnce( new Error('401 Unauthorized'), ); vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ authorizationUrl: authUrl, tokenUrl, scopes: ['test-scope'], }); vi.mocked(mockAuthProvider.getValidToken).mockResolvedValue( 'test-access-token-from-discovery', ); // We need this to be an any type because we dig into its private state. // eslint-disable-next-line @typescript-eslint/no-explicit-any let capturedTransport: any; vi.mocked(mockedClient.connect).mockImplementationOnce( async (transport) => { capturedTransport = transport; return Promise.resolve(); }, ); const client = await connectToMcpServer( 'test-server', { httpUrl: serverUrl }, false, workspaceContext, ); expect(client).toBe(mockedClient); expect(mockedClient.connect).toHaveBeenCalledTimes(2); expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(serverUrl); const authHeader = capturedTransport._requestInit?.headers?.['Authorization']; expect(authHeader).toBe('Bearer test-access-token-from-discovery'); }); });