diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2019-11-08 21:05:19 +0100 |
---|---|---|
committer | Marc Glisse <marc.glisse@inria.fr> | 2019-11-08 21:05:19 +0100 |
commit | 60c52012578265e6b6ac2e4a616cf2b617809d2c (patch) | |
tree | e958905af656f72228f9e778464739093635d35b /src/python | |
parent | 7c80dd28eb16e70316e6acc0bde8f698f79b2003 (diff) | |
parent | db405e686cc859e510b894dca45562158cb5c963 (diff) |
Merge remote-tracking branch 'origin/master' into sklearn_tda
Diffstat (limited to 'src/python')
-rw-r--r-- | src/python/CMakeLists.txt | 67 | ||||
-rw-r--r-- | src/python/doc/index.rst | 7 | ||||
-rw-r--r-- | src/python/doc/installation.rst | 20 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_sum.inc | 14 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 40 | ||||
-rw-r--r-- | src/python/gudhi/__init__.py.in | 9 | ||||
-rw-r--r-- | src/python/gudhi/simplex_tree.pyx | 2 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein.py | 99 | ||||
-rw-r--r-- | src/python/include/Alpha_complex_interface.h | 11 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 48 |
10 files changed, 277 insertions, 40 deletions
diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 2756d547..2cc578a6 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -50,6 +50,8 @@ if(PYTHONINTERP_FOUND) set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_witness_complex', ") set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ") set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'sktda', ") + # Modules that should not be auto-imported in __init__.py + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ") add_gudhi_debug_info("Python version ${PYTHON_VERSION_STRING}") add_gudhi_debug_info("Cython version ${CYTHON_VERSION}") @@ -67,6 +69,8 @@ if(PYTHONINTERP_FOUND) endif() if(SKLEARN_FOUND) add_gudhi_debug_info("Scikit-learn version ${SKLEARN_VERSION}") + if(OT_FOUND) + add_gudhi_debug_info("POT version ${OT_VERSION}") endif() set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-DBOOST_RESULT_OF_USE_DECLTYPE', ") @@ -77,7 +81,7 @@ if(PYTHONINTERP_FOUND) if(MSVC) set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'/fp:strict', ") else(MSVC) - set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-std=c++11', ") + set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-std=c++14', ") endif(MSVC) if(CMAKE_COMPILER_IS_GNUCXX) set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-frounding-math', ") @@ -204,6 +208,7 @@ if(PYTHONINTERP_FOUND) # Other .py files file(COPY "gudhi/persistence_graphical_tools.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/sktda" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/") + file(COPY "gudhi/wasserstein.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") add_custom_command( OUTPUT gudhi.so @@ -376,37 +381,47 @@ if(PYTHONINTERP_FOUND) # Reader utils add_gudhi_py_test(test_reader_utils) + # Wasserstein + if(OT_FOUND) + add_gudhi_py_test(test_wasserstein_distance) + endif(OT_FOUND) + # Documentation generation is available through sphinx - requires all modules if(SPHINX_PATH) if(MATPLOTLIB_FOUND) if(NUMPY_FOUND) if(SCIPY_FOUND) - if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) - set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") - # User warning - Sphinx is a static pages generator, and configured to work fine with user_version - # Images and biblio warnings because not found on developper version - if (GUDHI_PYTHON_PATH STREQUAL "src/python") - set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") - endif() - # sphinx target requires gudhi.so, because conf.py reads gudhi version from it - add_custom_target(sphinx - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" - ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx - DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so" - COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM) - - add_test(NAME sphinx_py_test - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" - ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest) - - # Set missing or not modules - set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES") - else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) - message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0") + if(OT_FOUND) + if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") + # User warning - Sphinx is a static pages generator, and configured to work fine with user_version + # Images and biblio warnings because not found on developper version + if (GUDHI_PYTHON_PATH STREQUAL "src/python") + set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") + endif() + # sphinx target requires gudhi.so, because conf.py reads gudhi version from it + add_custom_target(sphinx + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" + ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so" + COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM) + + add_test(NAME sphinx_py_test + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" + ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest) + + # Set missing or not modules + set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES") + else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0") + set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") + endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + else(OT_FOUND) + message("++ Python documentation module will not be compiled because POT was not found") set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") - endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + endif(OT_FOUND) else(SCIPY_FOUND) message("++ Python documentation module will not be compiled because scipy was not found") set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") diff --git a/src/python/doc/index.rst b/src/python/doc/index.rst index e379bc23..8f27da0d 100644 --- a/src/python/doc/index.rst +++ b/src/python/doc/index.rst @@ -23,7 +23,7 @@ Alpha complex .. include:: alpha_complex_sum.inc Rips complex -------------- +------------ .. include:: rips_complex_sum.inc @@ -73,6 +73,11 @@ Bottleneck distance .. include:: bottleneck_distance_sum.inc +Wasserstein distance +==================== + +.. include:: wasserstein_distance_sum.inc + Persistence graphical tools =========================== diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index d8b6f861..7699a5bb 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -8,11 +8,11 @@ Installation Conda ***** The easiest way to install the Python version of GUDHI is using -`conda <https://gudhi.inria.fr/licensing/>`_. +`conda <https://gudhi.inria.fr/conda/>`_. Compiling ********* -The library uses c++11 and requires `Boost <https://www.boost.org/>`_ ≥ 1.56.0, +The library uses c++14 and requires `Boost <https://www.boost.org/>`_ ≥ 1.56.0, `CMake <https://www.cmake.org/>`_ ≥ 3.1 to generate makefiles, `NumPy <http://numpy.org>`_ and `Cython <https://www.cython.org/>`_ to compile the GUDHI Python module. @@ -138,7 +138,7 @@ Documentation To build the documentation, `sphinx-doc <http://www.sphinx-doc.org>`_ and `sphinxcontrib-bibtex <https://sphinxcontrib-bibtex.readthedocs.io>`_ are -required. As the documentation is auto-tested, `CGAL`_, `Eigen3`_, +required. As the documentation is auto-tested, `CGAL`_, `Eigen`_, `Matplotlib`_, `NumPy`_ and `SciPy`_ are also mandatory to build the documentation. @@ -215,12 +215,20 @@ The following examples require the `Matplotlib <http://matplotlib.org>`_: * :download:`euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py <../example/euclidean_strong_witness_complex_diagram_persistence_from_off_file_example.py>` * :download:`euclidean_witness_complex_diagram_persistence_from_off_file_example.py <../example/euclidean_witness_complex_diagram_persistence_from_off_file_example.py>` +Python Optimal Transport +======================== + +The :doc:`Wasserstein distance </wasserstein_distance_user>` +module requires `POT <https://pot.readthedocs.io/>`_, a library that provides +several solvers for optimization problems related to Optimal Transport. + SciPy ===== -The :doc:`persistence graphical tools </persistence_graphical_tools_user>` -module requires `SciPy <http://scipy.org>`_, a Python-based ecosystem of -open-source software for mathematics, science, and engineering. +The :doc:`persistence graphical tools </persistence_graphical_tools_user>` and +:doc:`Wasserstein distance </wasserstein_distance_user>` modules require `SciPy +<http://scipy.org>`_, a Python-based ecosystem of open-source software for +mathematics, science, and engineering. Threading Building Blocks ========================= diff --git a/src/python/doc/wasserstein_distance_sum.inc b/src/python/doc/wasserstein_distance_sum.inc new file mode 100644 index 00000000..ffd4d312 --- /dev/null +++ b/src/python/doc/wasserstein_distance_sum.inc @@ -0,0 +1,14 @@ +.. table:: + :widths: 30 50 20 + + +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+ + | .. figure:: | The p-Wasserstein distance measures the similarity between two | :Author: Theo Lacombe | + | ../../doc/Bottleneck_distance/perturb_pd.png | persistence diagrams. It's the minimum value c that can be achieved | | + | :figclass: align-center | by a perfect matching between the points of the two diagrams (+ all | :Introduced in: GUDHI 3.1.0 | + | | diagonal points), where the value of a matching is defined as the | | + | Wasserstein distance is the p-th root of the sum of the | p-th root of the sum of all edge lengths to the power p. Edge lengths| :Copyright: MIT | + | edge lengths to the power p. | are measured in norm q, for :math:`1 \leq q \leq \infty`. | | + | | | :Requires: Python Optimal Transport (POT) :math:`\geq` 0.5.1 | + +-----------------------------------------------------------------+----------------------------------------------------------------------+------------------------------------------------------------------+ + | * :doc:`wasserstein_distance_user` | | + +-----------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst new file mode 100644 index 00000000..a049cfb5 --- /dev/null +++ b/src/python/doc/wasserstein_distance_user.rst @@ -0,0 +1,40 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +Wasserstein distance user manual +================================ +Definition +---------- + +.. include:: wasserstein_distance_sum.inc + +This implementation is based on ideas from "Large Scale Computation of Means and Cluster for Persistence Diagrams via Optimal Transport". + +Function +-------- +.. autofunction:: gudhi.wasserstein.wasserstein_distance + + +Basic example +------------- + +This example computes the 1-Wasserstein distance from 2 persistence diagrams with euclidean ground metric. +Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values. + +.. testcode:: + + import gudhi.wasserstein + import numpy as np + + diag1 = np.array([[2.7, 3.7],[9.6, 14.],[34.2, 34.974]]) + diag2 = np.array([[2.8, 4.45],[9.5, 14.1]]) + + message = "Wasserstein distance value = " + '%.2f' % gudhi.wasserstein.wasserstein_distance(diag1, diag2, q=2., p=1.) + print(message) + +The output is: + +.. testoutput:: + + Wasserstein distance value = 1.45 diff --git a/src/python/gudhi/__init__.py.in b/src/python/gudhi/__init__.py.in index 28bab0e1..02888fff 100644 --- a/src/python/gudhi/__init__.py.in +++ b/src/python/gudhi/__init__.py.in @@ -21,13 +21,16 @@ __debug_info__ = @GUDHI_PYTHON_DEBUG_INFO@ from sys import exc_info from importlib import import_module -__all__ = [@GUDHI_PYTHON_MODULES@] +__all__ = [@GUDHI_PYTHON_MODULES@ @GUDHI_PYTHON_MODULES_EXTRA@] __available_modules = '' __missing_modules = '' -# try to import * from gudhi.__module_name -for __module_name in __all__: +# Try to import * from gudhi.__module_name for default modules. +# Extra modules require an explicit import by the user (mostly because of +# unusual dependencies, but also to avoid cluttering namespace gudhi and +# speed up the basic import) +for __module_name in [@GUDHI_PYTHON_MODULES@]: try: __module = import_module('gudhi.' + __module_name) try: diff --git a/src/python/gudhi/simplex_tree.pyx b/src/python/gudhi/simplex_tree.pyx index 9f490271..90ddf6dd 100644 --- a/src/python/gudhi/simplex_tree.pyx +++ b/src/python/gudhi/simplex_tree.pyx @@ -362,7 +362,7 @@ cdef class SimplexTree: value than its faces by increasing the filtration values. :returns: True if any filtration value was modified, - False if the filtration was already non-decreasing. + False if the filtration was already non-decreasing. :rtype: bool diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py new file mode 100644 index 00000000..eba7c6d5 --- /dev/null +++ b/src/python/gudhi/wasserstein.py @@ -0,0 +1,99 @@ +import numpy as np +import scipy.spatial.distance as sc +try: + import ot +except ImportError: + print("POT (Python Optimal Transport) package is not installed. Try to run $ conda install -c conda-forge pot ; or $ pip install POT") + +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Theo Lacombe + + Copyright (C) 2019 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +def _proj_on_diag(X): + ''' + :param X: (n x 2) array encoding the points of a persistent diagram. + :returns: (n x 2) array encoding the (respective orthogonal) projections of the points onto the diagonal + ''' + Z = (X[:,0] + X[:,1]) / 2. + return np.array([Z , Z]).T + + +def _build_dist_matrix(X, Y, p=2., q=2.): + ''' + :param X: (n x 2) numpy.array encoding the (points of the) first diagram. + :param Y: (m x 2) numpy.array encoding the second diagram. + :param q: Ground metric (i.e. norm l_q). + :param p: exponent for the Wasserstein metric. + :returns: (n+1) x (m+1) np.array encoding the cost matrix C. + For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal. + note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal). + ''' + Xdiag = _proj_on_diag(X) + Ydiag = _proj_on_diag(Y) + if np.isinf(q): + C = sc.cdist(X,Y, metric='chebyshev')**p + Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p + Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p + else: + C = sc.cdist(X,Y, metric='minkowski', p=q)**p + Cxd = np.linalg.norm(X - Xdiag, ord=q, axis=1)**p + Cdy = np.linalg.norm(Y - Ydiag, ord=q, axis=1)**p + Cf = np.hstack((C, Cxd[:,None])) + Cdy = np.append(Cdy, 0) + + Cf = np.vstack((Cf, Cdy[None,:])) + + return Cf + + +def _perstot(X, p, q): + ''' + :param X: (n x 2) numpy.array (points of a given diagram). + :param q: Ground metric on the (upper-half) plane (i.e. norm l_q in R^2); Default value is 2 (Euclidean norm). + :param p: exponent for Wasserstein; Default value is 2. + :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram). + ''' + Xdiag = _proj_on_diag(X) + return (np.sum(np.linalg.norm(X - Xdiag, ord=q, axis=1)**p))**(1./p) + + +def wasserstein_distance(X, Y, p=2., q=2.): + ''' + :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). + :param Y: (m x 2) numpy.array encoding the second diagram. + :param q: Ground metric on the (upper-half) plane (i.e. norm l_q in R^2); Default value is 2 (euclidean norm). + :param p: exponent for Wasserstein; Default value is 2. + :returns: the p-Wasserstein distance (1 <= p < infinity) with respect to the q-norm as ground metric. + :rtype: float + ''' + n = len(X) + m = len(Y) + + # handle empty diagrams + if X.size == 0: + if Y.size == 0: + return 0. + else: + return _perstot(Y, p, q) + elif Y.size == 0: + return _perstot(X, p, q) + + M = _build_dist_matrix(X, Y, p=p, q=q) + a = np.full(n+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. + a[-1] = a[-1] * m # normalized so that we have a probability measure, required by POT + b = np.full(m+1, 1. / (n + m) ) # weight vector of the input diagram. Uniform here. + b[-1] = b[-1] * n # so that we have a probability measure, required by POT + + # Comptuation of the otcost using the ot.emd2 library. + # Note: it is the squared Wasserstein distance. + # The default numItermax=100000 is not sufficient for some examples with 5000 points, what is a good value? + ot_cost = (n+m) * ot.emd2(a, b, M, numItermax=2000000) + + return ot_cost ** (1./p) + diff --git a/src/python/include/Alpha_complex_interface.h b/src/python/include/Alpha_complex_interface.h index b3553d32..96353cc4 100644 --- a/src/python/include/Alpha_complex_interface.h +++ b/src/python/include/Alpha_complex_interface.h @@ -15,6 +15,8 @@ #include <gudhi/Alpha_complex.h> #include <CGAL/Epick_d.h> +#include <boost/range/adaptor/transformed.hpp> + #include "Simplex_tree_interface.h" #include <iostream> @@ -31,7 +33,10 @@ class Alpha_complex_interface { public: Alpha_complex_interface(const std::vector<std::vector<double>>& points) { - alpha_complex_ = new Alpha_complex<Dynamic_kernel>(points); + auto mkpt = [](std::vector<double> const& vec){ + return Point_d(vec.size(), vec.begin(), vec.end()); + }; + alpha_complex_ = new Alpha_complex<Dynamic_kernel>(boost::adaptors::transform(points, mkpt)); } Alpha_complex_interface(const std::string& off_file_name, bool from_file = true) { @@ -45,9 +50,9 @@ class Alpha_complex_interface { std::vector<double> get_point(int vh) { std::vector<double> vd; try { - Point_d ph = alpha_complex_->get_point(vh); + Point_d const& ph = alpha_complex_->get_point(vh); for (auto coord = ph.cartesian_begin(); coord < ph.cartesian_end(); coord++) - vd.push_back(*coord); + vd.push_back(CGAL::to_double(*coord)); } catch (std::out_of_range const&) { // std::out_of_range is thrown in case not found. Other exceptions must be re-thrown } diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py new file mode 100755 index 00000000..a6bf9901 --- /dev/null +++ b/src/python/test/test_wasserstein_distance.py @@ -0,0 +1,48 @@ +from gudhi.wasserstein import wasserstein_distance +import numpy as np + +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Theo Lacombe + + Copyright (C) 2019 Inria + + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +__author__ = "Theo Lacombe" +__copyright__ = "Copyright (C) 2019 Inria" +__license__ = "MIT" + + +def test_basic_wasserstein(): + diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) + diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) + diag3 = np.array([[0, 2], [4, 6]]) + diag4 = np.array([[0, 3], [4, 8]]) + emptydiag = np.array([[]]) + + assert wasserstein_distance(emptydiag, emptydiag, q=2., p=1.) == 0. + assert wasserstein_distance(emptydiag, emptydiag, q=np.inf, p=1.) == 0. + assert wasserstein_distance(emptydiag, emptydiag, q=np.inf, p=2.) == 0. + assert wasserstein_distance(emptydiag, emptydiag, q=2., p=2.) == 0. + + assert wasserstein_distance(diag3, emptydiag, q=np.inf, p=1.) == 2. + assert wasserstein_distance(diag3, emptydiag, q=1., p=1.) == 4. + + assert wasserstein_distance(diag4, emptydiag, q=1., p=2.) == 5. # thank you Pythagorician triplets + assert wasserstein_distance(diag4, emptydiag, q=np.inf, p=2.) == 2.5 + assert wasserstein_distance(diag4, emptydiag, q=2., p=2.) == 3.5355339059327378 + + assert wasserstein_distance(diag1, diag2, q=2., p=1.) == 1.4453593023967701 + assert wasserstein_distance(diag1, diag2, q=2.35, p=1.74) == 0.9772734057168739 + + assert wasserstein_distance(diag1, emptydiag, q=2.35, p=1.7863) == 3.141592214572228 + + assert wasserstein_distance(diag3, diag4, q=1., p=1.) == 3. + assert wasserstein_distance(diag3, diag4, q=np.inf, p=1.) == 3. # no diag matching here + assert wasserstein_distance(diag3, diag4, q=np.inf, p=2.) == np.sqrt(5) + assert wasserstein_distance(diag3, diag4, q=1., p=2.) == np.sqrt(5) + assert wasserstein_distance(diag3, diag4, q=4.5, p=2.) == np.sqrt(5) + |