diff --git a/packages/cli/src/commands/mcp/add.test.ts b/packages/cli/src/commands/mcp/add.test.ts index 12fba04b15..9eceb45715 100644 --- a/packages/cli/src/commands/mcp/add.test.ts +++ b/packages/cli/src/commands/mcp/add.test.ts @@ -107,6 +107,7 @@ describe('mcp add command', () => { expect(mockSetValue).toHaveBeenCalledWith(SettingScope.User, 'mcpServers', { 'sse-server': { url: 'https://example.com/sse-endpoint', + type: 'sse', headers: { 'X-API-Key': 'your-key' }, }, }); @@ -122,7 +123,8 @@ describe('mcp add command', () => { 'mcpServers', { 'http-server': { - httpUrl: 'https://example.com/mcp', + url: 'https://example.com/mcp', + type: 'http', headers: { Authorization: 'Bearer your-token' }, }, }, diff --git a/packages/cli/src/commands/mcp/add.ts b/packages/cli/src/commands/mcp/add.ts index b960736e4c..eaf599517d 100644 --- a/packages/cli/src/commands/mcp/add.ts +++ b/packages/cli/src/commands/mcp/add.ts @@ -69,6 +69,7 @@ async function addMcpServer( case 'sse': newServer = { url: commandOrUrl, + type: 'sse', headers, timeout, trust, @@ -79,7 +80,8 @@ async function addMcpServer( break; case 'http': newServer = { - httpUrl: commandOrUrl, + url: commandOrUrl, + type: 'http', headers, timeout, trust, diff --git a/packages/cli/src/ui/commands/mcpCommand.ts b/packages/cli/src/ui/commands/mcpCommand.ts index 65abb50741..f3fdd06b8e 100644 --- a/packages/cli/src/ui/commands/mcpCommand.ts +++ b/packages/cli/src/ui/commands/mcpCommand.ts @@ -20,6 +20,7 @@ import { MCPServerStatus, getErrorMessage, MCPOAuthTokenStorage, + mcpServerRequiresOAuth, } from '@google/gemini-cli-core'; import { appEvents, AppEvent } from '../../utils/events.js'; import { MessageType, type HistoryItemMcpStatus } from '../types.js'; @@ -47,12 +48,23 @@ const authCommand: SlashCommand = { const mcpServers = config.getMcpClientManager()?.getMcpServers() ?? {}; if (!serverName) { - // List servers that support OAuth - const oauthServers = Object.entries(mcpServers) + // List servers that support OAuth from two sources: + // 1. Servers with oauth.enabled in config + // 2. Servers detected as requiring OAuth (returned 401) + const configuredOAuthServers = Object.entries(mcpServers) .filter(([_, server]) => server.oauth?.enabled) .map(([name, _]) => name); - if (oauthServers.length === 0) { + const detectedOAuthServers = Array.from( + mcpServerRequiresOAuth.keys(), + ).filter((name) => mcpServers[name]); // Only include configured servers + + // Combine and deduplicate + const allOAuthServers = [ + ...new Set([...configuredOAuthServers, ...detectedOAuthServers]), + ]; + + if (allOAuthServers.length === 0) { return { type: 'message', messageType: 'info', @@ -63,7 +75,7 @@ const authCommand: SlashCommand = { return { type: 'message', messageType: 'info', - content: `MCP servers with OAuth authentication:\n${oauthServers.map((s) => ` - ${s}`).join('\n')}\n\nUse /mcp auth to authenticate.`, + content: `MCP servers with OAuth authentication:\n${allOAuthServers.map((s) => ` - ${s}`).join('\n')}\n\nUse /mcp auth to authenticate.`, }; } @@ -220,7 +232,8 @@ const listAction = async ( const tokenStorage = new MCPOAuthTokenStorage(); for (const serverName of serverNames) { const server = mcpServers[serverName]; - if (server.oauth?.enabled) { + // Check auth status for servers with oauth.enabled OR detected as requiring OAuth + if (server.oauth?.enabled || mcpServerRequiresOAuth.has(serverName)) { const creds = await tokenStorage.getCredentials(serverName); if (creds) { if (creds.token.expiresAt && creds.token.expiresAt < Date.now()) { diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index 11674a5d7e..cadb5898bb 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -190,6 +190,12 @@ export class MCPServerConfig { readonly headers?: Record, // For websocket transport readonly tcp?: string, + // Transport type (optional, for use with 'url' field) + // When set to 'http', uses StreamableHTTPClientTransport + // When set to 'sse', uses SSEClientTransport + // When omitted, auto-detects transport type + // Note: 'httpUrl' is deprecated in favor of 'url' + 'type' + readonly type?: 'sse' | 'http', // Common readonly timeout?: number, readonly trust?: boolean, diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index a21c4e160e..c62e89184a 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -15,7 +15,7 @@ import { MCPDiscoveryState, populateMcpServerCommand, } from './mcp-client.js'; -import { getErrorMessage } from '../utils/errors.js'; +import { getErrorMessage, isAuthenticationError } from '../utils/errors.js'; import type { EventEmitter } from 'node:events'; import { coreEvents } from '../utils/events.js'; import { debugLogger } from '../utils/debugLogger.js'; @@ -186,14 +186,17 @@ export class McpClientManager { this.eventEmitter?.emit('mcp-client-update', this.clients); } catch (error) { this.eventEmitter?.emit('mcp-client-update', this.clients); - // Log the error but don't let a single failed server stop the others - coreEvents.emitFeedback( - 'error', - `Error during discovery for server '${name}': ${getErrorMessage( + // Check if this is a 401/auth error - if so, don't show as red error + // (the info message was already shown in mcp-client.ts) + if (!isAuthenticationError(error)) { + // Log the error but don't let a single failed server stop the others + const errorMessage = getErrorMessage(error); + coreEvents.emitFeedback( + 'error', + `Error during discovery for MCP server '${name}': ${errorMessage}`, error, - )}`, - error, - ); + ); + } } } finally { // This is required to update the content generator configuration with the diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 15e9fd8cb3..6ddd5b1271 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -7,7 +7,10 @@ import * as ClientLib from '@modelcontextprotocol/sdk/client/index.js'; import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; import * as SdkClientStdioLib from '@modelcontextprotocol/sdk/client/stdio.js'; -import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; +import { + StreamableHTTPClientTransport, + StreamableHTTPError, +} from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { AuthProviderType, type Config } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; @@ -490,16 +493,14 @@ describe('mcp-client', () => { ); expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); - expect(transport).toHaveProperty( - '_url', - new URL('http://test-server/'), - ); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { headers: {} }, + }); }); it('with headers', async () => { - // We need this to be an any type because we dig into its private state. - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const transport: any = await createTransport( + const transport = await createTransport( 'test-server', { httpUrl: 'http://test-server', @@ -507,13 +508,14 @@ describe('mcp-client', () => { }, false, ); + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); - expect(transport).toHaveProperty( - '_url', - new URL('http://test-server/'), - ); - const authHeader = transport._requestInit?.headers?.['Authorization']; - expect(authHeader).toBe('derp'); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { + headers: { Authorization: 'derp' }, + }, + }); }); }); @@ -526,17 +528,15 @@ describe('mcp-client', () => { }, false, ); - expect(transport).toBeInstanceOf(SSEClientTransport); - expect(transport).toHaveProperty( - '_url', - new URL('http://test-server/'), - ); + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { headers: {} }, + }); }); it('with headers', async () => { - // We need this to be an any type because we dig into its private state. - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const transport: any = await createTransport( + const transport = await createTransport( 'test-server', { url: 'http://test-server', @@ -544,13 +544,122 @@ describe('mcp-client', () => { }, false, ); - expect(transport).toBeInstanceOf(SSEClientTransport); - expect(transport).toHaveProperty( - '_url', - new URL('http://test-server/'), + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { + headers: { Authorization: 'derp' }, + }, + }); + }); + + it('with type="http" creates StreamableHTTPClientTransport', async () => { + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'http', + }, + false, ); - const authHeader = transport._requestInit?.headers?.['Authorization']; - expect(authHeader).toBe('derp'); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { headers: {} }, + }); + }); + + it('with type="sse" creates SSEClientTransport', async () => { + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'sse', + }, + false, + ); + + expect(transport).toBeInstanceOf(SSEClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { headers: {} }, + }); + }); + + it('without type defaults to StreamableHTTPClientTransport', async () => { + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + }, + false, + ); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { headers: {} }, + }); + }); + + it('with type="http" and headers applies headers correctly', async () => { + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'http', + headers: { Authorization: 'Bearer token' }, + }, + false, + ); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { + headers: { Authorization: 'Bearer token' }, + }, + }); + }); + + it('with type="sse" and headers applies headers correctly', async () => { + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + type: 'sse', + headers: { 'X-API-Key': 'key123' }, + }, + false, + ); + + expect(transport).toBeInstanceOf(SSEClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server'), + _requestInit: { + headers: { 'X-API-Key': 'key123' }, + }, + }); + }); + + it('httpUrl takes priority over url when both are present', async () => { + const transport = await createTransport( + 'test-server', + { + httpUrl: 'http://test-server-http', + url: 'http://test-server-url', + }, + false, + ); + + // httpUrl should take priority and create HTTP transport + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + expect(transport).toMatchObject({ + _url: new URL('http://test-server-http'), + _requestInit: { headers: {} }, + }); }); }); @@ -680,6 +789,7 @@ describe('mcp-client', () => { 'test-server', { url: 'http://test.googleapis.com', + type: 'sse', authProviderType: AuthProviderType.GOOGLE_CREDENTIALS, oauth: { scopes: ['scope1'], @@ -839,7 +949,10 @@ describe('connectToMcpServer with OAuth', () => { const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`; vi.mocked(mockedClient.connect).mockRejectedValueOnce( - new Error(`401 Unauthorized\nwww-authenticate: ${wwwAuthHeader}`), + new StreamableHTTPError( + 401, + `Unauthorized\nwww-authenticate: ${wwwAuthHeader}`, + ), ); vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ @@ -860,7 +973,7 @@ describe('connectToMcpServer with OAuth', () => { const client = await connectToMcpServer( 'test-server', - { httpUrl: serverUrl }, + { httpUrl: serverUrl, oauth: { enabled: true } }, false, workspaceContext, ); @@ -880,7 +993,7 @@ describe('connectToMcpServer with OAuth', () => { const tokenUrl = 'http://auth.example.com/token'; vi.mocked(mockedClient.connect).mockRejectedValueOnce( - new Error('401 Unauthorized'), + new StreamableHTTPError(401, 'Unauthorized'), ); vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ @@ -904,7 +1017,7 @@ describe('connectToMcpServer with OAuth', () => { const client = await connectToMcpServer( 'test-server', - { httpUrl: serverUrl }, + { httpUrl: serverUrl, oauth: { enabled: true } }, false, workspaceContext, ); @@ -919,3 +1032,193 @@ describe('connectToMcpServer with OAuth', () => { expect(authHeader).toBe('Bearer test-access-token-from-discovery'); }); }); + +describe('connectToMcpServer - HTTP→SSE fallback', () => { + let mockedClient: ClientLib.Client; + let workspaceContext: WorkspaceContext; + let testWorkspace: string; + + beforeEach(() => { + mockedClient = { + connect: vi.fn(), + close: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + onclose: vi.fn(), + notification: vi.fn(), + } as unknown as ClientLib.Client; + vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient); + + testWorkspace = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-agent-test-'), + ); + workspaceContext = new WorkspaceContext(testWorkspace); + + vi.spyOn(console, 'log').mockImplementation(() => {}); + vi.spyOn(console, 'warn').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should NOT trigger fallback when type="http" is explicit', async () => { + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new Error('Connection failed'), + ); + + await expect( + connectToMcpServer( + 'test-server', + { url: 'http://test-server', type: 'http' }, + false, + workspaceContext, + ), + ).rejects.toThrow('Connection failed'); + + // Should only try once (no fallback) + expect(mockedClient.connect).toHaveBeenCalledTimes(1); + }); + + it('should NOT trigger fallback when type="sse" is explicit', async () => { + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new Error('Connection failed'), + ); + + await expect( + connectToMcpServer( + 'test-server', + { url: 'http://test-server', type: 'sse' }, + false, + workspaceContext, + ), + ).rejects.toThrow('Connection failed'); + + // Should only try once (no fallback) + expect(mockedClient.connect).toHaveBeenCalledTimes(1); + }); + + it('should trigger fallback when url provided without type and HTTP fails', async () => { + vi.mocked(mockedClient.connect) + .mockRejectedValueOnce(new StreamableHTTPError(500, 'Server error')) + .mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + 'test-server', + { url: 'http://test-server' }, + false, + workspaceContext, + ); + + expect(client).toBe(mockedClient); + // First HTTP attempt fails, second SSE attempt succeeds + expect(mockedClient.connect).toHaveBeenCalledTimes(2); + }); + + it('should throw original HTTP error when both HTTP and SSE fail (non-401)', async () => { + const httpError = new StreamableHTTPError(500, 'Server error'); + const sseError = new Error('SSE connection failed'); + + vi.mocked(mockedClient.connect) + .mockRejectedValueOnce(httpError) + .mockRejectedValueOnce(sseError); + + await expect( + connectToMcpServer( + 'test-server', + { url: 'http://test-server' }, + false, + workspaceContext, + ), + ).rejects.toThrow('Server error'); + + expect(mockedClient.connect).toHaveBeenCalledTimes(2); + }); + + it('should handle HTTP 404 followed by SSE success', async () => { + vi.mocked(mockedClient.connect) + .mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found')) + .mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + 'test-server', + { url: 'http://test-server' }, + false, + workspaceContext, + ); + + expect(client).toBe(mockedClient); + expect(mockedClient.connect).toHaveBeenCalledTimes(2); + }); +}); + +describe('connectToMcpServer - OAuth with transport fallback', () => { + let mockedClient: ClientLib.Client; + let workspaceContext: WorkspaceContext; + let testWorkspace: string; + let mockAuthProvider: MCPOAuthProvider; + let mockTokenStorage: MCPOAuthTokenStorage; + + beforeEach(() => { + mockedClient = { + connect: vi.fn(), + close: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + onclose: vi.fn(), + notification: vi.fn(), + } as unknown as ClientLib.Client; + vi.mocked(ClientLib.Client).mockImplementation(() => mockedClient); + + testWorkspace = fs.mkdtempSync( + path.join(os.tmpdir(), 'gemini-agent-test-'), + ); + workspaceContext = new WorkspaceContext(testWorkspace); + + vi.spyOn(console, 'log').mockImplementation(() => {}); + vi.spyOn(console, 'warn').mockImplementation(() => {}); + vi.spyOn(console, 'error').mockImplementation(() => {}); + + mockTokenStorage = { + getCredentials: vi.fn().mockResolvedValue({ clientId: 'test-client' }), + } as unknown as MCPOAuthTokenStorage; + vi.mocked(MCPOAuthTokenStorage).mockReturnValue(mockTokenStorage); + + mockAuthProvider = { + authenticate: vi.fn().mockResolvedValue(undefined), + getValidToken: vi.fn().mockResolvedValue('test-access-token'), + tokenStorage: mockTokenStorage, + } as unknown as MCPOAuthProvider; + vi.mocked(MCPOAuthProvider).mockReturnValue(mockAuthProvider); + + vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ + authorizationUrl: 'http://auth.example.com/auth', + tokenUrl: 'http://auth.example.com/token', + scopes: ['test-scope'], + }); + }); + + afterEach(() => { + vi.clearAllMocks(); + }); + + it('should handle HTTP 404 → SSE 401 → OAuth → SSE+OAuth succeeds', async () => { + // Tests that OAuth flow works when SSE (not HTTP) requires auth + vi.mocked(mockedClient.connect) + .mockRejectedValueOnce(new StreamableHTTPError(404, 'Not Found')) + .mockRejectedValueOnce(new StreamableHTTPError(401, 'Unauthorized')) + .mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + 'test-server', + { url: 'http://test-server', oauth: { enabled: true } }, + false, + workspaceContext, + ); + + expect(client).toBe(mockedClient); + expect(mockedClient.connect).toHaveBeenCalledTimes(3); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + }); +}); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index c4f9cc390e..e988e3b346 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -40,7 +40,11 @@ import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; -import { getErrorMessage } from '../utils/errors.js'; +import { + getErrorMessage, + isAuthenticationError, + UnauthorizedError, +} from '../utils/errors.js'; import type { Unsubscribe, WorkspaceContext, @@ -443,33 +447,6 @@ function createAuthProvider( return undefined; } -/** - * Create a transport for URL based servers (remote servers). - * - * @param mcpServerConfig The MCP server configuration - * @param transportOptions The transport options - */ -function createUrlTransport( - mcpServerConfig: MCPServerConfig, - transportOptions: - | StreamableHTTPClientTransportOptions - | SSEClientTransportOptions, -): StreamableHTTPClientTransport | SSEClientTransport { - if (mcpServerConfig.httpUrl) { - return new StreamableHTTPClientTransport( - new URL(mcpServerConfig.httpUrl), - transportOptions, - ); - } - if (mcpServerConfig.url) { - return new SSEClientTransport( - new URL(mcpServerConfig.url), - transportOptions, - ); - } - throw new Error('No URL configured for MCP Server'); -} - /** * Create a transport with OAuth token for the given server configuration. * @@ -493,7 +470,7 @@ async function createTransportWithOAuth( requestInit: createTransportRequestInit(mcpServerConfig, headers), }; - return createUrlTransport(mcpServerConfig, transportOptions); + return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions); } catch (error) { coreEvents.emitFeedback( 'error', @@ -921,6 +898,156 @@ export function hasNetworkTransport(config: MCPServerConfig): boolean { return !!(config.url || config.httpUrl); } +/** + * Helper function to retrieve a stored OAuth token for an MCP server. + * Handles token validation and refresh automatically. + * + * @param serverName The name of the MCP server + * @returns The valid access token, or null if no token is stored + */ +async function getStoredOAuthToken(serverName: string): Promise { + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getCredentials(serverName); + if (!credentials) return null; + + const authProvider = new MCPOAuthProvider(tokenStorage); + return authProvider.getValidToken(serverName, { + // Pass client ID if available + clientId: credentials.clientId, + }); +} + +/** + * Helper function to create an SSE transport with optional OAuth authentication. + * + * @param config The MCP server configuration + * @param accessToken Optional OAuth access token for authentication + * @returns A configured SSE transport ready for connection + */ +function createSSETransportWithAuth( + config: MCPServerConfig, + accessToken?: string | null, +): SSEClientTransport { + const headers = { + ...config.headers, + ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}), + }; + + const options: SSEClientTransportOptions = {}; + if (Object.keys(headers).length > 0) { + options.requestInit = { headers }; + } + + return new SSEClientTransport(new URL(config.url!), options); +} + +/** + * Helper function to connect a client using SSE transport with optional OAuth. + * + * @param client The MCP client to connect + * @param config The MCP server configuration + * @param accessToken Optional OAuth access token for authentication + */ +async function connectWithSSETransport( + client: Client, + config: MCPServerConfig, + accessToken?: string | null, +): Promise { + const transport = createSSETransportWithAuth(config, accessToken); + await client.connect(transport, { + timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); +} + +/** + * Helper function to show authentication required message and throw error. + * Checks if there's a stored token that was rejected (requires re-auth). + * + * @param serverName The name of the MCP server + * @throws Always throws an error with authentication instructions + */ +async function showAuthRequiredMessage(serverName: string): Promise { + const hasRejectedToken = !!(await getStoredOAuthToken(serverName)); + + const message = hasRejectedToken + ? `MCP server '${serverName}' rejected stored OAuth token. Please re-authenticate using: /mcp auth ${serverName}` + : `MCP server '${serverName}' requires authentication using: /mcp auth ${serverName}`; + + coreEvents.emitFeedback('info', message); + throw new UnauthorizedError(message); +} + +/** + * Helper function to retry connection with OAuth token after authentication. + * Handles both HTTP and SSE transports based on what previously failed. + * + * @param client The MCP client to connect + * @param serverName The name of the MCP server + * @param config The MCP server configuration + * @param accessToken The OAuth access token to use + * @param httpReturned404 Whether the HTTP transport returned 404 (indicating SSE-only server) + */ +async function retryWithOAuth( + client: Client, + serverName: string, + config: MCPServerConfig, + accessToken: string, + httpReturned404: boolean, +): Promise { + if (httpReturned404) { + // HTTP returned 404, only try SSE + debugLogger.log( + `Retrying SSE connection to '${serverName}' with OAuth token...`, + ); + await connectWithSSETransport(client, config, accessToken); + debugLogger.log( + `Successfully connected to '${serverName}' using SSE with OAuth.`, + ); + return; + } + + // HTTP returned 401, try HTTP with OAuth first + debugLogger.log(`Retrying connection to '${serverName}' with OAuth token...`); + + const httpTransport = await createTransportWithOAuth( + serverName, + config, + accessToken, + ); + if (!httpTransport) { + throw new Error( + `Failed to create OAuth transport for server '${serverName}'`, + ); + } + + try { + await client.connect(httpTransport, { + timeout: config.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + debugLogger.log( + `Successfully connected to '${serverName}' using HTTP with OAuth.`, + ); + } catch (httpError) { + await httpTransport.close(); + + // If HTTP+OAuth returns 404 and auto-detection enabled, try SSE+OAuth + if ( + String(httpError).includes('404') && + config.url && + !config.type && + !config.httpUrl + ) { + debugLogger.log(`HTTP with OAuth returned 404, trying SSE with OAuth...`); + await connectWithSSETransport(client, config, accessToken); + debugLogger.log( + `Successfully connected to '${serverName}' using SSE with OAuth.`, + ); + } else { + throw httpError; + } + } +} + /** * Creates and connects an MCP client to a server based on the provided configuration. * It determines the appropriate transport (Stdio, SSE, or Streamable HTTP) and @@ -993,6 +1120,10 @@ export async function connectToMcpServer( unlistenDirectories = undefined; }; + let firstAttemptError: Error | null = null; + let httpReturned404 = false; // Track if HTTP returned 404 to skip it in OAuth retry + let sseError: Error | null = null; // Track SSE fallback error + try { const transport = await createTransport( mcpServerName, @@ -1006,52 +1137,79 @@ export async function connectToMcpServer( return mcpClient; } catch (error) { await transport.close(); + firstAttemptError = error as Error; throw error; } - } catch (error) { + } catch (initialError) { + let error = initialError; + + // Check if this is a 401 error FIRST (before attempting SSE fallback) + // This ensures OAuth flow happens before we try SSE + if (isAuthenticationError(error) && hasNetworkTransport(mcpServerConfig)) { + // Continue to OAuth handling below (after SSE fallback section) + } else if ( + // If not 401, and HTTP failed with url without explicit type, try SSE fallback + firstAttemptError && + mcpServerConfig.url && + !mcpServerConfig.type && + !mcpServerConfig.httpUrl + ) { + // Check if HTTP returned 404 - if so, we know it's not an HTTP server + httpReturned404 = String(firstAttemptError).includes('404'); + + const logMessage = httpReturned404 + ? `HTTP returned 404, trying SSE transport...` + : `HTTP connection failed, attempting SSE fallback...`; + debugLogger.log(`MCP server '${mcpServerName}': ${logMessage}`); + + try { + // Try SSE with stored OAuth token if available + // This ensures that SSE fallback works for authenticated servers + await connectWithSSETransport( + mcpClient, + mcpServerConfig, + await getStoredOAuthToken(mcpServerName), + ); + + debugLogger.log( + `MCP server '${mcpServerName}': Successfully connected using SSE transport.`, + ); + return mcpClient; + } catch (sseFallbackError) { + sseError = sseFallbackError as Error; + + // If SSE also returned 401, handle OAuth below + if (isAuthenticationError(sseError)) { + debugLogger.log( + `MCP server '${mcpServerName}': SSE returned 401, OAuth authentication required.`, + ); + // Update error to be the SSE error for OAuth handling + error = sseError; + // Continue to OAuth handling below + } else { + debugLogger.log( + `MCP server '${mcpServerName}': SSE fallback also failed.`, + ); + // Both failed without 401, throw the original error + throw firstAttemptError; + } + } + } + // Check if this is a 401 error that might indicate OAuth is required - const errorString = String(error); - if (errorString.includes('401') && hasNetworkTransport(mcpServerConfig)) { + if (isAuthenticationError(error) && hasNetworkTransport(mcpServerConfig)) { mcpServerRequiresOAuth.set(mcpServerName, true); - // Only trigger automatic OAuth discovery for HTTP servers or when OAuth is explicitly configured - // For SSE servers, we should not trigger new OAuth flows automatically - const shouldTriggerOAuth = - mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; + + // Only trigger automatic OAuth if explicitly enabled in config + // Otherwise, show error and tell user to run /mcp auth command + const shouldTriggerOAuth = mcpServerConfig.oauth?.enabled; if (!shouldTriggerOAuth) { - // For SSE servers without explicit OAuth config, if a token was found but rejected, report it accurately. - const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getCredentials(mcpServerName); - if (credentials) { - const authProvider = new MCPOAuthProvider(tokenStorage); - const hasStoredTokens = await authProvider.getValidToken( - mcpServerName, - { - // Pass client ID if available - clientId: credentials.clientId, - }, - ); - if (hasStoredTokens) { - coreEvents.emitFeedback( - 'error', - `Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` + - `Please re-authenticate using: /mcp auth ${mcpServerName}`, - ); - } else { - coreEvents.emitFeedback( - 'error', - `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + - `Please authenticate using: /mcp auth ${mcpServerName}`, - ); - } - } - throw new Error( - `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + - `Please authenticate using: /mcp auth ${mcpServerName}`, - ); + await showAuthRequiredMessage(mcpServerName); } // Try to extract www-authenticate header from the error + const errorString = String(error); let wwwAuthenticate = extractWWWAuthenticateHeader(errorString); // If we didn't get the header from the error string, try to get it from the server @@ -1061,12 +1219,27 @@ export async function connectToMcpServer( ); try { const urlToFetch = mcpServerConfig.httpUrl || mcpServerConfig.url!; + + // Determine correct Accept header based on what transport failed + let acceptHeader: string; + if (mcpServerConfig.httpUrl) { + acceptHeader = 'application/json'; + } else if (mcpServerConfig.type === 'http') { + acceptHeader = 'application/json'; + } else if (mcpServerConfig.type === 'sse') { + acceptHeader = 'text/event-stream'; + } else if (httpReturned404) { + // HTTP failed with 404, SSE returned 401 - use SSE header + acceptHeader = 'text/event-stream'; + } else { + // HTTP returned 401 - use HTTP header + acceptHeader = 'application/json'; + } + const response = await fetch(urlToFetch, { method: 'HEAD', headers: { - Accept: mcpServerConfig.httpUrl - ? 'application/json' - : 'text/event-stream', + Accept: acceptHeader, }, signal: AbortSignal.timeout(5000), }); @@ -1101,52 +1274,21 @@ export async function connectToMcpServer( ); if (oauthSuccess) { // Retry connection with OAuth token - debugLogger.log( - `Retrying connection to '${mcpServerName}' with OAuth token...`, - ); - - // Get the valid token - we need to create a proper OAuth config - // The token should already be available from the authentication process - const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getCredentials(mcpServerName); - if (credentials) { - const authProvider = new MCPOAuthProvider(tokenStorage); - const accessToken = await authProvider.getValidToken( - mcpServerName, - { - // Pass client ID if available - clientId: credentials.clientId, - }, - ); - - if (accessToken) { - // Create transport with OAuth token - const oauthTransport = await createTransportWithOAuth( - mcpServerName, - mcpServerConfig, - accessToken, - ); - if (oauthTransport) { - await mcpClient.connect(oauthTransport, { - timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - }); - // Connection successful with OAuth - return mcpClient; - } else { - throw new Error( - `Failed to create OAuth transport for server '${mcpServerName}'`, - ); - } - } else { - throw new Error( - `Failed to get OAuth token for server '${mcpServerName}'`, - ); - } - } else { + const accessToken = await getStoredOAuthToken(mcpServerName); + if (!accessToken) { throw new Error( - `Failed to get credentials for server '${mcpServerName}' after successful OAuth authentication`, + `Failed to get OAuth token for server '${mcpServerName}'`, ); } + + await retryWithOAuth( + mcpClient, + mcpServerName, + mcpServerConfig, + accessToken, + httpReturned404, + ); + return mcpClient; } else { throw new Error( `Failed to handle automatic OAuth for server '${mcpServerName}'`, @@ -1154,41 +1296,11 @@ export async function connectToMcpServer( } } else { // No www-authenticate header found, but we got a 401 - // Only try OAuth discovery for HTTP servers or when OAuth is explicitly configured - // For SSE servers, we should not trigger new OAuth flows automatically - const shouldTryDiscovery = - mcpServerConfig.httpUrl || mcpServerConfig.oauth?.enabled; + // Only try OAuth discovery when OAuth is explicitly enabled in config + const shouldTryDiscovery = mcpServerConfig.oauth?.enabled; if (!shouldTryDiscovery) { - const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getCredentials(mcpServerName); - if (credentials) { - const authProvider = new MCPOAuthProvider(tokenStorage); - const hasStoredTokens = await authProvider.getValidToken( - mcpServerName, - { - // Pass client ID if available - clientId: credentials.clientId, - }, - ); - if (hasStoredTokens) { - coreEvents.emitFeedback( - 'error', - `Stored OAuth token for SSE server '${mcpServerName}' was rejected. ` + - `Please re-authenticate using: /mcp auth ${mcpServerName}`, - ); - } else { - coreEvents.emitFeedback( - 'error', - `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + - `Please authenticate using: /mcp auth ${mcpServerName}`, - ); - } - } - throw new Error( - `401 error received for SSE server '${mcpServerName}' without OAuth configuration. ` + - `Please authenticate using: /mcp auth ${mcpServerName}`, - ); + await showAuthRequiredMessage(mcpServerName); } // For SSE/HTTP servers, try to discover OAuth configuration from the base URL @@ -1234,47 +1346,30 @@ export async function connectToMcpServer( ); // Retry connection with OAuth token - const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = - await tokenStorage.getCredentials(mcpServerName); - if (credentials) { - const authProvider = new MCPOAuthProvider(tokenStorage); - const accessToken = await authProvider.getValidToken( - mcpServerName, - { - // Pass client ID if available - clientId: credentials.clientId, - }, - ); - if (accessToken) { - // Create transport with OAuth token - const oauthTransport = await createTransportWithOAuth( - mcpServerName, - mcpServerConfig, - accessToken, - ); - if (oauthTransport) { - await mcpClient.connect(oauthTransport, { - timeout: - mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, - }); - // Connection successful with OAuth - return mcpClient; - } else { - throw new Error( - `Failed to create OAuth transport for server '${mcpServerName}'`, - ); - } - } else { - throw new Error( - `Failed to get OAuth token for server '${mcpServerName}'`, - ); - } - } else { + const accessToken = await getStoredOAuthToken(mcpServerName); + if (!accessToken) { throw new Error( - `Failed to get stored credentials for server '${mcpServerName}'`, + `Failed to get OAuth token for server '${mcpServerName}'`, ); } + + // Create transport with OAuth token + const oauthTransport = await createTransportWithOAuth( + mcpServerName, + mcpServerConfig, + accessToken, + ); + if (!oauthTransport) { + throw new Error( + `Failed to create OAuth transport for server '${mcpServerName}'`, + ); + } + + await mcpClient.connect(oauthTransport, { + timeout: mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + // Connection successful with OAuth + return mcpClient; } else { throw new Error( `OAuth configuration failed for '${mcpServerName}'. Please authenticate manually with /mcp auth ${mcpServerName}`, @@ -1288,27 +1383,63 @@ export async function connectToMcpServer( } } else { // Handle other connection errors - // Create a concise error message - const errorMessage = (error as Error).message || String(error); - const isNetworkError = - errorMessage.includes('ENOTFOUND') || - errorMessage.includes('ECONNREFUSED'); - - let conciseError: string; - if (isNetworkError) { - conciseError = `Cannot connect to '${mcpServerName}' - server may be down or URL incorrect`; - } else { - conciseError = `Connection failed for '${mcpServerName}': ${errorMessage}`; - } - - if (process.env['SANDBOX']) { - conciseError += ` (check sandbox availability)`; - } - - throw new Error(conciseError); + // Re-throw the original error to preserve its structure + throw error; } } } + +/** + * Helper function to create the appropriate transport based on config + * This handles the logic for httpUrl/url/type consistently + */ +function createUrlTransport( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, + transportOptions: + | StreamableHTTPClientTransportOptions + | SSEClientTransportOptions, +): StreamableHTTPClientTransport | SSEClientTransport { + // Priority 1: httpUrl (deprecated) + if (mcpServerConfig.httpUrl) { + if (mcpServerConfig.url) { + debugLogger.warn( + `MCP server '${mcpServerName}': Both 'httpUrl' and 'url' are configured. ` + + `Using deprecated 'httpUrl'. Please migrate to 'url' with 'type: "http"'.`, + ); + } + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.httpUrl), + transportOptions, + ); + } + + // Priority 2 & 3: url with explicit type + if (mcpServerConfig.url && mcpServerConfig.type) { + if (mcpServerConfig.type === 'http') { + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } else if (mcpServerConfig.type === 'sse') { + return new SSEClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } + } + + // Priority 4: url without type (default to HTTP) + if (mcpServerConfig.url) { + return new StreamableHTTPClientTransport( + new URL(mcpServerConfig.url), + transportOptions, + ); + } + + throw new Error(`No URL configured for MCP server '${mcpServerName}'`); +} + /** Visible for Testing */ export async function createTransport( mcpServerName: string, @@ -1333,7 +1464,6 @@ export async function createTransport( ); } } - if (mcpServerConfig.httpUrl || mcpServerConfig.url) { const authProvider = createAuthProvider(mcpServerConfig); const headers: Record = @@ -1342,8 +1472,7 @@ export async function createTransport( if (authProvider === undefined) { // Check if we have OAuth configuration or stored tokens let accessToken: string | null = null; - let hasOAuthConfig = mcpServerConfig.oauth?.enabled; - if (hasOAuthConfig && mcpServerConfig.oauth) { + if (mcpServerConfig.oauth?.enabled && mcpServerConfig.oauth) { const tokenStorage = new MCPOAuthTokenStorage(); const mcpAuthProvider = new MCPOAuthProvider(tokenStorage); accessToken = await mcpAuthProvider.getValidToken( @@ -1352,31 +1481,22 @@ export async function createTransport( ); if (!accessToken) { - throw new Error( - `MCP server '${mcpServerName}' requires OAuth authentication. ` + - `Please authenticate using the /mcp auth command.`, + // Emit info message (not error) since this is expected behavior + coreEvents.emitFeedback( + 'info', + `MCP server '${mcpServerName}' requires authentication using: /mcp auth ${mcpServerName}`, ); } } else { // Check if we have stored OAuth tokens for this server (from previous authentication) - const tokenStorage = new MCPOAuthTokenStorage(); - const credentials = await tokenStorage.getCredentials(mcpServerName); - if (credentials) { - const mcpAuthProvider = new MCPOAuthProvider(tokenStorage); - accessToken = await mcpAuthProvider.getValidToken(mcpServerName, { - // Pass client ID if available - clientId: credentials.clientId, - }); - - if (accessToken) { - hasOAuthConfig = true; - debugLogger.log( - `Found stored OAuth token for server '${mcpServerName}'`, - ); - } + accessToken = await getStoredOAuthToken(mcpServerName); + if (accessToken) { + debugLogger.log( + `Found stored OAuth token for server '${mcpServerName}'`, + ); } } - if (hasOAuthConfig && accessToken) { + if (accessToken) { headers['Authorization'] = `Bearer ${accessToken}`; } } @@ -1388,7 +1508,7 @@ export async function createTransport( authProvider, }; - return createUrlTransport(mcpServerConfig, transportOptions); + return createUrlTransport(mcpServerName, mcpServerConfig, transportOptions); } if (mcpServerConfig.command) { diff --git a/packages/core/src/utils/errors.test.ts b/packages/core/src/utils/errors.test.ts new file mode 100644 index 0000000000..c7f8c5f287 --- /dev/null +++ b/packages/core/src/utils/errors.test.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, it, expect } from 'vitest'; +import { isAuthenticationError, UnauthorizedError } from './errors.js'; + +describe('isAuthenticationError', () => { + it('should detect error with code: 401 property (MCP SDK style)', () => { + const error = { code: 401, message: 'Unauthorized' }; + expect(isAuthenticationError(error)).toBe(true); + }); + + it('should detect UnauthorizedError instance', () => { + const error = new UnauthorizedError('Authentication required'); + expect(isAuthenticationError(error)).toBe(true); + }); + + it('should return false for 404 errors', () => { + const error = { code: 404, message: 'Not Found' }; + expect(isAuthenticationError(error)).toBe(false); + }); + + it('should handle null and undefined gracefully', () => { + expect(isAuthenticationError(null)).toBe(false); + expect(isAuthenticationError(undefined)).toBe(false); + }); + + it('should handle non-error objects', () => { + expect(isAuthenticationError('string error')).toBe(false); + expect(isAuthenticationError(123)).toBe(false); + expect(isAuthenticationError({})).toBe(false); + }); + + it('should detect 401 in various message formats', () => { + expect(isAuthenticationError(new Error('401 Unauthorized'))).toBe(true); + expect(isAuthenticationError(new Error('HTTP 401'))).toBe(true); + expect(isAuthenticationError(new Error('Status code: 401'))).toBe(true); + }); +}); diff --git a/packages/core/src/utils/errors.ts b/packages/core/src/utils/errors.ts index fa5d8bf6d3..a91fafb922 100644 --- a/packages/core/src/utils/errors.ts +++ b/packages/core/src/utils/errors.ts @@ -117,3 +117,42 @@ function parseResponseData(error: GaxiosError): ResponseData { } return error.response?.data as ResponseData; } + +/** + * Checks if an error is a 401 authentication error. + * Uses structured error properties from MCP SDK errors. + * + * @param error The error to check + * @returns true if this is a 401/authentication error + */ +export function isAuthenticationError(error: unknown): boolean { + // Check for MCP SDK errors with code property + // (SseError and StreamableHTTPError both have numeric 'code' property) + if (error && typeof error === 'object' && 'code' in error) { + const errorCode = (error as { code: unknown }).code; + if (errorCode === 401) { + return true; + } + } + + // Check for UnauthorizedError class (from MCP SDK or our own) + if ( + error instanceof Error && + error.constructor.name === 'UnauthorizedError' + ) { + return true; + } + + if (error instanceof UnauthorizedError) { + return true; + } + + // Fallback: Check for MCP SDK's plain Error messages with HTTP 401 + // The SDK sometimes throws: new Error(`Error POSTing to endpoint (HTTP 401): ...`) + const message = getErrorMessage(error); + if (message.includes('401')) { + return true; + } + + return false; +}