diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 19430c2f9a..3e592825dd 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -2056,6 +2056,90 @@ describe('connectToMcpServer with OAuth', () => { capturedTransport._requestInit?.headers?.['Authorization']; expect(authHeader).toBe('Bearer test-access-token-from-discovery'); }); + + it('should use discoverOAuthFromWWWAuthenticate when it succeeds and skip discoverOAuthConfig', async () => { + const serverUrl = 'http://test-server.com/mcp'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + const wwwAuthHeader = `Bearer realm="test", resource_metadata="http://test-server.com/.well-known/oauth-protected-resource"`; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new StreamableHTTPError( + 401, + `Unauthorized\nwww-authenticate: ${wwwAuthHeader}`, + ), + ); + + vi.mocked(OAuthUtils.discoverOAuthFromWWWAuthenticate).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['read'], + }); + + vi.mocked(mockedClient.connect).mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + '0.0.1', + 'test-server', + { httpUrl: serverUrl, oauth: { enabled: true } }, + false, + workspaceContext, + EMPTY_CONFIG, + ); + + expect(client).toBe(mockedClient); + expect(OAuthUtils.discoverOAuthFromWWWAuthenticate).toHaveBeenCalledWith( + wwwAuthHeader, + serverUrl, + ); + expect(OAuthUtils.discoverOAuthConfig).not.toHaveBeenCalled(); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + }); + + it('should fall back to extractBaseUrl + discoverOAuthConfig when discoverOAuthFromWWWAuthenticate returns null', async () => { + const serverUrl = 'http://test-server.com/mcp'; + const baseUrl = 'http://test-server.com'; + const authUrl = 'http://auth.example.com/auth'; + const tokenUrl = 'http://auth.example.com/token'; + const wwwAuthHeader = `Bearer realm="test"`; + + vi.mocked(mockedClient.connect).mockRejectedValueOnce( + new StreamableHTTPError( + 401, + `Unauthorized\nwww-authenticate: ${wwwAuthHeader}`, + ), + ); + + vi.mocked(OAuthUtils.discoverOAuthFromWWWAuthenticate).mockResolvedValue( + null, + ); + vi.mocked(OAuthUtils.extractBaseUrl).mockReturnValue(baseUrl); + vi.mocked(OAuthUtils.discoverOAuthConfig).mockResolvedValue({ + authorizationUrl: authUrl, + tokenUrl, + scopes: ['read'], + }); + + vi.mocked(mockedClient.connect).mockResolvedValueOnce(undefined); + + const client = await connectToMcpServer( + '0.0.1', + 'test-server', + { httpUrl: serverUrl, oauth: { enabled: true } }, + false, + workspaceContext, + EMPTY_CONFIG, + ); + + expect(client).toBe(mockedClient); + expect(OAuthUtils.discoverOAuthFromWWWAuthenticate).toHaveBeenCalledWith( + wwwAuthHeader, + serverUrl, + ); + expect(OAuthUtils.extractBaseUrl).toHaveBeenCalledWith(serverUrl); + expect(OAuthUtils.discoverOAuthConfig).toHaveBeenCalledWith(baseUrl); + expect(mockAuthProvider.authenticate).toHaveBeenCalledOnce(); + }); }); describe('connectToMcpServer - HTTP→SSE fallback', () => { diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index a838cf76e5..ccc6bbec3c 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -719,18 +719,17 @@ async function handleAutomaticOAuth( try { debugLogger.log(`🔐 '${mcpServerName}' requires OAuth authentication`); - // Always try to parse the resource metadata URI from the www-authenticate header - let oauthConfig; - const resourceMetadataUri = - OAuthUtils.parseWWWAuthenticateHeader(wwwAuthenticate); - if (resourceMetadataUri) { - oauthConfig = await OAuthUtils.discoverOAuthConfig(resourceMetadataUri); - } else if (hasNetworkTransport(mcpServerConfig)) { + const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url; + + // Try to discover OAuth config from the WWW-Authenticate header first + let oauthConfig = await OAuthUtils.discoverOAuthFromWWWAuthenticate( + wwwAuthenticate, + serverUrl, + ); + + if (!oauthConfig && hasNetworkTransport(mcpServerConfig)) { // Fallback: try to discover OAuth config from the base URL - const serverUrl = new URL( - mcpServerConfig.httpUrl || mcpServerConfig.url!, - ); - const baseUrl = `${serverUrl.protocol}//${serverUrl.host}`; + const baseUrl = OAuthUtils.extractBaseUrl(serverUrl!); oauthConfig = await OAuthUtils.discoverOAuthConfig(baseUrl); } @@ -754,8 +753,6 @@ async function handleAutomaticOAuth( }; // Perform OAuth authentication - // Pass the server URL for proper discovery - const serverUrl = mcpServerConfig.httpUrl || mcpServerConfig.url; debugLogger.log( `Starting OAuth authentication for server '${mcpServerName}'...`, );