diff options
author | Marc Glisse <marc.glisse@inria.fr> | 2020-02-12 12:44:51 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-02-12 12:44:51 +0100 |
commit | bed30b19e57669c0b8ad385f1124586ed3499a2d (patch) | |
tree | 6ddbe2f3015899159818bd6e3003d78dd920d707 | |
parent | ee0f12f1df406c81c6ad860c494eed908021fad9 (diff) | |
parent | d6f3165831d20bf3a91f1ff7e9734a574eaa567a (diff) |
Merge pull request #182 from mglisse/ext
Interface to hera's Wasserstein distance
-rw-r--r-- | .appveyor.yml | 3 | ||||
-rw-r--r-- | .circleci/config.yml | 6 | ||||
-rw-r--r-- | .gitmodules | 3 | ||||
-rw-r--r-- | .travis.yml | 2 | ||||
-rw-r--r-- | Dockerfile_gudhi_installation | 1 | ||||
-rw-r--r-- | README.md | 9 | ||||
-rw-r--r-- | biblio/bibliography.bib | 12 | ||||
m--------- | ext/hera | 0 | ||||
-rw-r--r-- | src/Bottleneck_distance/include/gudhi/Persistence_graph.h | 2 | ||||
-rw-r--r-- | src/cmake/modules/GUDHI_third_party_libraries.cmake | 4 | ||||
-rw-r--r-- | src/cmake/modules/GUDHI_user_version_target.cmake | 5 | ||||
-rw-r--r-- | src/python/CMakeLists.txt | 60 | ||||
-rw-r--r-- | src/python/doc/installation.rst | 5 | ||||
-rw-r--r-- | src/python/doc/wasserstein_distance_user.rst | 17 | ||||
-rw-r--r-- | src/python/gudhi/hera.cc | 71 | ||||
-rw-r--r-- | src/python/gudhi/wasserstein.py | 6 | ||||
-rw-r--r-- | src/python/setup.py.in | 17 | ||||
-rwxr-xr-x | src/python/test/test_wasserstein_distance.py | 61 |
18 files changed, 225 insertions, 59 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 4a76ea0a..34f42dea 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -39,6 +39,7 @@ init: install: + - git submodule update --init - vcpkg install tbb:x64-windows boost-disjoint-sets:x64-windows boost-serialization:x64-windows boost-date-time:x64-windows boost-system:x64-windows boost-filesystem:x64-windows boost-units:x64-windows boost-thread:x64-windows boost-program-options:x64-windows eigen3:x64-windows mpfr:x64-windows mpir:x64-windows cgal:x64-windows - SET PATH=c:\Tools\vcpkg\installed\x64-windows\bin;%PATH% - SET PATH=%PYTHON%;%PYTHON%\Scripts;%PYTHON%\Library\bin;%PATH% @@ -48,7 +49,7 @@ install: - pip --version - python -m pip install --upgrade pip - pip install -U setuptools numpy matplotlib scipy Cython pytest - - pip install -U POT + - pip install -U POT pybind11 build_script: - mkdir build diff --git a/.circleci/config.yml b/.circleci/config.yml index f4073746..4f86cb12 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -51,6 +51,8 @@ jobs: - run: name: Build and test python module. Generates and tests the python documentation command: | + git submodule init + git submodule update mkdir build; cd build; cmake -DUSER_VERSION_DIR=version ..; @@ -78,6 +80,8 @@ jobs: - run: name: Generates the C++ documentation with doxygen command: | + git submodule init + git submodule update mkdir build; cd build; cmake -DCMAKE_BUILD_TYPE=Release -DWITH_GUDHI_EXAMPLE=OFF -DWITH_GUDHI_TEST=OFF -DWITH_GUDHI_UTILITIES=OFF -DWITH_GUDHI_PYTHON=OFF -DUSER_VERSION_DIR=version ..; @@ -97,4 +101,4 @@ workflows: - tests - utils - python - - doxygen
\ No newline at end of file + - doxygen diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..6e8b3ab1 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "ext/hera"] + path = ext/hera + url = https://bitbucket.org/grey_narn/hera.git diff --git a/.travis.yml b/.travis.yml index 60d32ef8..8980be10 100644 --- a/.travis.yml +++ b/.travis.yml @@ -61,7 +61,7 @@ before_cache: install: - python3 -m pip install --upgrade pip setuptools wheel - python3 -m pip install --user pytest Cython sphinx sphinxcontrib-bibtex sphinx-paramlinks matplotlib numpy scipy scikit-learn - - python3 -m pip install --user POT + - python3 -m pip install --user POT pybind11 script: - rm -rf build diff --git a/Dockerfile_gudhi_installation b/Dockerfile_gudhi_installation index 33864d11..f9e8813b 100644 --- a/Dockerfile_gudhi_installation +++ b/Dockerfile_gudhi_installation @@ -42,6 +42,7 @@ RUN apt-get install -y make \ python3-pip \ python3-pytest \ python3-tk \ + python3-pybind11 \ libfreetype6-dev \ pkg-config \ curl @@ -10,6 +10,15 @@ The GUDHI library is a generic open source C++ library, with a Python interface, for Topological Data Analysis (TDA) and Higher Dimensional Geometry Understanding. The library offers state-of-the-art data structures and algorithms to construct simplicial complexes and compute persistent homology. +# Source code + +We recommend that users get official releases from [the GUDHI website](https://gudhi.inria.fr/). + +For potential contributors, to fully checkout GUDHI, after cloning the git repository, you may also need to checkout its submodules using +```sh +git submodule update --init +``` + # Compilation and installation To install GUDHI, you can follow the [C++ compilation procedure](https://gudhi.inria.fr/doc/latest/installation.html), the [Python compilation procedure](https://gudhi.inria.fr/python/latest/installation.html), use our [conda-forge package](https://gudhi.inria.fr/conda/), or [go with Docker](https://gudhi.inria.fr/dockerfile/). diff --git a/biblio/bibliography.bib b/biblio/bibliography.bib index a1b951e0..3bbe7960 100644 --- a/biblio/bibliography.bib +++ b/biblio/bibliography.bib @@ -1180,3 +1180,15 @@ language={English} booktitle = {In Neural Information Processing Systems}, year = {2007} } +@inproceedings{10.5555/3327546.3327645, +author = {Lacombe, Th\'{e}o and Cuturi, Marco and Oudot, Steve}, +title = {Large Scale Computation of Means and Clusters for Persistence Diagrams Using Optimal Transport}, +year = {2018}, +publisher = {Curran Associates Inc.}, +address = {Red Hook, NY, USA}, +booktitle = {Proceedings of the 32nd International Conference on Neural Information Processing Systems}, +pages = {9792–9802}, +numpages = {11}, +location = {Montr\'{e}al, Canada}, +series = {NIPS’18} +} diff --git a/ext/hera b/ext/hera new file mode 160000 +Subproject 9a89971855acefe39dce0e2adadf53b88ca8f68 diff --git a/src/Bottleneck_distance/include/gudhi/Persistence_graph.h b/src/Bottleneck_distance/include/gudhi/Persistence_graph.h index f791e37c..e1e3522e 100644 --- a/src/Bottleneck_distance/include/gudhi/Persistence_graph.h +++ b/src/Bottleneck_distance/include/gudhi/Persistence_graph.h @@ -25,7 +25,7 @@ namespace Gudhi { namespace persistence_diagram { -/** \internal \brief Structure representing an euclidean bipartite graph containing +/** \internal \brief Structure representing a Euclidean bipartite graph containing * the points from the two persistence diagrams (including the projections). * * \ingroup bottleneck_distance diff --git a/src/cmake/modules/GUDHI_third_party_libraries.cmake b/src/cmake/modules/GUDHI_third_party_libraries.cmake index 24a34150..359d1c12 100644 --- a/src/cmake/modules/GUDHI_third_party_libraries.cmake +++ b/src/cmake/modules/GUDHI_third_party_libraries.cmake @@ -35,6 +35,9 @@ if(CGAL_FOUND) include( ${CGAL_USE_FILE} ) endif() +# For those who dislike bundled dependencies, this indicates where to find a preinstalled Hera. +set(HERA_WASSERSTEIN_INCLUDE_DIR ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include CACHE PATH "Directory where one can find Hera's wasserstein.h") + option(WITH_GUDHI_USE_TBB "Build with Intel TBB parallelization" ON) # Find TBB package for parallel sort - not mandatory, just optional. @@ -127,6 +130,7 @@ if( PYTHONINTERP_FOUND ) find_python_module("sphinx") find_python_module("sklearn") find_python_module("ot") + find_python_module("pybind11") endif() if(NOT GUDHI_PYTHON_PATH) diff --git a/src/cmake/modules/GUDHI_user_version_target.cmake b/src/cmake/modules/GUDHI_user_version_target.cmake index 0b361a0f..5047252f 100644 --- a/src/cmake/modules/GUDHI_user_version_target.cmake +++ b/src/cmake/modules/GUDHI_user_version_target.cmake @@ -54,6 +54,9 @@ add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_SOURCE_DIR}/src/GudhUI ${GUDHI_USER_VERSION_DIR}/GudhUI) +add_custom_command(TARGET user_version PRE_BUILD COMMAND ${CMAKE_COMMAND} -E + copy_directory ${CMAKE_SOURCE_DIR}/ext/hera/geom_matching/wasserstein/include ${GUDHI_USER_VERSION_DIR}/ext/hera/geom_matching/wasserstein/include) + set(GUDHI_DIRECTORIES "doc;example;concept;utilities") set(GUDHI_INCLUDE_DIRECTORIES "include/gudhi") @@ -93,4 +96,4 @@ foreach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST}) endforeach() endforeach(GUDHI_INCLUDE_DIRECTORY ${GUDHI_INCLUDE_DIRECTORIES}) -endforeach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST})
\ No newline at end of file +endforeach(GUDHI_MODULE ${GUDHI_MODULES_FULL_LIST}) diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index b558d4c4..090a7446 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -32,6 +32,10 @@ function( add_gudhi_debug_info DEBUG_INFO ) endfunction( add_gudhi_debug_info ) if(PYTHONINTERP_FOUND) + if(PYBIND11_FOUND) + add_gudhi_debug_info("Pybind11 version ${PYBIND11_VERSION}") + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'hera', ") + endif() if(CYTHON_FOUND) set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'off_reader', ") set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'simplex_tree', ") @@ -86,6 +90,7 @@ if(PYTHONINTERP_FOUND) endif(MSVC) if(CMAKE_COMPILER_IS_GNUCXX) set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-frounding-math', ") + set(GUDHI_PYBIND11_EXTRA_COMPILE_ARGS "${GUDHI_PYBIND11_EXTRA_COMPILE_ARGS}'-fvisibility=hidden', ") endif(CMAKE_COMPILER_IS_GNUCXX) if (CMAKE_CXX_COMPILER_ID MATCHES Intel) set(GUDHI_PYTHON_EXTRA_COMPILE_ARGS "${GUDHI_PYTHON_EXTRA_COMPILE_ARGS}'-fp-model strict', ") @@ -390,9 +395,9 @@ endif(CGAL_FOUND) add_gudhi_py_test(test_reader_utils) # Wasserstein - if(OT_FOUND) + if(OT_FOUND AND PYBIND11_FOUND) add_gudhi_py_test(test_wasserstein_distance) - endif(OT_FOUND) + endif() # Representations if(SKLEARN_FOUND AND MATPLOTLIB_FOUND) @@ -406,32 +411,37 @@ endif(CGAL_FOUND) if(SCIPY_FOUND) if(SKLEARN_FOUND) if(OT_FOUND) - if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) - set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") - # User warning - Sphinx is a static pages generator, and configured to work fine with user_version - # Images and biblio warnings because not found on developper version - if (GUDHI_PYTHON_PATH STREQUAL "src/python") - set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") - endif() - # sphinx target requires gudhi.so, because conf.py reads gudhi version from it - add_custom_target(sphinx - WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" - ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx - DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so" - COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM) + if(PYBIND11_FOUND) + if(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + set (GUDHI_SPHINX_MESSAGE "Generating API documentation with Sphinx in ${CMAKE_CURRENT_BINARY_DIR}/sphinx/") + # User warning - Sphinx is a static pages generator, and configured to work fine with user_version + # Images and biblio warnings because not found on developper version + if (GUDHI_PYTHON_PATH STREQUAL "src/python") + set (GUDHI_SPHINX_MESSAGE "${GUDHI_SPHINX_MESSAGE} \n WARNING : Sphinx is configured for user version, you run it on developper version. Images and biblio will miss") + endif() + # sphinx target requires gudhi.so, because conf.py reads gudhi version from it + add_custom_target(sphinx + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/doc + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" + ${SPHINX_PATH} -b html ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/sphinx + DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/gudhi.so" + COMMENT "${GUDHI_SPHINX_MESSAGE}" VERBATIM) - add_test(NAME sphinx_py_test - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} - COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" - ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest) + add_test(NAME sphinx_py_test + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${CMAKE_COMMAND} -E env "PYTHONPATH=${CMAKE_CURRENT_BINARY_DIR}" + ${SPHINX_PATH} -b doctest ${CMAKE_CURRENT_SOURCE_DIR}/doc ${CMAKE_CURRENT_BINARY_DIR}/doctest) - # Set missing or not modules - set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES") - else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) - message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0") + # Set missing or not modules + set(GUDHI_MODULES ${GUDHI_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MODULES") + else(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + message("++ Python documentation module will not be compiled because it requires a Eigen3 and CGAL version >= 4.11.0") + set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") + endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + else(PYBIND11_FOUND) + message("++ Python documentation module will not be compiled because pybind11 was not found") set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") - endif(NOT CGAL_WITH_EIGEN3_VERSION VERSION_LESS 4.11.0) + endif(PYBIND11_FOUND) else(OT_FOUND) message("++ Python documentation module will not be compiled because POT was not found") set(GUDHI_MISSING_MODULES ${GUDHI_MISSING_MODULES} "python-documentation" CACHE INTERNAL "GUDHI_MISSING_MODULES") diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 40f3f44b..d459145b 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -14,10 +14,11 @@ Compiling ********* The library uses c++14 and requires `Boost <https://www.boost.org/>`_ ≥ 1.56.0, `CMake <https://www.cmake.org/>`_ ≥ 3.1 to generate makefiles, -`NumPy <http://numpy.org>`_ and `Cython <https://www.cython.org/>`_ to compile +`NumPy <http://numpy.org>`_, `Cython <https://www.cython.org/>`_ and +`pybind11 <https://github.com/pybind/pybind11>`_ to compile the GUDHI Python module. It is a multi-platform library and compiles on Linux, Mac OSX and Visual -Studio 2015. +Studio 2017. On `Windows <https://wiki.python.org/moin/WindowsCompilers>`_ , only Python ≥ 3.5 are available because of the required Visual Studio version. diff --git a/src/python/doc/wasserstein_distance_user.rst b/src/python/doc/wasserstein_distance_user.rst index 32999a0c..94b454e2 100644 --- a/src/python/doc/wasserstein_distance_user.rst +++ b/src/python/doc/wasserstein_distance_user.rst @@ -9,17 +9,26 @@ Definition .. include:: wasserstein_distance_sum.inc -This implementation is based on ideas from "Large Scale Computation of Means and Cluster for Persistence Diagrams via Optimal Transport". +Functions +--------- +This implementation uses the Python Optimal Transport library and is based on +ideas from "Large Scale Computation of Means and Cluster for Persistence +Diagrams via Optimal Transport" :cite:`10.5555/3327546.3327645`. -Function --------- .. autofunction:: gudhi.wasserstein.wasserstein_distance +This other implementation comes from `Hera +<https://bitbucket.org/grey_narn/hera/src/master/>`_ (BSD-3-Clause) which is +based on "Geometry Helps to Compare Persistence Diagrams" +:cite:`Kerber:2017:GHC:3047249.3064175` by Michael Kerber, Dmitriy +Morozov, and Arnur Nigmetov. + +.. autofunction:: gudhi.hera.wasserstein_distance Basic example ------------- -This example computes the 1-Wasserstein distance from 2 persistence diagrams with euclidean ground metric. +This example computes the 1-Wasserstein distance from 2 persistence diagrams with Euclidean ground metric. Note that persistence diagrams must be submitted as (n x 2) numpy arrays and must not contain inf values. .. testcode:: diff --git a/src/python/gudhi/hera.cc b/src/python/gudhi/hera.cc new file mode 100644 index 00000000..0d562b4c --- /dev/null +++ b/src/python/gudhi/hera.cc @@ -0,0 +1,71 @@ +/* This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + * See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + * Author(s): Marc Glisse + * + * Copyright (C) 2020 Inria + * + * Modification(s): + * - YYYY/MM Author: Description of the modification + */ + +#include <pybind11/pybind11.h> +#include <pybind11/numpy.h> + +#include <boost/range/iterator_range.hpp> + +#include <wasserstein.h> // Hera + +#include <array> + +namespace py = pybind11; +typedef py::array_t<double, py::array::c_style | py::array::forcecast> Dgm; + +double wasserstein_distance( + Dgm d1, Dgm d2, + double wasserstein_power, double internal_p, + double delta) +{ + py::buffer_info buf1 = d1.request(); + py::buffer_info buf2 = d2.request(); + // shape (n,2) or (0) for empty + if((buf1.ndim!=2 || buf1.shape[1]!=2) && (buf1.ndim!=1 || buf1.shape[0]!=0)) + throw std::runtime_error("Diagram 1 must be an array of size n x 2"); + if((buf2.ndim!=2 || buf2.shape[1]!=2) && (buf2.ndim!=1 || buf2.shape[0]!=0)) + throw std::runtime_error("Diagram 2 must be an array of size n x 2"); + typedef std::array<double, 2> Point; + auto p1 = (Point*)buf1.ptr; + auto p2 = (Point*)buf2.ptr; + auto diag1 = boost::make_iterator_range(p1, p1+buf1.shape[0]); + auto diag2 = boost::make_iterator_range(p2, p2+buf2.shape[0]); + + hera::AuctionParams<double> params; + params.wasserstein_power = wasserstein_power; + // hera encodes infinity as -1... + if(std::isinf(internal_p)) internal_p = hera::get_infinity<double>(); + params.internal_p = internal_p; + params.delta = delta; + // The extra parameters are purposedly not exposed for now. + return hera::wasserstein_dist(diag1, diag2, params); +} + +PYBIND11_MODULE(hera, m) { + m.def("wasserstein_distance", &wasserstein_distance, + py::arg("X"), py::arg("Y"), + py::arg("order") = 1, + py::arg("internal_p") = std::numeric_limits<double>::infinity(), + py::arg("delta") = .01, + R"pbdoc( + Compute the Wasserstein distance between two diagrams. + Points at infinity are supported. + + Parameters: + X (n x 2 numpy array): First diagram + Y (n x 2 numpy array): Second diagram + order (float): Wasserstein exponent W_q + internal_p (float): Internal Minkowski norm L^p in R^2 + delta (float): Relative error 1+delta + + Returns: + float: Approximate Wasserstein distance W_q(X,Y) + )pbdoc"); +} diff --git a/src/python/gudhi/wasserstein.py b/src/python/gudhi/wasserstein.py index db5ddff2..13102094 100644 --- a/src/python/gudhi/wasserstein.py +++ b/src/python/gudhi/wasserstein.py @@ -27,8 +27,8 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.): ''' :param X: (n x 2) numpy.array encoding the (points of the) first diagram. :param Y: (m x 2) numpy.array encoding the second diagram. - :param internal_p: Ground metric (i.e. norm l_p). :param order: exponent for the Wasserstein metric. + :param internal_p: Ground metric (i.e. norm L^p). :returns: (n+1) x (m+1) np.array encoding the cost matrix C. For 1 <= i <= n, 1 <= j <= m, C[i,j] encodes the distance between X[i] and Y[j], while C[i, m+1] (resp. C[n+1, j]) encodes the distance (to the p) between X[i] (resp Y[j]) and its orthogonal proj onto the diagonal. note also that C[n+1, m+1] = 0 (it costs nothing to move from the diagonal to the diagonal). @@ -54,8 +54,8 @@ def _build_dist_matrix(X, Y, order=2., internal_p=2.): def _perstot(X, order, internal_p): ''' :param X: (n x 2) numpy.array (points of a given diagram). - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm l_p in R^2); Default value is 2 (Euclidean norm). :param order: exponent for Wasserstein. Default value is 2. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). :returns: float, the total persistence of the diagram (that is, its distance to the empty diagram). ''' Xdiag = _proj_on_diag(X) @@ -66,8 +66,8 @@ def wasserstein_distance(X, Y, order=2., internal_p=2.): ''' :param X: (n x 2) numpy.array encoding the (finite points of the) first diagram. Must not contain essential points (i.e. with infinite coordinate). :param Y: (m x 2) numpy.array encoding the second diagram. - :param internal_p: Ground metric on the (upper-half) plane (i.e. norm l_p in R^2); Default value is 2 (euclidean norm). :param order: exponent for Wasserstein; Default value is 2. + :param internal_p: Ground metric on the (upper-half) plane (i.e. norm L^p in R^2); Default value is 2 (Euclidean norm). :returns: the Wasserstein distance of order q (1 <= q < infinity) between persistence diagrams with respect to the internal_p-norm as ground metric. :rtype: float ''' diff --git a/src/python/setup.py.in b/src/python/setup.py.in index bd7fb180..f968bd59 100644 --- a/src/python/setup.py.in +++ b/src/python/setup.py.in @@ -12,6 +12,7 @@ from setuptools import setup, Extension, find_packages from Cython.Build import cythonize from numpy import get_include as numpy_get_include import sys +import pybind11 __author__ = "Vincent Rouvreau" __copyright__ = "Copyright (C) 2016 Inria" @@ -42,6 +43,18 @@ for module in modules: runtime_library_dirs=runtime_library_dirs, cython_directives = {'language_level': str(sys.version_info[0])},)) +ext_modules = cythonize(ext_modules) + +ext_modules.append(Extension( + 'gudhi.hera', + sources = [source_dir + 'hera.cc'], + language = 'c++', + include_dirs = include_dirs + + ['@HERA_WASSERSTEIN_INCLUDE_DIR@', + pybind11.get_include(False), pybind11.get_include(True)], + extra_compile_args=extra_compile_args + [@GUDHI_PYBIND11_EXTRA_COMPILE_ARGS@], + )) + setup( name = 'gudhi', packages=find_packages(), # find_namespace_packages(include=["gudhi*"]) @@ -49,7 +62,7 @@ setup( author_email='gudhi-contact@lists.gforge.inria.fr', version='@GUDHI_VERSION@', url='http://gudhi.gforge.inria.fr/', - ext_modules = cythonize(ext_modules), + ext_modules = ext_modules, install_requires = ['cython','numpy >= 1.9',], - setup_requires = ['numpy >= 1.9',], + setup_requires = ['numpy >= 1.9','pybind11',], ) diff --git a/src/python/test/test_wasserstein_distance.py b/src/python/test/test_wasserstein_distance.py index 43dda77e..6a6b217b 100755 --- a/src/python/test/test_wasserstein_distance.py +++ b/src/python/test/test_wasserstein_distance.py @@ -1,6 +1,6 @@ """ This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. - Author(s): Theo Lacombe + Author(s): Theo Lacombe, Marc Glisse Copyright (C) 2019 Inria @@ -8,41 +8,66 @@ - YYYY/MM Author: Description of the modification """ -from gudhi.wasserstein import wasserstein_distance +from gudhi.wasserstein import wasserstein_distance as pot +from gudhi.hera import wasserstein_distance as hera import numpy as np +import pytest __author__ = "Theo Lacombe" __copyright__ = "Copyright (C) 2019 Inria" __license__ = "MIT" - -def test_basic_wasserstein(): +def _basic_wasserstein(wasserstein_distance, delta, test_infinity=True): diag1 = np.array([[2.7, 3.7], [9.6, 14.0], [34.2, 34.974]]) diag2 = np.array([[2.8, 4.45], [9.5, 14.1]]) diag3 = np.array([[0, 2], [4, 6]]) diag4 = np.array([[0, 3], [4, 8]]) - emptydiag = np.array([[]]) + emptydiag = np.array([]) + + # We just need to handle positive numbers here + def approx(x): + return pytest.approx(x, rel=delta) assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=1.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=np.inf, order=2.) == 0. assert wasserstein_distance(emptydiag, emptydiag, internal_p=2., order=2.) == 0. - assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == 2. - assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == 4. + assert wasserstein_distance(diag3, emptydiag, internal_p=np.inf, order=1.) == approx(2.) + assert wasserstein_distance(diag3, emptydiag, internal_p=1., order=1.) == approx(4.) + + assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == approx(5.) # thank you Pythagorician triplets + assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == approx(2.5) + assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == approx(3.5355339059327378) + + assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == approx(1.4453593023967701) + assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == approx(0.9772734057168739) + + assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == approx(3.141592214572228) + + assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == approx(3.) + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == approx(3.) # no diag matching here + assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == approx(np.sqrt(5)) + assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == approx(np.sqrt(5)) + + if(not test_infinity): + return - assert wasserstein_distance(diag4, emptydiag, internal_p=1., order=2.) == 5. # thank you Pythagorician triplets - assert wasserstein_distance(diag4, emptydiag, internal_p=np.inf, order=2.) == 2.5 - assert wasserstein_distance(diag4, emptydiag, internal_p=2., order=2.) == 3.5355339059327378 + diag5 = np.array([[0, 3], [4, np.inf]]) + diag6 = np.array([[7, 8], [4, 6], [3, np.inf]]) - assert wasserstein_distance(diag1, diag2, internal_p=2., order=1.) == 1.4453593023967701 - assert wasserstein_distance(diag1, diag2, internal_p=2.35, order=1.74) == 0.9772734057168739 + assert wasserstein_distance(diag4, diag5) == np.inf + assert wasserstein_distance(diag5, diag6, order=1, internal_p=np.inf) == approx(4.) - assert wasserstein_distance(diag1, emptydiag, internal_p=2.35, order=1.7863) == 3.141592214572228 +def hera_wrap(delta): + def fun(*kargs,**kwargs): + return hera(*kargs,**kwargs,delta=delta) + return fun - assert wasserstein_distance(diag3, diag4, internal_p=1., order=1.) == 3. - assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=1.) == 3. # no diag matching here - assert wasserstein_distance(diag3, diag4, internal_p=np.inf, order=2.) == np.sqrt(5) - assert wasserstein_distance(diag3, diag4, internal_p=1., order=2.) == np.sqrt(5) - assert wasserstein_distance(diag3, diag4, internal_p=4.5, order=2.) == np.sqrt(5) +def test_wasserstein_distance_pot(): + _basic_wasserstein(pot, 1e-15, test_infinity=False) +def test_wasserstein_distance_hera(): + _basic_wasserstein(hera_wrap(1e-12), 1e-12) + _basic_wasserstein(hera_wrap(.1), .1) |